NMT with xFormers: Part 2

7 minute read


In the last post I explored subword tokenization, data prep, and the training loop for a transformer-based NMT model. While I had promised the next post would begin implementing decoding strategies, I wanted to explore efficiency gains to be had by digging into batch samplers. I’ll start by describing the problem we’re trying to solve.

By default and coarsely, when a PyTorch DataLoader is provided with batch_size=B, B sequential examples are drawn from the underlying dataset. In the case of seq2seq, these examples will likely not be of the same length and thus require padding to create a matrix of shape B x N where N is the length of the longest example in the batch. When computing losses, these padded tokens are masked as to not contribute to the loss calculation. While this glorified no-op is unavoidable, it is expensive and the problem is exacerbated when minibatches are formed with examples of dramatically varying lengths. For example, take this silly example where B=2:

sent1_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, ...,  99, 100, -1]
sent2_ids = [0, 101, -1]
# To be padded in the `collate_fn`
batch = [sent1_ids, sent2_ids]

In this contrived example, we see that sent2_ids needs to be padded by a large amount. Because excessive padding leads to wasted computation, the name of the game is to minimize padding within a batch.

When the batch_size kwarg is omitted from the DataLoader constructor, a user may provide a batch_sampler which instructs the DataLoader how to form a batch appropriately. One common way to define a batch to optimize for padding is to define a max-tokens per batch which groups similarly sized pre-padding examples within a batch such that when all examples are padded they don’t surpass a given token count. This allows us to have predictable performance characteristics since the number of tokens doesn’t vary dramatically batch to batch unlike when we batch by number of sentences (consider memory usage when training on a batch of 64 long sentences vs a batch of 64 short sentences). In this post I’ll describe how to write a batch sampler which optimizes the number of tokens in a batch.

Because we will need the encoded source and target sentences along with their lengths, I define a function which encodes and bookends the examples with BOS and EOS and returns the lengths on a per-example basis:

def encode_dataset(ds, text_transform, bos, eos):
    encoded_src, encoded_tgt = [], []
    for src, tgt in ds:
        encoded_src.append(bos + text_transform(src) + eos)
        encoded_tgt.append(bos + text_transform(tgt) + eos)
    return encoded_src, encoded_tgt, list(map(len, encoded_src)), list(map(len, encoded_tgt))

src, tgt, src_lens, tgt_lens = encode_dataset(dataset, lambda text: text_transform(text, spm, vocab), vocab(["<s>"]), vocab(["</s>"]))
# We define the max length of an example to be the length of the longer sent as source or target
max_lens = [max(src_len, tgt_len) for src_len, tgt_len in zip(src_lens, tgt_lens)]

Next we need to define the batch sampler. What we’d like this to fundamentally do is sort the examples by length and take examples so long as the batch doesn’t overflow some max_tokens parameter. I take inspiration from the allennlp max tokens batch sampler which employs a neat trick that injects some small noise into example lengths to promote diversity in batching and to make it non-deterministic so we don’t overfit to the same examples in a batch each iteration.

import random

from torch.utils.data import Sampler
from typing import List, Optional

class MaxTokenPerBatchSampler(Sampler):
    def __init__(self, lengths: List[int], max_tokens: int, max_length: Optional[int]=None, min_length: int = 1, padding_ratio: float = 0.1, shuffle: bool=True):
        if max_length is None:
            max_length = float('inf')
        assert 0 < min_length < max_length, "min_length must be positive and less than max_length"
        if not shuffle:
            padding_ratio = 0.0
        lengths = [length for length in lengths if min_length <= length <= max_length]
        noisy_lengths = [self._generate_noise(length, padding_ratio) for length in lengths]
        self.sorted_length_idx = torch.argsort(torch.tensor(noisy_lengths), descending=True).tolist()
        self.lengths = [lengths[i] for i in self.sorted_length_idx]
        self.max_length = max_length
        self.min_length = min_length
        self.max_tokens = max_tokens
        self.shuffle = shuffle

    def _generate_noise(val: int, ratio: float):
        # Inspired by allennlp
        noise_value = val * ratio
        noise = random.uniform(-noise_value, noise_value)
        return val + noise

    def __iter__(self):
        curr_max_length = 0
        batch = []
        for idx, length in zip(self.sorted_length_idx, self.lengths):
            # If this length is bigger than curr max length, we need to allocate
            # more padding elements
            max_batch_size = max(length, curr_max_length) * (len(batch) + 1)

            # If adding this element would overflow us, start new batch
            if max_batch_size > self.max_tokens:
                if self.shuffle:
                yield batch
                curr_max_length = 0
                batch = []
            curr_max_length = max(curr_max_length, length)

        if len(batch) != 0:
            if self.shuffle:
            yield batch

# Form batches of no more than 1024 post-pad tokens
batch_sampler = MaxTokenPerBatchSampler(max_lens, max_tokens=1024, shuffle=False)

One thing to observe from this function is that the batches that are being returned actually contain the indices of the examples to form the batch. These indices are returned to the DataLoader which uses them to fetch examples from the underlying Dataset to form the batch which is then passed to the collate_fn. This means that the underlying dataset needs to be a MapDataset instead of an IterableDataset. I define a dummy dataset below which implements this thin interface:

from torch.utils.data import Dataset

class DumbDataset(Dataset):
    def __init__(self, src, tgt):
        self.src, self.tgt = src, tgt

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        return torch.tensor(self.src[idx]), torch.tensor(self.tgt[idx])

# Take the encoded segments from before
ds = DumbDataset(src, tgt)

Now the only remaining business is to revisit the collate_fn which previously had the responisibility of encoding, prepending/appending BOS/EOS, padding, and collating. Since the sentences are already encoded and have the BOS/EOS bookends, we just need to pad and collate:

def collate_fn(batch, padding_value):
    # We need to pad to the longest source or target sentence
    max_len = max(max(len(src), len(tgt)) for src, tgt in batch)
    padded_src_sents = []
    padded_tgt_sents = []
    for src_encoded, tgt_encoded in batch:
        src_padded = torch.cat([src_encoded, torch.full((max_len - src_encoded.shape[0],), padding_value)], dim=0)
        tgt_padded = torch.cat([tgt_encoded, torch.full((max_len - tgt_encoded.shape[0],), padding_value)], dim=0)

    # Because padded_*_sents are tensors now we need to stack on dim=0 for batch-first.
    padded_src = torch.stack(padded_src_sents, dim=0)
    padded_tgt = torch.stack(padded_tgt_sents, dim=0)
    src_mask = (padded_src != padding_value)
    tgt_mask = (padded_tgt != padding_value)
    return padded_src, padded_tgt, src_mask, tgt_mask

Now we create the new DataLoader:

loader = DataLoader(
    ds, batch_sampler=batch_sampler,
    collate_fn=lambda batch: collate_fn(batch, padding_value)

We check that training still works, and indeed it does by inspecting the losses from first few steps with some additional information included:

Epoch 1 step: 1 Loss: 9.901693 Took 1.652534 seconds. bsz (toks): 778
Epoch 1 step: 2 Loss: 9.302147 Took 4.578519 seconds. bsz (toks): 920
Epoch 1 step: 3 Loss: 8.932611 Took 4.485247 seconds. bsz (toks): 975
Epoch 1 step: 4 Loss: 8.777491 Took 4.411197 seconds. bsz (toks): 961
Epoch 1 step: 5 Loss: 8.455295 Took 4.383704 seconds. bsz (toks): 991
Epoch 1 step: 6 Loss: 8.414553 Took 4.352589 seconds. bsz (toks): 982
Epoch 1 step: 7 Loss: 8.198702 Took 4.425001 seconds. bsz (toks): 986
Epoch 1 step: 8 Loss: 8.002873 Took 4.287635 seconds. bsz (toks): 1022
Epoch 1 step: 9 Loss: 7.746803 Took 4.183296 seconds. bsz (toks): 1016
Epoch 1 step: 10 Loss: 7.552640 Took 4.369744 seconds. bsz (toks): 998
Epoch 1 step: 11 Loss: 7.516591 Took 4.132573 seconds. bsz (toks): 1020
Epoch 1 step: 12 Loss: 7.302334 Took 4.133327 seconds. bsz (toks): 1013

While we don’t have an apples-to-apples comparison with the previous approach which batched by number of sentences, we see a steady decrease in loss. Indeed, larger batch sizes typically call for larger learning rates which have not been adjusted which explains the slower immediate convergence. In the next part I actually will explore decoding strategies… :-)