My PyTorch implementation of the original Transformer model from the paper Attention Is All You Need inspired by all the codes and blogs I've read on this topic. There's nothing really special going on here except the fact that I tried to make it as barebone as possible. There is also a training code prepared for a simple German -> English translator written in pure PyTorch using Torchtext library.
- The Illustrated Transformer by Jay Alammar
- The Original Transformer (PyTorch) by Aleksa Gordic
- Attention is all you need from scratch by Aladdin Persson
- PyTorch Seq2Seq by Ben Trevett
- Transformers: Attention in Disguise by Mihail Eric
- The Annotated Transformer by Harvard NLP
And probably a couple more which I don't remember ...
- Install the required pip packages:
pip install -r requirements.txt- Install
spacymodels :
python -m spacy download de_core_news_sm
python -m spacy download en_core_web_smNote: This code uses Torchtext's new API (v0.10.0+) and the dataset.py contains a custom text dataset class inherited from torch.utils.data.Dataset and is different from the classic methods using Field and BucketIterator (which are now moved to torchtext.legacy). Nevertheless torchtext library is still under heavy development so this code will probably break with the upcoming versions.
In train.py we train a simple German -> English translation model on Multi30k dataset using the Transformer model. Make sure you configure the necessary paths for weights, logs, etc in config.py. Then you can simply run the file as below:
python train.pyEpoch: 1/10 100%|######################################################################| 227/227 [00:10<00:00, 21.61batch/s, loss=4.33]
Evaluating... 100%|######################################################################| 8/8 [00:00<00:00, 45.25batch/s, loss=3.13]
Saved Model at weights/1.pt
Epoch: 2/10 100%|######################################################################| 227/227 [00:10<00:00, 22.64batch/s, loss=2.82]
Evaluating... 100%|######################################################################| 8/8 [00:00<00:00, 51.68batch/s, loss=2.55]
Saved Model at weights/2.pt
Epoch: 3/10 100%|######################################################################| 227/227 [00:10<00:00, 22.56batch/s, loss=2.22]
Evaluating... 100%|######################################################################| 8/8 [00:00<00:00, 51.98batch/s, loss=2.22]
Saved Model at weights/3.pt
Epoch: 4/10 100%|######################################################################| 227/227 [00:10<00:00, 22.64batch/s, loss=1.83]
Evaluating... 100%|######################################################################| 8/8 [00:00<00:00, 52.20batch/s, loss=2.07]
Saved Model at weights/4.pt
Epoch: 5/10 100%|######################################################################| 227/227 [00:10<00:00, 22.64batch/s, loss=1.55]
Evaluating... 100%|######################################################################| 8/8 [00:00<00:00, 52.12batch/s, loss=2]
Saved Model at weights/5.pt
Epoch: 6/10 100%|######################################################################| 227/227 [00:10<00:00, 22.25batch/s, loss=1.34]
Evaluating... 100%|######################################################################| 8/8 [00:00<00:00, 51.45batch/s, loss=1.95]
Saved Model at weights/6.pt
Epoch: 7/10 100%|######################################################################| 227/227 [00:10<00:00, 22.55batch/s, loss=1.17]
Evaluating... 100%|######################################################################| 8/8 [00:00<00:00, 51.34batch/s, loss=1.95]
Saved Model at weights/7.pt
Epoch: 8/10 100%|######################################################################| 227/227 [00:10<00:00, 22.46batch/s, loss=1.03]
Evaluating... 100%|######################################################################| 8/8 [00:00<00:00, 51.43batch/s, loss=1.96]
Saved Model at weights/8.pt
Epoch: 9/10 100%|######################################################################| 227/227 [00:10<00:00, 22.45batch/s, loss=0.91]
Evaluating... 100%|######################################################################| 8/8 [00:00<00:00, 52.84batch/s, loss=1.99]
Saved Model at weights/9.pt
Epoch: 10/10 100%|######################################################################| 227/227 [00:10<00:00, 22.50batch/s, loss=0.808]
Evaluating... 100%|######################################################################| 8/8 [00:00<00:00, 51.74batch/s, loss=2.01]
Saved Model at weights/10.pt
Given the sentence Eine Gruppe von Menschen steht vor einem Iglu as input in predict.py we get the following output which is pretty decent even though our dataset is somewhat naive & simple.
python predict.py"Translation: A group of people standing in front of a warehouse ."-
predict.pyfor inference - Add pretrained weights
- Visualize attentions
- An in-depth notebook