Welcome to Meta Reward Modeling (MRM)! π MRM is a research-oriented framework for personalized reward modeling and alignment of Large Language Models (LLMs). It is designed to study how reward models can efficiently adapt to diverse user preferences under sparse feedback and generalize to unseen users. The framework follows a clean and modular design, making it easy to prototype, extend, and evaluate personalized alignment methods.
This project provides a full implementation of MRM, where each userβs preference learning is treated as a separate task. It includes MAML-style training pipelines, few-shot user adaptation, robust optimization objectives for hard-to-learn users, and reproducible evaluation scripts for user-level performance analysis.
π§ Meta-Learned Personalized Reward Initialization: Treats each user as a separate task and learns a shared reward initialization that can quickly adapt to new users with only a few preference examples.
π Robust Training for Diverse User Preferences: Uses a robust personalization objective that focuses more on hard-to-model users, improving performance consistency across diverse and long-tail preferences.
π§© Lightweight and Modular Reward Design: Represents user rewards with low-dimensional adaptive weights over shared components, enabling efficient personalization, clean ablations, and easy extension.
- [2026.01] ππInitial release of the project.
The code is tested on Python 3.10.0, PyTorch 2.4.0 and CUDA 12.5.
You can create a conda environment with the required dependencies using the provided requirements.txt file.
conda create -n mrm python=3.10 -y
conda activate mrm
pip install -r requirements.txt
pip install flash-attn --no-build-isolation- The dataset used in the paper is the PRISM dataset and the Reddit TLDR dataset.
- Run the following command to download and preprocess the data to generate the embeddings:
For PRISM:
python scripts/preprocess_prism.py \
--model_path Skywork/Skywork-Reward-V2-Llama-3.1-8B \
--save_prefix data/emb/prism/V2 \For Reddit TLDR:
python scripts/preprocess_reddit.py \
--model_path Skywork/Skywork-Reward-V2-Llama-3.1-8B \
--save_prefix data/emb/reddit/V2Note
Change the model path to 'Skywork/Skywork-Reward-Llama-3.1-8B-v0.2' if you want to use the V1 version of the reward model for preprocessing.
After data preparation, you can start training the meta reward model using the following command: for PRISM:
bash scripts/train_on_prism.shfor Reddit TLDR:
bash scripts/train_on_reddit.shBoth scripts will train the model with our default hyperparameters, and evaluate the model on the test set after predefined intervals. Also, logs and model checkpoints will be saved under the output/ directory.
Note
-
Our training pipelines include automatic evaluation and checkpointing, so typically you do not need to run the evaluation script.
-
You can modify the hyperparameters in the training scripts as needed. For example, change the 'seen_train_limit' to 100 to replicate the results of Reddit TLDR with 100 training samples per user.
-
If you want to skip the training phase and directly evaluate with our pretrained checkpoints. Download from here.
After training, you can evaluate the saved model checkpoints using the following commands: For PRISM:
bash scripts/test_on_prism.shFor Reddit TLDR:
bash scripts/test_on_reddit.shImportant
-
As every training run will randomly split the users and samples, please make sure to use the same setting (inner epoch, inner leaerning rate, and random seed) with the training phase when doing evaluation for consistent results.
-
For our released checkpoints, please refer to the provided inference scripts for the exact hyperparameters used.
This example shows a typical workflow for a single user:
- Encode text pairs with Skywork-Reward-V2-Llama-3.1-8B into embeddings,
- Adapt the MRM on the user's few-shot examples (update
shared_weightonly), - Run inference on new pairs for that same user.
import torch
from copy import deepcopy
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from utils import bt_loss
from train import MRM
from inference import load_ckpt_into_model
@torch.no_grad()
def encode_pairs(model, tokenizer, pairs, device="cuda"):
model.eval()
ch, rj = [], []
for ex in pairs:
conv = ex["prompt"]
for key, buf in [("chosen", ch), ("rejected", rj)]:
ids = tokenizer.apply_chat_template(
conv + [{"role": "assistant", "content": ex[key]}],
tokenize=True, return_tensors="pt"
).to(device)
out = model(ids, output_hidden_states=True)
buf.append(out.hidden_states[-1][0, -1].float().cpu())
return torch.stack(ch), torch.stack(rj)
def adapt_single_user(base_model, support_ch, support_rj, inner_lr=1e-3, inner_epochs=5, device="cuda"):
model = deepcopy(base_model).to(device).train()
opt = torch.optim.Adam([model.shared_weight], lr=inner_lr)
support_ch, support_rj = support_ch.to(device), support_rj.to(device)
for _ in range(inner_epochs):
opt.zero_grad()
loss = bt_loss(model(support_ch), model(support_rj))
loss.backward()
opt.step()
return model.eval()
@torch.no_grad()
def infer_on_pairs(model, ch, rj, device="cuda"):
return model(ch.to(device)), model(rj.to(device))
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH = "Skywork/Skywork-Reward-V2-Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
llm = AutoModelForSequenceClassification.from_pretrained(
MODEL_PATH, num_labels=1, torch_dtype=torch.bfloat16, device_map=device
)
CKPT_PATH = "ckpt/model.pt"
mrm = MRM(in_dim=4096, hidden_sizes=[2], use_bias=False)
load_ckpt_into_model(mrm, CKPT_PATH, device)
support_pairs = [
{
"prompt": [{"role": "user", "content": "TL;DR this post: I tried waking up at 5am for a month and tracked my productivity."}],
"chosen": "Waking up early helped at first, but long-term productivity depended more on sleep quality than wake-up time.",
"rejected": "The post is about waking up early and productivity.",
},
{
"prompt": [{"role": "user", "content": "Summarize the main point: I switched from iPhone to Android after 10 years."}],
"chosen": "The author values customization and battery life more than ecosystem lock-in, which motivated the switch.",
"rejected": "The author bought a new phone.",
},
]
sup_ch, sup_rj = encode_pairs(llm, tokenizer, support_pairs, device)
user_mrm = adapt_single_user(mrm, sup_ch, sup_rj, device=device)
test_pairs = [
{
"prompt": [{"role": "user", "content": "TL;DR: I quit my job to freelance and here is what I learned in 6 months."}],
"chosen": "Freelancing offers flexibility but requires strong self-discipline and financial planning to be sustainable.",
"rejected": "The author talks about quitting a job and freelancing.",
}
]
test_ch, test_rj = encode_pairs(llm, tokenizer, test_pairs, device)
s_ch, s_rj = infer_on_pairs(user_mrm, test_ch, test_rj, device)
print("reward(chosen) =", s_ch.tolist())
print("reward(rejected)=", s_rj.tolist())Meta Reward Modeling is a modular framework for personalized reward modeling, designed to learn rewards that can quickly adapt to individual users from limited preference data. The method separates shared reward structure from user-specific adaptation, enabling few-shot personalization and robust generalization to unseen users.
At a high level, the workflow proceeds as follows:
-
Preference Representation β Each response is scored using shared base reward functions, producing feature-level reward signals that are common across users.
-
Meta-Learned Personalization β A shared initialization of user-specific weights is learned across users. For each user, these weights are adapted with a few gradient steps using their own preference data.
-
Robust Meta Optimization β During training, user-level losses are reweighted to focus more on hard-to-model users, ensuring stable performance across diverse and long-tail preferences.
We would like to thank the contributors, open-source projects, and research communities whose work made Meta Reward Modeling possible.
This project is licensed under the MIT License. Please refer to the LICENSE file for more details.
|
π PersonalWAB Project Page |
π LoRe GitHub Repo |
π§ SynthesizeMe GitHub Repo |
If you use Meta Reward Modeling in your research or applications, please consider citing:
@misc{cai2026adaptsanymetareward,
title={One Adapts to Any: Meta Reward Modeling for Personalized LLM Alignment},
author={Hongru Cai and Yongqi Li and Tiezheng Yu and Fengbin Zhu and Wenjie Wang and Fuli Feng and Wenjie Li},
year={2026},
eprint={2601.18731},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2601.18731},
}
