Microsoft Search, Assistant and Intelligence

PyMarlin: A lightweight library that improves deep learning training agility

Share this page

By Amin Saied, Ananth Rao, Ashwin Srinivasan, Damien Jose, Eduardo Gonzalez, Han Yu, Jon Sleep, Krishan Subudhi, Shruti Gullapuram

PyMarlin (opens in new tab) is a lightweight PyTorch (opens in new tab) extension library for agile experimentation. It was designed with the goal of simplifying the end-to-end deep learning experimentation lifecycle, agnostic of the compute environment. In July 2021, the PyMarlin team open-sourced their internal model training library to all PyTorch users. PyMarlin abstracts out all the boilerplate code for scaling, logging, and argument parsing that are crucial for training deep learning-based models. PyMarlin can be thought of as a high-level abstraction over PyTorch. We have created a five-minute “Getting Started (opens in new tab)” module for anyone interested in trying out PyMarlin. Today we’ll look at how PyMarlin works, how it supports extensibility, and next steps needed to advance its functionality further.

How the typical deep learning training lifecycle works

Typical deep learning training steps

Figure 1: Typical deep learning training steps

These three steps (and their sub-steps) are the backbone of any typical deep learning model training lifecycle. But this process also involves writing a lot of code and testing. Since scientists and researchers focus mostly on the model training part, they generally write other components without following any design pattern. This makes the training code difficult to extend.

For example, let’s say a researcher has written a code for text summarization, including all the code necessary for scaling and logging. A fellow researcher wants to try out a new optimizer. Another colleague wants to experiment with new evaluation metrics and loss functions. And yet another scientist wants to use the same recipe but on different data. In this case, all the stakeholders make separate copies of the code and make their own modifications. But then, suppose the original researcher changes the encoder and decoder architecture and comes up with a better original model. Other stakeholders may have to change their ML code. What a waste of everyone’s time!

While speeding up the training using Distributed Data Parallel (DDP) and Mixed Precision, bugs can be introduced, too. For example, by using multiple GPUs and multiple nodes, batch size per GPU must be reduced to maintain the same global batch size. This can involve manual and erroneous calculation of minibatch size or number of gradient accumulation steps. Additionally, during the validation step, the outputs from multiple GPUs need to be gathered to calculate evaluation metrics accurately.  Finally, adding an optimization, such as disabling all reduce during gradient accumulation, can speed up the model. In mixed precision training using PyTorch’s native amp module, gradients must be unscaled before they can be clipped.

There are many open source libraries that provide functionality similar to PyMarlin. In fact, some of them have extra features which can come in quite handy. The Hugging face trainer supports other logging frameworks like wandb, but it is not model agnostic. However, PyMarlin offers unique benefits. We focused on keeping the code simple and easily readable. PyMarlin is not designed to be a black box. Power users will be able to understand PyMarlin’s code and extend it as necessary.

PyMarlin at a glance

A brief look at the architecture[1] (opens in new tab)

PyMarlin has four core components: DataProcessor and DataInterface, Module Interface, Trainer Backend, and Trainer. First, we’ll look at the DataProcessor and DataInterface (opens in new tab). The role of DataProcessor and DataInterface is to decouple data processing and dataset building from model training. The DataProcessor processes and optionally analyzes data. Users can have multiple data processors which can be chained together. DataInterface has abstract methods which the ModuleInterface calls to obtain train and validation datasets during training.

The Module Interface (opens in new tab) is where scientists and researchers write their training code. This module can be thought of as the implementation of the training recipe. Module Interface inherits from nn.Module and hence can be treated like any PyTorch module.

The Trainer Backend (opens in new tab) is responsible for training/validating the Module Interface for one entire epoch. PyMarlin offers various useful backend implementations, such as SingleProcess, SingleProcessAmp, and DDPTrainerBackend.

Finally, the Trainer (opens in new tab) serves as the bridge between the Trainer Backend and Module Interface: it takes care of device management, rank fetching, checkpointing, and reloading. It also handles all the calculations for mini batch size, gradient accumulation, number of remaining epochs, initializing stats writers like tensor board, and restarting training from previous state.

PyMarlin Steps

Figure 2: Steps to follow while writing code using PyMarlin

A PyMarlin Deep Dive

Although PyMarlin has four core components, it has additional features to assist coders. We’ll first explore the core components in greater depth, then look at the supporting features.

  1. DataProcessor and DataInterface
    The DataProcessor modules aim to support most large-scale preprocessing requirements. DataProcessor can be seen as a single step of processing, such as reading files. Multiple DataProcessors can be used sequentially, each covering a preprocessing step. Once the business logic is added in process() function, inbuilt multiprocessing support can be easily leveraged.The business logic in the DataProcessor’s process function can be invoked either on a single compute, locally, or even as a distributed job across nodes. It comes with built-in support for AML. It also allows for selective preprocessing. For example, with a large dataset you could decide how many parts it should be split into and choose the part to be processed at a time on a single node.

    This (opens in new tab) example (opens in new tab) covers a pre-processing example with raw Wikipedia data. In that example, which splits sentences for 27 Wikipedia raw text files, we see the following time savings:

    • Single node, without multi-processing: 2.5 hours
    • Single node, with multi-processing: 20 minutes
    • Multi-node (4), with multi process: 13 minutes
  2. ModuleInterface (opens in new tab)
    This interface contains model architecture in the form of a PyTorch nn.Module together with optimizers and schedulers, train and validation step recipes, and any callbacks. Scientists need to implement the abstract functions to create a training recipe. This recipe can be further extended, too. In general, ModuleInterface takes a DataInterface instance as input. DataInterface is called upon to return the datasets which ModuleInterface uses to create DataLoaders. The forward function is overridden and replaced with two functions — train_step and val_step — to differentiate training and validation loop code. ModuleInterface also inherits from CallBackInterface. Users can optionally override callbacks like on_end_val_epoch() to calculate metrics. We created an example ModuleInterface (opens in new tab) for further reference.
  3. TrainerBackend
    In PyMarlin, we’ve made distributed (DDP specifically) training, as well as FP16 training, easy by implementing them as backend trainers. You can use them by setting the trainer backend string (this can also be done similarly by setting trainer.backend in the YAML config) as follows:

    trainer = Trainer(module_interface=MyModuleInterface(…), backend=”ddp-amp”)
    train()
    

    Behind the scenes, many things are happening, such as loss scaling for FP16 and distributed output collection that would otherwise require dozens of lines of boilerplate code copy-pasted from scenario to scenario. Each trainer backend file is separated for modularity. Having multiple trainer backends makes the code extendable and clutter-free. Currently we support the following backends:

    1. SingleProcess (sp)
      To train in single cpu or gpu
    2. SingleProcessAmp (sp-amp)
      To train using single gpu and mixed precision. Recommended for V100 or A100 GPUs.
    3. SingleProcessApexAmp (sp-amp-apex)
      Same as SingleProcessAmp but uses nvidia apex library instead of native amp.
    4. DDPTrainerBackend
      A decorator that can convert any of the other backends to work in distributed data parallel setting.
      backend = DDPTrainerBackend(SingleProcessAmp())

    More information can be found in the documentation (opens in new tab).

  4. Trainer
    The Trainer is responsible for coordinating the model definition (ModuleInterface) and the TrainerBackend, connecting the high-level model recipe with the backend on which it will be trained. The Trainer can scale to multiple processes. It automatically handles fetching ranks to scale PyMarlin training from one GPU to multiple GPUs using torch.distributed.launch, or AzureML MPI from the environment variables, and passes it to the backend as shown in Figure 3.Other frameworks can also be integrated easily. For example, if users are spawning multiple processes using a custom script or a framework other than Azure ML, they can write a function to fetch the ranks and create an instance of DistributedTrainingArguments (opens in new tab) and pass it as a TrainerArgument.

    The Trainer also allows users to move ModuleInterface to their device. After fetching the ranks, PyMarlin moves the ModuleInterface to the local_rank GPU. Inputs to ModuleInterface’s train_step and val_step are not moved to the device; the user is responsible to move them. Users can extend and modify the ‘to’ function to change the device movement behavior. `To` is a torch.nn.Module function, hence proper care must be taken before overriding this function. Model parallelism is not supported out of the box but can be achieved by writing custom code for rank fetching and model movement.

    Trainer lifecycle

    Figure 3: Trainer lifecycle

     

  5. Stats and Loggers
    We have implemented a wrapper on Tensorboard’s SummaryWriter for logging stats to Tensorboard (TB), which makes it easy to use the utility to save TB events and visualize on TB later, for tracking the progress of your training experiment. We also have the Azure ML and stdout writers to be able to write out your stats to the logs. Users can create their own writers and pass to the trainer. Currently PyMarlin supports three writers out of the box: StdOut, Tensorboard, and AML.
  6. Checkpointer
    Checkpointing is made simple with a built-in checkpoint utility module that offers a default implementation which saves the state of the ModuleInterface, TrainerBackend, and Trainer at every epoch, and loads any model checkpoints available at the start of training. Users can control via arguments the save and load directories, as well as customize the frequency of checkpointing, and easily perform tasks such as resuming training after any number of steps, because optimizers and schedulers and additional information are stored as a part of the default checkpointing. However, in line with the goal of offering flexibility and extensibility, users can implement their own checkpointers by extending the Abstract Checkpointer class for custom checkpointing logic. As shown in Figure 4, users can also save any states or variables from the three core classes mentioned by overwriting the get_state() and update_state() methods from each class, which are called upon by the checkpointer at each save() and load() respectively. For an example on how to implement a custom checkpointer, please visit our documentation site (opens in new tab).

    Checkpointing design

    Figure 4: Checkpointing design

     

  7. AML integration
    AML integration is out of the box using Azure ML (MPI) launcher. Trainer fetches ranks from the MPI environment variables, which are set if Azure ML is used to spawn the nodes and processes. You can use the right backend (DDP, DDP-AMP) to ensure distributed training on AML compute.
  8. Yaml parser
    It’s hard to keep track of the many hyperparameters in deep learning experiments. To ease this, we offer a custom arguments parser that allows you to maintain a YAML file containing all parameters. Values which need to be overriden during experimentation can also be passed via command line. Example: If this is your YAML file with following configs:

    trainer:
        backend: "sp"
        train_batch_size: 32
        val_batch_size: 16
        epochs: 1 # Total epochs to run.
    

    You can modify this while running the script via command line as:

    python Myscript.py --trainer.backend “sp-amp”

    The simplest example of PyMarlin in action is for CIFAR Image Classification, for which we have a Collab notebook. We recommend following along there: CIFAR.ipynb – Colaboratory (google.com) (opens in new tab). The notebook goes through the major steps of the workflow for creating a PyMarlin Scenario. These are:

    • Data preprocessing and analysis (through implementing DataProcessor and DataInterface)
    • Defining the model and training setup (optimizers,LR Scheduler), and validation metrics through ModuleInterface
    • Start training by initializing instances of DataInterface, ModuleInterface and Trainer and executing Trainer.train()

Extensibility

The biggest benefit of PyMarlin is its extensibility.

You can change the dataset but keep the training recipe the same (CIFAR to MNIST). For example, to reuse a great Image Classification model for another task, you can keep your ModuleInterface almost the same and just implement a new data interface. For switching to MNIST from our CIFAR example, this is simple with torchvision extension package. Changes from CIFAR are highlighted:

from pymarlin.core import data_interface
class MNISTDataProcessor(data_interface.DataProcessor):
    def process(self):
        '''
        Downloads and caches the CIFAR data.
        Normalizes the data and creates torch datasets
        '''
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
        testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
        self.datasets = {'Train': trainset, 'Test': testset}
        return self.datasets

    def analyze(self):
        '''
        Displays size of train and test data sets.
        prints few images and their labels from train dataset
        '''
        datasets = self.datasets
        print(f'train data size = {len(datasets["Train"])}')
        print(f'val data size = {len(datasets["Test"])}')
        print('Examples')
        random_indices = np.random.choice(range(len(datasets['Train'])),5, False)
        sample_images = [datasets['Train'][i][0] for i in random_indices]
        sample_labels = [datasets['Train'][i][1] for i in random_indices]
        self._imshow(torchvision.utils.make_grid(sample_images))
        classes = [str(digit) for digit in range(10)]
        print('| '.join('%5s' % classes[sample_labels[j]] for j in range(len(sample_labels))))
        
    def _imshow(self,img):
        img = img / 2 + 0.5     # unnormalize
        npimg = img.numpy()
        plt.figure(figsize = (10,5))
        plt.imshow(np.transpose(npimg, (1, 2, 0))) # height x width x channels
        plt.show()


You can also change the model architecture but keep the data the same. Using the CIFAR example above, this would still be incredibly simple with torchvision, assuming we’d keep the optimizer and everything else the same:

 from torchvision.models import resnet18
	 def __init__(self, data_interface):
        super().__init__() # always initialize superclass first
        self.data_interface = data_interface
       
  self.net = resnet18(pretrained=True)

All PyMarlin code is modular and extendable. We encourage the Open Source community to contribute new additions to the library. The BART CNN/DailyMail Summarization example (opens in new tab) shows the process for creating a new trainer backend for ONNXRuntimeTraining: PyMarlin/ORT_README.md at main · microsoft/PyMarlin (github.com) (opens in new tab).

Future Roadmap

The PyMarlin team plans to support more trainer backends that will enable users wanting to train large models through DeepSpeed model parallelism. Differential privacy training is another feature we plan to support. We are always looking for ways to further expand and improve PyMarlin and we welcome contributions from the community! You can find our external contribution guidelines here (opens in new tab).

PyMarlin team – we’re hiring

We are a group of applied scientists and engineers (Amin Saied, Ananth Rao, Ashwin Srinivasan, Damien Jose, Eduardo Gonzalez, Han Yu, Jon Sleep, Krishan Subudhi, Shruti Gullapuram, Alejandro Stevenson-Duran, Manash Goswami) from Microsoft Office and Azure who are enthusiastic about running extensible and scalable code and working on large-scale language model pretraining and finetuning for enterprise scenarios. If this type of work interests you, the PyMarlin team in MSAI (opens in new tab) is hiring both scientists and engineers! Please visit our careers page (opens in new tab).


1 (opens in new tab)More information can be found in our documentation (opens in new tab)