NMT with xFormers: Part 1
Published:
In this post I explore the APIs of the wonderful new xFormers library while revisiting torchtext which has undergone a great deal of maturation since I last used it. This serves as an exploration and hopefully some motivation to develop tooling around the ideas explored here in hopes of developing a principled neural machine translation library that leverages optimized tooling under the hood. Let’s jump in.
This post will hopefully be the first in a series where I explore and develop components in the neural machine translation lifecycle. Naturally I’ll begin by discussing data, loading, and preparation. I wanted to develop this with a relatively small dataset, so I used French-English data from Tatoeba as found on this website. While torchtext natively supports fetching established benchmark datasets, I elected to use a non-encapsulated dataset for two purposes:
- The benchmark datasets tend to be too big to yield quick turnaround during development.
- I wanted to understand the surface area of incorporating a custom dataset into torchtext.
$ wget http://d2l-data.s3-accelerate.amazonaws.com/fra-eng.zip && unzip fra-eng.zip
$ head fra-eng/fra.txt
Go. Va !
Hi. Salut !
Run! Cours !
Run! Courez !
Who? Qui ?
Wow! Ça alors !
Fire! Au feu !
Help! À l'aide !
Jump. Saute.
Stop! Ça suffit !
$ wc -l fra-eng/fra.txt
167130 fra-eng/fra.txt
So our data is TSV organized by English then French. The first thing I’ll need is to read the TSV:
import csv
def read_tsv(file):
with open(file, encoding='utf-8') as f:
reader = csv.reader(f, delimiter='\t')
yield from reader
I explored the torchtext datasets
module and at the time of writing this, the convention was to wrap the torchtext.data.datasets.datasets_utils._RawTextIterableDataset
class in some helpful function. The arguments to the constructor are a description, the number of lines in the dataset, and an iterable yielding tuples of (src, tgt).
from torchtext.data.datasets_utils import _RawTextIterableDataset
def Tatoeba_EN_FR(file="fra-eng/fra.txt"):
_iter = read_tsv(file)
return _RawTextIterableDataset("Tatoeba", 167130, _iter)
Strictly speaking this is a bit bare-bones as compared to some of the datasets as defined in the torchtext library, but as an initial pass, this is totally fine.
Now that we have a torch.utils.data.Dataset
, we can start loading data with a PyTorch DataLoader
:
>>> import csv
>>> from torchtext.data.datasets_utils import _RawTextIterableDataset
>>> from torch.utils.data import DataLoader
>>> def read_tsv(file):
with open(file, encoding='utf-8') as f:
reader = csv.reader(f, delimiter='\t')
yield from reader
...
>>> def Tatoeba_EN_FR(file="fra-eng/fra.txt"):
_iter = read_tsv(file)
return _RawTextIterableDataset("Tatoeba", 167130, _iter)
...
>>> ds = Tatoeba_EN_FR()
>>> loader = DataLoader(ds, batch_size=1)
>>> next(iter(loader))
[('Go.',), ('Va !',)]
>>> # Nice!
Now that we can load pairs of sentences from the dataset, we consider tokenization strategies. While there are many, I’ll focus on the unigram language model segmentation scheme as implemented in SentencePiece. Luckily torchtext has wrappers for this so no extra libraries are needed. For file-based unigram training in torchtext, I use the torchtext.data.functional.generate_sp_model
function leaving defaults for all arguments except the model_prefix
.
import os
from torchtext.data.functional import generate_sp_model, load_sp_model
prefix = "tatoeba_20k_en_fr" # By default, torchtext learns 20k subword units
if not os.path.exists(prefix + ".model"):
# Writes "tatoeba_20k_en_fr.vocab" and "tatoeba_20k_en_fr.model"
# to the current working directory if it doesn't already exist
generate_sp_model("fra-eng/fra.txt", model_prefix=prefix)
spm_model = load_sp_model(prefix + ".model")
Now that we have a unigram model which can perform subword tokenization, we want to build a torchtext vocabulary using these subwords so we can use the facilities it provides. First we can construct a “sentencepiece tokenizer” which provides a fast mechanism for unigram tokenization. Then we will incrementally apply the tokenizer to sentences we encounter in our data and build the vocab using the resulting subwords:
from torchtext.data.functional import sentencepiece_tokenizer
from torchtext.vocab import build_vocab_from_iterator
def yield_tokens(file, spm):
for (src, tgt) in read_tsv(file):
# SentencePieceTokenizer expects a list of strings as inputs
yield from spm([src])
yield from spm([tgt])
spm = sentencepiece_tokenizer(spm_model)
# Add special tokens since we'll need them later.
vocab = build_vocab_from_iterator(yield_tokens("fra-eng/fra.txt", spm), specials=["<unk>", "<s>", "</s>", "<pad>"])
# Set OOV handler
vocab.set_default_index(vocab["<unk>"])
Finally we’ll want to be able to collate a batch of sentences appropriately. This means that we tokenize, numericalize, and pad appropriately batch-wise. Earlier we used a batch size of 1 just to see the examples, but typically we’ll use a significantly larger batch size. We’ll define a collate_fn
import torch
def collate_fn(batch, text_transform, bos, eos, padding_value):
src_sents, tgt_sents = [], []
for src, tgt in batch:
src_sents.append(text_transform(src))
tgt_sents.append(text_transform(tgt))
padded_src_sents = []
padded_tgt_sents = []
# Add two for BOS and EOS
max_len = max(max(map(len, src_sents)), max(map(len, tgt_sents))) + 2
for src_tokens, tgt_tokens in zip(src_sents, tgt_sents):
src_encoded = torch.tensor(bos + src_tokens + eos, dtype=torch.int64)
tgt_encoded = torch.tensor(bos + tgt_tokens + eos, dtype=torch.int64)
src_padded = torch.cat([src_encoded, torch.full((max_len - src_encoded.shape[0],), padding_value)], dim=0)
tgt_padded = torch.cat([tgt_encoded, torch.full((max_len - tgt_encoded.shape[0],), padding_value)], dim=0)
padded_src_sents.append(src_padded)
padded_tgt_sents.append(tgt_padded)
# Stack 'em up so they're shaped B x max_len
padded_src = torch.stack(padded_src_sents, dim=1)
padded_tgt = torch.stack(padded_tgt_sents, dim=1)
# Mask out padded values
src_mask = (padded_src != padding_value)
tgt_mask = (padded_tgt != padding_value)
return padded_src, padded_tgt, src_mask, tgt_mask
and text_transform
which operates at the per-sentence level:
def text_transform(text, spm, vocab):
# Create a list of list of subwords for sentence text
encoded = list(spm([text]))
# Look up indices for the sentence
return vocab(encoded[0])
So now our data loader can be finalized:
padding_value = vocab(["<pad>"])[0]
BATCH = 64
loader = DataLoader(
ds, batch_size=BATCH,
collate_fn=lambda b: collate_fn(b, lambda sent: text_transform(sent, spm, vocab), vocab(["<s>"]), vocab(["</s>"]), padding_value)
)
Now that we have an initial pass at loading data we can focus on modeling with xFormers. The initial model comes from a modified configuration as defined here. The main modifications are the input vocab size which reflects the size of the joint subword vocabulary from the sentencepiece unigram model and removing seq_len
from the attention so we can have different seq lengths per batch.
from xformers.factory.model_factory import xFormer, xFormerConfig
EMB = 384
# Maximum allowable sentence in our dataset (in # subword tokens + 2 for <s>, </s>)
BLOCK_SIZE = 64
VOCAB = len(vocab)
encoder_configs = {
"reversible": False,
"block_type": "encoder",
"dim_model": EMB,
"position_encoding_config": {
"name": "vocab",
"seq_len": BLOCK_SIZE,
"vocab_size": VOCAB,
"dim_model": EMB,
},
"num_layers": 3,
"multi_head_config": {
"num_heads": 4,
"residual_dropout": 0,
"attention": {
"name": "scaled_dot_product",
"dropout": 0,
"causal": True,
},
"dim_model": EMB,
},
"feedforward_config": {
"name": "MLP",
"dropout": 0,
"activation": "relu",
"hidden_layer_multiplier": 4,
"dim_model": EMB,
},
}
decoder_configs = {
"block_type": "decoder",
"dim_model": 384,
"position_encoding_config": {
"name": "vocab",
"seq_len": BLOCK_SIZE,
"vocab_size": VOCAB,
"dim_model": EMB,
},
"num_layers": 2,
"multi_head_config_masked": {
"num_heads": 4,
"residual_dropout": 0,
"dim_model": EMB,
"attention": {
"name": "scaled_dot_product",
"dropout": 0,
"causal": True,
},
},
"multi_head_config_cross": {
"num_heads": 4,
"residual_dropout": 0,
"dim_model": EMB,
"attention": {
"name": "scaled_dot_product",
"dropout": 0,
"causal": True,
},
},
"feedforward_config": {
"name": "MLP",
"dropout": 0,
"activation": "relu",
"hidden_layer_multiplier": 4,
"dim_model": EMB,
},
}
my_config = [encoder_configs, decoder_configs]
# This part of xFormers is entirely type checked and needs a config object,
# could be changed in the future
config = xFormerConfig(my_config)
transformer = xFormer.from_config(config)
Now we need to project the decoder’s embedding dimension (384) into the vocab space. We define a simple model which adds a tiny linear layer on top.
import torch.nn as nn
class MTModel(nn.Module):
def __init__(self, transformer, out_embed_dim, vocab_size):
super().__init__()
self.transformer = transformer
self.out_proj = nn.Linear(out_embed_dim, vocab_size)
def forward(self, src, tgt=None, src_mask=None, tgt_mask=None):
return self.out_proj(self.transformer(src, tgt, src_mask, tgt_mask))
model = MTModel(transformer, EMB, VOCAB)
Now we can begin writing the training code.
As an initial training loop I’ll use label smoothed cross entropy loss with a weight-decayed Adam optimizer and a fixed learning rate. These training loops can be quite sophisticated, but I’ll present a simple one here:
import time
def train(model, loader, max_epochs, padding_value):
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1, ignore_index=padding_value)
opt = torch.optim.AdamW(model.parameters(), lr=5e-4)
total_loss = 0
for epoch in range(1, max_epochs+1):
start = time.time()
for i, (src, tgt, src_mask, tgt_mask) in enumerate(loader):
opt.zero_grad()
out = model(src, tgt, src_mask, tgt_mask).transpose(2, 1)
loss = criterion(out, tgt)
total_loss += loss
elapsed = time.time() - start
print("Epoch Step: %d Loss: %f Took %f seconds" % (i, loss, elapsed))
loss.backward()
opt.step()
start = time.time()
return total_loss
loss = train(model, loader, max_epochs=3, padding_value=padding_value)
Running this for a few steps on CPU shows that it seems to be working:
Epoch Step: 0 Loss: 9.821983 Took 0.996907 seconds
Epoch Step: 1 Loss: 7.661346 Took 0.907667 seconds
Epoch Step: 2 Loss: 7.165735 Took 1.052036 seconds
Epoch Step: 3 Loss: 6.880791 Took 0.941898 seconds
Epoch Step: 4 Loss: 6.698797 Took 1.211879 seconds
Epoch Step: 5 Loss: 6.543982 Took 0.886298 seconds
Epoch Step: 6 Loss: 6.329848 Took 1.045404 seconds
Epoch Step: 7 Loss: 6.076595 Took 1.641880 seconds
Epoch Step: 8 Loss: 5.877031 Took 1.065028 seconds
Epoch Step: 9 Loss: 5.831444 Took 1.135166 seconds
This concludes the first part of the series. In the next post I hope to implement a few decoding strategies to begin evaluating the model.