This repository contains the code for MMOA-RAG, a system for multi-modules optimization involving Query Rewriter, Retriever, Selector and Generator. The code is organized into several components that facilitate the deployment, training, and evaluation of the RAG system.
Paper: Improving Retrieval-Augmented Generation through Multi-Agent Reinforcement Learning
- Computational Resource Requirements
- Deploying the Retrieval Model
- Getting the SFT and MAPPO Training Data
- Warm Start for RAG System
- Multi-Agent Optimization for RAG System
- Evaluation
- Others
We used two servers, each equipped with 8 A800 GPUs (each with 80GB of memory), for training MMOA-RAG. One server was dedicated to deploying the retrieval model, while the other was used for training MARL.
Why is a separate machine needed to deploy the retrieval model? During the MARL training process, updates to the Query Rewriter are involved, and it is necessary to obtain Top-k documents in real-time during Rollout. This requires high real-time performance from the retrieval model. Therefore, we deployed the retrieval model on a separate machine using Faiss and leveraged GPU acceleration to ensure fast retrieval results.
The retrieval models are deployed using a specialized machine due to the multi-modules optimization that involves the training of the Query Rewriter.
To deploy the retrieval model, execute the following:
- Ensure the code in
./flask_server.pyis properly configured. - Start the retrieval model API by running in one server:
bash run_server.sh
To generate the training data for SFT and MAPPO processes, follow these steps:
Run the following script to obtain the SFT training data:
python qr_s_g_sft_data_alpaca.pyRun the following script to get the MAPPO training data for each dataset:
python get_ppo_data_alpaca.pyWe developed the code of MAPPO to joint optimizing multiple modules in RAG system based on LLaMA-Factory, and the core code can be seen at:
./LLaMA-Factory/src/llamafactory/train/ppo/trainer_qr_s_g.pyTo warm start multiple modules in the RAG system using SFT, execute:
bash LLaMA-Factory/run_sft.shTo perform joint learning of the multiple modules in the RAG system using MAPPO, run the following command in another server:
bash LLaMA-Factory/run_mappo.shEvaluate the performance of the RAG system by executing:
CUDA_VISIBLE_DEVICES=0 python evaluate_qr_s_g.pyCreate necessary directories:
-
./datafor storing data sets. For example,./data/ambigqais used to save the AmbigQA dataset. -
./modelsfor saving checkpoints of the retrieval model and LLMs.