This directory contains the serving code for the trained pairwise prompt ranker.
The service loads a trained checkpoint from predictor_train/ and exposes HTTP
endpoints for prompt scoring and prompt comparison.
The released version supports batch inference. Single-request endpoints reuse the same batch scoring path internally, and the batch endpoints run batched tokenization plus batched forward inference for better throughput.
pip install -r requirements.txtTypical checkpoint paths from ../predictor_train/:
../predictor_train/outputs/alpaca_gpt4_bert/best_model.pt../predictor_train/outputs/alpaca_gpt4_bert/last_model.pt../predictor_train/outputs/lmsys_gpt4_bert/best_model.pt
Use best_model.pt if you want the checkpoint with the best validation
accuracy. Use last_model.pt if you want the final epoch checkpoint.
uvicorn scripts.serve_predictor_score:app --host 0.0.0.0 --port 8000You can override model settings with environment variables:
PREDICTOR_MODEL_PATH: checkpoint pathPREDICTOR_MODEL_NAME: encoder name, defaultbert-base-uncasedPREDICTOR_MAX_LENGTH: tokenizer max length, default128PREDICTOR_DEVICE: force device such ascpuorcuda
Example:
PREDICTOR_MODEL_PATH=../predictor_train/outputs/alpaca_gpt4_bert/best_model.pt \
PREDICTOR_MODEL_NAME=bert-base-uncased \
uvicorn scripts.serve_predictor_score:app --host 0.0.0.0 --port 8000GET /
- health check:
curl http://127.0.0.1:8000/POST /score
- one-command example:
curl -X POST http://127.0.0.1:8000/score \
-H "Content-Type: application/json" \
-d '{"prompt":"Explain machine learning simply."}'- response body:
{
"prompt": "Explain machine learning simply.",
"score": 0.42
}A larger score means the predictor estimates that the prompt is more likely to produce a longer response.
POST /score_batch
- one-command example:
curl -X POST http://127.0.0.1:8000/score_batch \
-H "Content-Type: application/json" \
-d '{"prompts":["Explain machine learning simply.","Give a detailed explanation of machine learning."]}'- response body:
{
"prompts": [
"Explain machine learning simply.",
"Give a detailed explanation of machine learning."
],
"scores": [0.18, 0.73]
}POST /compare
- one-command example:
curl -X POST http://127.0.0.1:8000/compare \
-H "Content-Type: application/json" \
-d '{"prompt_a":"Explain machine learning simply.","prompt_b":"Give a detailed explanation of machine learning."}'- response body:
{
"prompt_a": "Explain machine learning simply.",
"prompt_b": "Give a detailed explanation of machine learning.",
"score_a": 0.18,
"score_b": 0.73,
"winner": "prompt_b"
}POST /compare_batch
- one-command example:
curl -X POST http://127.0.0.1:8000/compare_batch \
-H "Content-Type: application/json" \
-d '{"pairs":[{"prompt_a":"Explain machine learning simply.","prompt_b":"Give a detailed explanation of machine learning."},{"prompt_a":"Summarize AI.","prompt_b":"Write a long tutorial on artificial intelligence."}]}'- response body:
{
"results": [
{
"prompt_a": "Explain machine learning simply.",
"prompt_b": "Give a detailed explanation of machine learning.",
"score_a": 0.18,
"score_b": 0.73,
"winner": "prompt_b"
},
{
"prompt_a": "Summarize AI.",
"prompt_b": "Write a long tutorial on artificial intelligence.",
"score_a": 0.11,
"score_b": 0.84,
"winner": "prompt_b"
}
]
}