This repository contains code for training and fine-tuning Enformer-based models for genomic track prediction.
Commands to set up the environment:
module load python3
module load pytorch
module load tensorflow
[ ! -d "env" ] && virtualenv --system-site-packages env
source env/bin/activate
pip install -r requirements.txtTo use SCC interactive sessions:
qrsh -P aclab -l gpus=1 -l gpu_c=3.5 -pe omp 1You can submit bash files on the SCC using generic_submit.sh with the following syntax:
bash generic_submit.sh job_script.sh arg1 arg2 arg3This is roughly equivalent to qsub job_script.sh arg1 arg2 arg3, but will first copy job_script into the submitted_jobs directory with a unique name job_script.sh.TIMESTAMP and then submit that script instead. This means you can repeatedly modify and resubmit job_script even while jobs are queued without having to worry about which job id corresponds to which submission.
The main entry point for training is submit_basenji_combined.py. You can either run it directly with Python or use basenji_combined.sh which sets up the environment and passes parameters to the Python script.
There are three main experiment types:
Train an Enformer model from scratch on human genomic data.
Key flags:
model_architecture=enformer_pytorch- use the Enformer architecturetask_type=basenji- genomic track prediction taskro_only_human=true- train on human dataro_pretrained_pretraining=false- train from scratch (not fine-tuning)efa_enformer_n_layer=11- number of transformer layerslearning_rate=5e-5- learning rate for training from scratchepochs=40- number of training epochs
Example command:
python submit_basenji_combined.py \
--model_architecture enformer_pytorch \
--task_type basenji \
--ro_only_human true \
--ro_pretrained_pretraining false \
--efa_enformer_n_layer 11 \
--learning_rate 5e-5 \
--epochs 40 \
--batch_size 1 \
--loss_type manual_poisson \
--lr_type cosine \
--weight_decay 1e-4 \
--grad_norm_clip 0.2 \
--wandb_project my_pretraining_projectFine-tune a pretrained model on data from a different species or specific histone modification tracks.
Key flags:
ro_pretrained_pretraining=true- enable fine-tuning from checkpointro_pretrained_path=<path>- path to the pretrained model checkpoint- Species/track flag set to
true(see list below) learning_rate=3e-5- learning rate for fine-tuningepochs=10- fewer epochs needed for fine-tuning
Available species flags:
ro_only_human,ro_only_mouse,ro_only_cattle,ro_only_pig,ro_only_chicken,ro_only_dog,ro_only_mole_rat,ro_only_rhesus,ro_only_mouse_12
Available histone track flags:
ro_only_mouse_h3k27ac,ro_only_mouse_h3k27me3ro_only_rhesus_h3k27ac,ro_only_rhesus_h3k27me3ro_only_chicken_h3k27ac,ro_only_chicken_h3k27me3
Example: Fine-tuning on cattle data:
python submit_basenji_combined.py \
--model_architecture enformer_pytorch \
--task_type basenji \
--ro_only_cattle true \
--ro_pretrained_pretraining true \
--ro_pretrained_path /path/to/pretrained/model \
--efa_enformer_n_layer 11 \
--learning_rate 3e-5 \
--epochs 10 \
--batch_size 1 \
--loss_type manual_poisson \
--wandb_project cattle_finetuningExample: Fine-tuning on chicken H3K27me3 track:
python submit_basenji_combined.py \
--model_architecture enformer_pytorch \
--task_type basenji \
--ro_only_chicken_h3k27me3 true \
--ro_pretrained_pretraining true \
--ro_pretrained_path /path/to/pretrained/model \
--efa_enformer_n_layer 11 \
--learning_rate 3e-5 \
--epochs 10 \
--batch_size 1 \
--wandb_project chicken_h3k27me3_finetuningStudy the effect of varying the number of training tracks on individual track performance.
Key flags:
num_tracks=<N>- number of tracks to use (e.g., 50, 100, 200, 500, 800, 1000, 1500, 1642)track_of_interest=<track_id>- specific track to evaluateseed=<seed>- random seed for reproducibilityro_only_mouse=true- typically run on mouse datalearning_rate=3e-7- lower learning rate when fine-tuning with specific tracks
Example:
python submit_basenji_combined.py \
--model_architecture enformer_pytorch \
--task_type basenji \
--ro_only_mouse true \
--num_tracks 800 \
--track_of_interest 280 \
--seed 800280 \
--ro_pretrained_pretraining true \
--ro_pretrained_path /path/to/checkpoint \
--efa_enformer_n_layer 11 \
--learning_rate 3e-7 \
--epochs 10 \
--batch_size 1 \
--wandb_project num_tracks_experimentThere are 5 groups of files used in each training job:
- Model files - e.g.,
model/model_basenji_rewrite.py - Training loop file - e.g.,
model/training_loop_reg.py - Dataset loader file -
loader/expanded_basenji.py - Main script -
submit_basenji_combined.py(combines all components) - Bash wrapper -
basenji_combined.sh(sets up environment and parameters)
The model_architecture parameter controls which model is used:
enformer_pytorch- Pure Enformer implementation. Parameters use theefa_prefix (e.g.,efa_enformer_n_layer).enformer_rewrite- Custom model with more options. Parameters use thero_prefix. Sub-prefixes include:ro_enformer_*- Enformer block parametersro_performer_*- Performer block parameters
WandB is used for experiment tracking. It stores data online for easy access and provides ways to organize runs by parameters.
- Create an account on the WandB website and ask to be added to the optimizedlearning team.
- Install:
pip install wandb(orpip install -r requirements.txt) - Initialize: run
wandb initfrom the project directory and choose the optimizedlearning team - Login: run
wandb loginand follow the prompts for your API key
In the code, experiments are logged with wandb.log calls, and hyperparameters are tracked via wandb.config.
To generate a new dataset, use the scripts in genomics_debug/script/. Example:
/genomicsML/genomics_debug/script/debug_again_cattle.shVariables to configure for new datasets:
OUTPUT_BASEORGANISM_ORIGINAL(e.g.,cattle_nooverlap)ORGANISM(e.g.,${ORGANISM_ORIGINAL}_reassigned)GAP_FILE- path to gap BED fileFASTA_FILE- path to genome FASTATXT_FILE- path to targets fileCHROM_SIZES- path to chromosome sizes fileALIGNMENT- path to chain file for liftover
-
Update
basenji_combined.sh: Add a new variable in the 'SPECIE TYPE' section (e.g.,ro_only_cat=true) and add the corresponding argument (e.g.,--ro_only_cat $ro_only_cat). -
Update
submit_basenji_combined.py: Add the dataset path for the new species. Search for an existing species likemouse_12and replicate the pattern:if args.ro_only_cat: train_dataset = BasenjiDataset('cat_nooverlap_reassigned', 'train', data_dir="/path/to/data/") validation_dataset = BasenjiDataset('cat_nooverlap_reassigned', 'valid', data_dir="/path/to/data/")
-
Update output features: Adjust the number of output features for the new species in the model configuration.
- Todo document: Google Doc
- Google Drive: Shared Folder