MTMA22 Day 1: Project Definition
Published:
This year I’m fortunate to be co-leading a Machine Translation Marathon in the Americas project with Marcin Junczys-Dowmunt. The project has modest goals: to add Python bindings to Marian NMT.
Marian, being a C++ toolkit, enjoys speed afforded by close-to-metal programming with fine-grained memory control. However, being a C++ library introduces a lot of complexity for those unfamiliar with C++, including dependency management, debugging, installation, and embedding into other applications. A common request for Marian developers is to add Python bindings to help mitigate some of these issues. We set out on this MTMA project with the goal is adding shallow bindings for some of the most common classes that we expect to be used by Python developers, starting with inference.
There are many ways to add Python bindings to existing C++ code ranging from Cython to SWIG to CFFI. We elected to use Pybind11 following the philosophy that, since Marian is a C++ library, adding Python bindings using C++ makes sense to reduce maintenance burden.
We started by defining a high-level Python API and example usage we expect users to want to follow and landed on the following mocked up API:
import pymarian
import sentencepiece as spm
src_spm = spm.SentencePieceProcessor(source_spm_model)
tgt_spm = spm.SentencePieceProcessor(target_spm_model)
# may change
translator = pymarian.Translator(yaml_file)
for sent in sents_to_translate:
src_pieces_str = " ".join(src_spm.Encode(sent))
tgt_pieces = translator.translate(src_pieces_str).split()
translation = " ".join(tgt_spm.Decode(tgt_pieces))
print(f"Translation of '{sent}' is '{translation}'")
We first needed to identify the right bit of Marian to wrap to achieve this API. We landed on marian::TranslateService<Search>
as the most appropriate class to wrap and identified a couple of snags.
marian::TranslateService<Search>::run
only handles string-in, string-out translation which batches internally. This is fine as is, but we may expect users to want to pass lists of strings without requiring unnecessary joins and splits. This is a TODO.marian::TranslateService<Search>
is constructed with using a pointer to amarian::Options
object, which is similar to a lookup table which can be parsed from a number of configuration formats. To iterate quickly we locally added a C++ constructor which loads a YAML file directly. We may want to clean this up to be a factory instead, but that’s a TODO.
Now we have C++ and Python API parity for inference. We begun digging into the Pybind11 bit which was extremely straightforward in the initial pass:
#include "pybind11/pybind11.h"
#include "marian.h"
#include "translator/translator.h"
#include "translator/beam_search.h"
namespace py = pybind11;
PYBIND11_MODULE(pymarian, m) {
// Classes
py::class_<marian::TranslateService<marian::BeamSearch>>(m, "Translator")
.def(py::init<std::string>())
.def("translate", &marian::TranslateService<marian::BeamSearch>::run);
}
This just defines a Python module pymarian
containing a single class Translator
with a str
-constructor (corresponding to the YAML file) and a str
-in, str
-out translate
method which just calls te run
method described above.
In total this took about 30 minutes. We spent much of the rest of the day fiddling with setuptools
and cmake
to make (heh) sure they can speak to one another. The details are a bit gritty so I won’t describe them here but after a single day we nearly have a working Python implementation of Marian inference. I expect on day 2 we will harden this API, add more functionality, and begin benchmarking the code to see the delta between native and wrapped code.