This project focuses on training a robust toxicity classification model that can withstand adversarial attacks. The training pipeline includes:
- Baseline Classifier: Fine-tuned transformer-based model.
- Adversarial Perturbations: Applying adversarial attacks on text inputs.
- Reinforcement Learning (RL) Fine-Tuning: Using RL to improve robustness.
First, ensure you have Python 3.8+ installed. Then manually install the required libraries:
pip install transformers datasets torch torchmetrics accelerate tqdm kaggle spacyThe project relies on SpaCy's word embeddings. Run the following commands to download the required models:
python -m spacy download en_core_web_lg
python -m spacy download en_core_web_mdTo download the dataset from Kaggle, you need to configure your Kaggle API key.
- Go to Kaggle API Tokens and download
kaggle.json. - Move it to the correct location:
mkdir -p ~/.kaggle mv path/to/kaggle.json ~/.kaggle/ chmod 600 ~/.kaggle/kaggle.json
Download the dataset using:
kaggle competitions download -c jigsaw-toxic-comment-classification-challenge --force
unzip -o jigsaw-toxic-comment-classification-challenge.zip -d ./jigsaw_toxicity_data
unzip -o ./jigsaw_toxicity_data/train.csv.zip -d ./jigsaw_toxicity_data
unzip -o ./jigsaw_toxicity_data/test.csv.zip -d ./jigsaw_toxicity_data
unzip -o ./jigsaw_toxicity_data/test_labels.csv.zip -d ./jigsaw_toxicity_datapython ./src/train.py --data_directory ./jigsaw_toxicity_data --model_path ./models/classifier.ptpython ./src/evaluate.py --data_directory ./jigsaw_toxicity_data --model_path ./models/classifier.pt --adversarialpython ./src/rl_policy.py --data_directory ./jigsaw_toxicity_data --classifier_model_path ./models/classifier.pt --policy_model_path ./models/policy.pt📂 Safety-Alignment-Classifier
│── 📂 jigsaw_toxicity_data # Contains dataset files
│── 📂 models # Saved models
│── 📂 src # Training & inference scripts
|── dataset.py # Dataset creation
|── model.py # Model creation
│── train.py # Baseline classifier training
│── evaluate.py # Model evaluation
│── rl_policy.py # RL fine-tuning and evaluation for robustness