A simple spam detection project using PyTorch, Scikit-learn, and the Hugging Face sms_spam dataset. The model is trained to classify SMS messages as spam or not spam using a logistic regression-based neural network (CircleModelV0).
Sms-Spam/
├── main.ipynb # Training notebook
├── inference.py # Inference script
├── full_model.pth # Trained PyTorch model (full)
├── vectorizer.pkl # CountVectorizer object (for text preprocessing)
└── README.md
A simple feed-forward neural network:
Input Layer: 8713 features (from CountVectorizer)Hidden Layer: 64 neurons + ReLUOutput Layer: 1 neuron (sigmoid for binary classification)
git clone https://github.com/milliyin/sms-spam-model-train.git
cd Sms-Spampip install torch scikit-learn datasets huggingface_hubpython inference.pyInput: you have won a free prize
Output: SpamThe model uses the public dataset ucirvine/sms_spam from Hugging Face. It contains 5,574 SMS messages labeled as spam or ham.
This project is part of my learning process on how to train machine learning models using my own code. Here’s what I’ve learned and implemented:
- ✅ Installed and configured libraries like
torch,scikit-learn, anddatasets. - ✅ Downloaded and loaded the dataset using the Hugging Face Datasets library.
- ✅ Preprocessed the raw SMS data using
CountVectorizerto convert text into numerical feature vectors. - ✅ Converted those vectors to PyTorch tensors.
- ✅ Split the data into training and testing sets using
train_test_split. - ✅ Assigned input features (X: SMS prompts) and labels (y: 0 for ham, 1 for spam).
- ✅ Created a custom model class
CircleModelV0using PyTorch'snn.Module. - ✅ Wrote the training loop including forward pass, loss computation, backpropagation, and optimization.
- ✅ Evaluated the model using accuracy and tested it on new inputs.
- ✅ Saved and reloaded the model and vectorizer for inference.
If you're using PyTorch 2.6+, the torch.load() call needs to allow loading full models with custom classes:
model = torch.load("full_model.pth", weights_only=False)Make sure the class CircleModelV0 is defined or imported in your script before calling torch.load.
Feel free to open issues or suggest improvements!
Made with ❤️ for NLP enthusiasts.