NMT with xFormers: Part 2
In the last post I explored subword tokenization, data prep, and the training loop for a transformer-based NMT model. While I had promised the next post would begin implementing decoding strategies, I wanted to explore efficiency gains to be had by digging into batch samplers. I’ll start by describing the problem we’re trying to solve.
By default and coarsely, when a PyTorch
DataLoader is provided with
B sequential examples are drawn from the underlying dataset. In the case of seq2seq, these examples will likely not be of the same length and thus require padding to create a matrix of shape
N is the length of the longest example in the batch. When computing losses, these padded tokens are masked as to not contribute to the loss calculation. While this glorified no-op is unavoidable, it is expensive and the problem is exacerbated when minibatches are formed with examples of dramatically varying lengths. For example, take this silly example where
sent1_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, ..., 99, 100, -1] sent2_ids = [0, 101, -1] # To be padded in the `collate_fn` batch = [sent1_ids, sent2_ids]
In this contrived example, we see that
sent2_ids needs to be padded by a large amount. Because excessive padding leads to wasted computation, the name of the game is to minimize padding within a batch.
batch_size kwarg is omitted from the
DataLoader constructor, a user may provide a
batch_sampler which instructs the
DataLoader how to form a batch appropriately. One common way to define a batch to optimize for padding is to define a max-tokens per batch which groups similarly sized pre-padding examples within a batch such that when all examples are padded they don’t surpass a given token count. This allows us to have predictable performance characteristics since the number of tokens doesn’t vary dramatically batch to batch unlike when we batch by number of sentences (consider memory usage when training on a batch of 64 long sentences vs a batch of 64 short sentences). In this post I’ll describe how to write a batch sampler which optimizes the number of tokens in a batch.
Because we will need the encoded source and target sentences along with their lengths, I define a function which encodes and bookends the examples with BOS and EOS and returns the lengths on a per-example basis:
def encode_dataset(ds, text_transform, bos, eos): encoded_src, encoded_tgt = ,  for src, tgt in ds: encoded_src.append(bos + text_transform(src) + eos) encoded_tgt.append(bos + text_transform(tgt) + eos) return encoded_src, encoded_tgt, list(map(len, encoded_src)), list(map(len, encoded_tgt)) src, tgt, src_lens, tgt_lens = encode_dataset(dataset, lambda text: text_transform(text, spm, vocab), vocab(["<s>"]), vocab(["</s>"])) # We define the max length of an example to be the length of the longer sent as source or target max_lens = [max(src_len, tgt_len) for src_len, tgt_len in zip(src_lens, tgt_lens)]
Next we need to define the batch sampler. What we’d like this to fundamentally do is sort the examples by length and take examples so long as the batch doesn’t overflow some
max_tokens parameter. I take inspiration from the allennlp max tokens batch sampler which employs a neat trick that injects some small noise into example lengths to promote diversity in batching and to make it non-deterministic so we don’t overfit to the same examples in a batch each iteration.
import random from torch.utils.data import Sampler from typing import List, Optional class MaxTokenPerBatchSampler(Sampler): def __init__(self, lengths: List[int], max_tokens: int, max_length: Optional[int]=None, min_length: int = 1, padding_ratio: float = 0.1, shuffle: bool=True): if max_length is None: max_length = float('inf') assert 0 < min_length < max_length, "min_length must be positive and less than max_length" if not shuffle: padding_ratio = 0.0 lengths = [length for length in lengths if min_length <= length <= max_length] noisy_lengths = [self._generate_noise(length, padding_ratio) for length in lengths] self.sorted_length_idx = torch.argsort(torch.tensor(noisy_lengths), descending=True).tolist() self.lengths = [lengths[i] for i in self.sorted_length_idx] self.max_length = max_length self.min_length = min_length self.max_tokens = max_tokens self.shuffle = shuffle @staticmethod def _generate_noise(val: int, ratio: float): # Inspired by allennlp noise_value = val * ratio noise = random.uniform(-noise_value, noise_value) return val + noise def __iter__(self): curr_max_length = 0 batch =  for idx, length in zip(self.sorted_length_idx, self.lengths): # If this length is bigger than curr max length, we need to allocate # more padding elements max_batch_size = max(length, curr_max_length) * (len(batch) + 1) # If adding this element would overflow us, start new batch if max_batch_size > self.max_tokens: if self.shuffle: random.shuffle(batch) yield batch curr_max_length = 0 batch =  curr_max_length = max(curr_max_length, length) batch.append(idx) if len(batch) != 0: if self.shuffle: random.shuffle(batch) yield batch # Form batches of no more than 1024 post-pad tokens batch_sampler = MaxTokenPerBatchSampler(max_lens, max_tokens=1024, shuffle=False)
One thing to observe from this function is that the batches that are being returned actually contain the indices of the examples to form the batch. These indices are returned to the
DataLoader which uses them to fetch examples from the underlying
Dataset to form the batch which is then passed to the
collate_fn. This means that the underlying dataset needs to be a
MapDataset instead of an
IterableDataset. I define a dummy dataset below which implements this thin interface:
from torch.utils.data import Dataset class DumbDataset(Dataset): def __init__(self, src, tgt): self.src, self.tgt = src, tgt def __len__(self): return len(self.src) def __getitem__(self, idx): return torch.tensor(self.src[idx]), torch.tensor(self.tgt[idx]) # Take the encoded segments from before ds = DumbDataset(src, tgt)
Now the only remaining business is to revisit the
collate_fn which previously had the responisibility of encoding, prepending/appending BOS/EOS, padding, and collating. Since the sentences are already encoded and have the BOS/EOS bookends, we just need to pad and collate:
def collate_fn(batch, padding_value): # We need to pad to the longest source or target sentence max_len = max(max(len(src), len(tgt)) for src, tgt in batch) padded_src_sents =  padded_tgt_sents =  for src_encoded, tgt_encoded in batch: src_padded = torch.cat([src_encoded, torch.full((max_len - src_encoded.shape,), padding_value)], dim=0) tgt_padded = torch.cat([tgt_encoded, torch.full((max_len - tgt_encoded.shape,), padding_value)], dim=0) padded_src_sents.append(src_padded) padded_tgt_sents.append(tgt_padded) # Because padded_*_sents are tensors now we need to stack on dim=0 for batch-first. padded_src = torch.stack(padded_src_sents, dim=0) padded_tgt = torch.stack(padded_tgt_sents, dim=0) src_mask = (padded_src != padding_value) tgt_mask = (padded_tgt != padding_value) return padded_src, padded_tgt, src_mask, tgt_mask
Now we create the new
loader = DataLoader( ds, batch_sampler=batch_sampler, collate_fn=lambda batch: collate_fn(batch, padding_value) )
We check that training still works, and indeed it does by inspecting the losses from first few steps with some additional information included:
Epoch 1 step: 1 Loss: 9.901693 Took 1.652534 seconds. bsz (toks): 778 Epoch 1 step: 2 Loss: 9.302147 Took 4.578519 seconds. bsz (toks): 920 Epoch 1 step: 3 Loss: 8.932611 Took 4.485247 seconds. bsz (toks): 975 Epoch 1 step: 4 Loss: 8.777491 Took 4.411197 seconds. bsz (toks): 961 Epoch 1 step: 5 Loss: 8.455295 Took 4.383704 seconds. bsz (toks): 991 Epoch 1 step: 6 Loss: 8.414553 Took 4.352589 seconds. bsz (toks): 982 Epoch 1 step: 7 Loss: 8.198702 Took 4.425001 seconds. bsz (toks): 986 Epoch 1 step: 8 Loss: 8.002873 Took 4.287635 seconds. bsz (toks): 1022 Epoch 1 step: 9 Loss: 7.746803 Took 4.183296 seconds. bsz (toks): 1016 Epoch 1 step: 10 Loss: 7.552640 Took 4.369744 seconds. bsz (toks): 998 Epoch 1 step: 11 Loss: 7.516591 Took 4.132573 seconds. bsz (toks): 1020 Epoch 1 step: 12 Loss: 7.302334 Took 4.133327 seconds. bsz (toks): 1013
While we don’t have an apples-to-apples comparison with the previous approach which batched by number of sentences, we see a steady decrease in loss. Indeed, larger batch sizes typically call for larger learning rates which have not been adjusted which explains the slower immediate convergence. In the next part I actually will explore decoding strategies… :-)