This repository provides the code for our paper MedRAT (ECCV 2024).
conda create --name medrat python=3.7
conda activate medcycle
conda install pytorch==1.10.0 torchvision==0.11.0 torchaudio==0.10.0 cudatoolkit=10.2 -c pytorch
pip install pandas scikit-learn pycocoevalcap tqdm
pip install wandb # optional
The <CSV_PATH> is the path to the .csv file generated by the preprocessing steps. The IMAGE_DIR is different for the training and test sets as the datasets are unpaired.
Train
python main.py --ann_path <CSV_DIR_PATH> --image_dir <IMAGE_DIR_TRAIN> --image_dir_test <IMAGE_DIR_TEST>
Note: you can set --wandb 1 for Weights & Biases monitoring.
Continue from checkpoint
python main.py --ann_path <CSV_DIR_PATH> --image_dir <IMAGE_DIR_TRAIN> --image_dir_test <IMAGE_DIR_TEST> --resume <CHECKPOINT_PATH>
Please consider citing our paper if the project helps your research.
(soon)
We thank the authors of R2GEN-CMN, both for their research and for sharing their code. Our repository is built upon their project. The code for the contrastive loss is built upon SupContrast.