Skip to content

chiamp/torchformer

Repository files navigation

Torchformer

This repo has PyTorch implementations of the following:

No models have KV caching.

Tasks

All models were trained on the following tasks:

  • string reversal (max input character limit: 8)
  • addition of two operands (max input character limit per operand: 5)

In reality, there are "four" tasks, since I separated the tasks based on model type:

  • string_reverse_encoder_decoder
  • string_reverse_decoder_only
  • addition_encoder_decoder
  • addition_decoder_only

Test

You can test out the models on the tasks above by running python evaluation.py. Edit the main block to try out different tasks.

Files and Directories

  • evaluation.py: test the PyTorch models I pre-trained on the tasks
  • train.py: train your own checkpoints of the PyTorch models. To test out your saved checkpoints, copy the filename of the checkpoint (excluding the file extension) and paste it into the main block in evaluation.py and run python evaluation.py.
  • data.py: generate new data for the tasks (WARNING: this will overwrite the current data, unless you move the existing data elsewhere)
  • config.py: configs for training. There's a separate config for each of the four tasks. The default config values guarantee convergence for the default generated data (string_reverse_encoder_decoder will converge by step/epoch 4000, string_reverse_decoder_only will converge by step/epoch 4000, addition_encoder_decoder will converge by step/epoch 22000, and addition_decoder_only will converge by step/epoch 193000). (WARNING: The addition_decoder_only task for the default config values will converge to a validation loss value of 0.03605 at step/epoch 172000, but the script won't stop until step/epoch 193000 because the patience will run out and it never reached a validation loss of 0.03. You can adjust the max_patience hyper parameter to control this. All the other tasks should converge to a validation loss value of 0.03 at a much shorter step/epoch with the default config values).
  • layers.py: contains the nn.Module / layers of the transformer
  • inference.py: contains teacher-forcing and auto-regressive decoding functions that call the layer functions from layers.py
  • tokenizer.py: contains the tokenizer class for each task
  • codebase_string.py: script that copies the entire codebase as a string so I can copy and paste it to Gemini for feedback/debugging
  • data/: directory where the generated data for each task is written to
  • checkpoints/: directory where the model checkpoints for each task are saved
  • basic_torch_examples/: unrelated directory where I was practicing basic PyTorch scripts for linear regression / basic training
  • requirements.txt: install by running pip install -r requirements.txt. (NOTE: there may be some unnecessary dependencies here, but I just copied my virtual environment's dependencies over by running pip freeze). I'm using Python version 3.11.9.

About

Transformer implementation in PyTorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages