This repository contains a text classification system designed to detect signs of depression in written language. The project leverages a pre-trained Bidirectional Encoder Representations from Transformers (BERT) model, fine-tuned on a specialized dataset to classify text into two categories: Depression and No Depression.
It utilizes Hugging Face's transformers and datasets libraries to provide an optimized workflow for training and inference.
- State-of-the-Art Modeling: Employs an English
bert-base-uncasedmodel for deep contextual understanding of text. - Binary Classification: Fine-tuned specifically to classify text input efficiently and accurately.
- Automated Training Pipeline:
train_bert.pyhandles dataset downloading, preprocessing, hyperparameter optimization, and model serialization. - Interactive Inference:
predict_bert.pyallows real-time inference on custom user inputs using the fine-tuned model.
Python 3.7+ is recommended. To set up the environment and install dependencies, run:
pip install transformers torch pandas scikit-learn datasets numpygit clone https://github.com/Adarsh-Aravind/Depression-Detection-Bert-.git
cd Depression-Detection-Bert-The train_bert.py script automatically downloads the ShreyaR/DepressionDetection dataset from Hugging Face, processes it, and begins the fine-tuning process.
python train_bert.pyConfiguration Details: The model trains for 10 epochs with a learning rate of 1.5e-5, weight decay of 0.01, and a batch size of 20 (with fp16 precision enabled). Results and checkpoints are saved to the ./bert_depression_model directory.
Once the model is trained and saved, you can use the predict_bert.py script for interactive testing.
python predict_bert.pyThe script will prompt you to enter sentences and will output the model's classification. Type quit to exit the loop.
train_bert.py: The complete training script including data tokenization, training configuration, and execution.predict_bert.py: An interactive CLI script to analyze text and predict outcomes.
- Dataset sourced from Hugging Face:
ShreyaR/DepressionDetection. - Powered by Hugging Face and PyTorch.