This directory contains various training scripts.
Torch blog posts
- The torch.ch blog contains detailed posts about the rnn package.
- recurrent-visual-attention.lua: training script used in Recurrent Model for Visual Attention. Implements the REINFORCE learning rule to learn an attention mechanism for classifying MNIST digits, sometimes translated.
- noise-contrastive-esimate.lua: one of two training scripts used in Language modeling a billion words. Single-GPU script for training recurrent language models on the Google billion words dataset.
- multigpu-nce-rnnlm.lua : 4-GPU version of
noise-contrastive-estimate.luafor training larger multi-GPU models. Two of two training scripts used in the Language modeling a billion words.
Simple training scripts.
- Showcases the fundamental principles of the package. In chronological order of introduction date.
- simple-recurrent-network.lua: uses the
nn.Recurrentmodule to instantiate a Simple RNN. Illustrates the first AbstractRecurrent instance in action. It has since been surpassed by the more flexiblenn.Recursorandnn.Recurrence. Thenn.Recursorclass decorates any module to make it conform to the nn.AbstractRecurrent interface. Thenn.Recurrenceimplements the recursiveh[t] <- forward(h[t-1], x[t]). Together,nn.Recursorandnn.Recurrencecan be used to implement a wide range of experimental recurrent architectures. - simple-sequencer-network.lua: uses the
nn.Sequencermodule to accept a batch of sequences asinputof sizeseqlen x batchsize x .... Both tables and tensors are accepted as input and produce the same type of output (table->table, tensor->tensor). TheSequencerclass abstract away the implementation of back-propagation through time. It also provides aremember(['neither','both'])method for triggering what theSequencerremembers between iterations (forward,backward,update). - simple-recurrence-network.lua: uses the
nn.Recurrencemodule to define the h[t] <- sigmoid(h[t-1], x[t]) Simple RNN. Decorates it usingnn.Sequencerso that an entire batch of sequences (input) can forward and backward propagated per update.