{"id":803617,"date":"2021-12-16T14:47:39","date_gmt":"2021-12-16T22:47:39","guid":{"rendered":"https:\/\/www.microsoft.com\/en-us\/research\/?post_type=msr-blog-post&p=803617"},"modified":"2021-12-16T14:47:39","modified_gmt":"2021-12-16T22:47:39","slug":"pymarlin-a-lightweight-library-that-improves-deep-learning-training-agility","status":"publish","type":"msr-blog-post","link":"https:\/\/www.microsoft.com\/en-us\/research\/articles\/pymarlin-a-lightweight-library-that-improves-deep-learning-training-agility\/","title":{"rendered":"PyMarlin: A lightweight library that improves deep learning training agility"},"content":{"rendered":"

By Amin Saied, Ananth Rao, Ashwin Srinivasan, Damien Jose, Eduardo Gonzalez, Han Yu, Jon Sleep, Krishan Subudhi, Shruti Gullapuram<\/strong><\/p>\n

PyMarlin (opens in new tab)<\/span><\/a>\u00a0is a lightweight\u00a0PyTorch (opens in new tab)<\/span><\/a>\u00a0extension library for agile experimentation.\u00a0It was designed with the goal of simplifying the end-to-end deep learning experimentation lifecycle, agnostic of the compute environment.\u00a0In July 2021, the PyMarlin team open-sourced their internal model training library to all PyTorch users. PyMarlin abstracts out all the boilerplate code for scaling, logging, and argument parsing that are crucial for training deep learning-based models. PyMarlin can be thought of as a high-level abstraction over PyTorch. We have created a five-minute \u201cGetting Started (opens in new tab)<\/span><\/a>\u201d module for anyone interested in trying out PyMarlin. Today we\u2019ll look at how PyMarlin works, how it supports extensibility, and next steps needed to advance its functionality further.<\/p>\n

How the typical deep learning training lifecycle works<\/h3>\n
\"Typical

Figure 1: Typical deep learning training steps<\/p><\/div>\n

These three steps (and their sub-steps) are the backbone of any typical deep learning model training lifecycle. But this process also involves writing a lot of code and testing. Since scientists and researchers focus mostly on the model training part, they generally write other components without following any design pattern. This makes the training code difficult to extend.<\/p>\n

For example, let\u2019s say a researcher has written a code for text summarization, including all the code necessary for scaling and logging. A fellow researcher wants to try out a new optimizer. Another colleague wants to experiment with new evaluation metrics and loss functions. And yet another scientist wants to use the same recipe but on different data. In this case, all the stakeholders make separate copies of the code and make their own modifications. But then, suppose the original researcher changes the encoder and decoder architecture and comes up with a better original model. Other stakeholders may have to change their ML code. What a waste of everyone\u2019s time!<\/p>\n

While speeding up the training using Distributed Data Parallel (DDP) and Mixed Precision, bugs can be introduced, too. For example, by using multiple GPUs and multiple nodes, batch size per GPU must be reduced to maintain the same global batch size. This can involve manual and erroneous calculation of minibatch size or number of gradient accumulation steps. Additionally, during the validation step, the outputs from multiple GPUs need to be gathered to calculate evaluation metrics accurately. \u00a0Finally, adding an optimization, such as disabling all reduce during gradient accumulation, can speed up the model. In mixed precision training using PyTorch\u2019s native amp module, gradients must be unscaled before they can be clipped.<\/p>\n

There are many open source libraries that provide functionality similar to PyMarlin. In fact, some of them have extra features which can come in quite handy. The Hugging face trainer supports other logging frameworks like wandb, but it is not model agnostic. However, PyMarlin offers unique benefits. We focused on keeping the code simple and easily readable. PyMarlin is not designed to be a black box. Power users will be able to understand PyMarlin\u2019s code and extend it as necessary.<\/p>\n

PyMarlin at a glance<\/h3>\n

A brief look at the architecture[1]<\/sup> (opens in new tab)<\/span><\/a><\/b><\/p>\n

PyMarlin has four core components: DataProcessor and DataInterface, Module Interface, Trainer Backend, and Trainer. First, we\u2019ll look at the DataProcessor and DataInterface (opens in new tab)<\/span><\/a>. The role of DataProcessor and DataInterface is to decouple data processing and dataset building from model training. The DataProcessor processes and optionally analyzes data. Users can have multiple data processors which can be chained together. DataInterface has abstract methods which the ModuleInterface calls to obtain train and validation datasets during training.<\/p>\n

The Module Interface (opens in new tab)<\/span><\/a> is where scientists and researchers write their training code. This module can be thought of as the implementation of the training recipe. Module Interface inherits from nn.Module and hence can be treated like any PyTorch module.<\/p>\n

The Trainer Backend (opens in new tab)<\/span><\/a> is responsible for training\/validating the Module Interface for one entire epoch. PyMarlin offers various useful backend implementations, such as SingleProcess, SingleProcessAmp, and DDPTrainerBackend.<\/p>\n

Finally, the Trainer (opens in new tab)<\/span><\/a> serves as the bridge between the Trainer Backend and Module Interface: it takes care of device management, rank fetching, checkpointing, and reloading. It also handles all the calculations for mini batch size, gradient accumulation, number of remaining epochs, initializing stats writers like tensor board, and restarting training from previous state.<\/p>\n

\"PyMarlin

Figure 2: Steps to follow while writing code using PyMarlin<\/p><\/div>\n

A PyMarlin Deep Dive<\/h3>\n

Although PyMarlin has four core components, it has additional features to assist coders. We\u2019ll first explore the core components in greater depth, then look at the supporting features.<\/p>\n

    \n
  1. DataProcessor and DataInterface<\/strong>
    \nThe DataProcessor modules aim to support most large-scale preprocessing requirements. DataProcessor can be seen as a single step of processing, such as reading files. Multiple DataProcessors can be used sequentially, each covering a preprocessing step. Once the business logic is added in process() function, inbuilt multiprocessing support can be easily leveraged.The business logic in the DataProcessor\u2019s process function can be invoked either on a single compute, locally, or even as a distributed job across nodes. It comes with built-in support for AML. It also allows for selective preprocessing. For example, with a large dataset you could decide how many parts it should be split into and choose the part to be processed at a time on a single node.<\/p>\n

    This (opens in new tab)<\/span><\/a> example (opens in new tab)<\/span><\/a> covers a pre-processing example with raw Wikipedia data. In that example, which splits sentences for 27 Wikipedia raw text files, we see the following time savings:<\/p>\n