NVIDIA Apex: Tools for Easy Mixed-Precision Training in PyTorch

Most deep learning frameworks, including PyTorch, train using 32-bit floating point (FP32) arithmetic by default. However, using FP32 for all operations is not essential to achieve full accuracy for many state-of-the-art deep neural networks (DNNs). In 2017, NVIDIA researchers developed a methodology for mixed-precision training in which a few operations are executed in FP32 while the majority of the network is executed using 16-bit floating point (FP16) arithmetic. FP16 arithmetic offers the following additional performance benefits on Volta GPUs:

  • FP16 reduces memory bandwidth and storage requirements by 2X. Bandwidth-bound operations can realize up to 2X speedup immediately.
  • FP16 arithmetic enables Volta Tensor Cores which offer 125 TFlops of computational throughput on generalized matrix-matrix multiplications (GEMMs) and convolutions, an 8X increase over FP32.

Mixed-precision training enables networks to receive almost all the memory savings and improved throughput of pure FP16 training while matching the accuracy of FP32 training. A number of recently published results demonstrate the accuracy and high performance of the mixed-precision recipe:

  • Facebook AI Research’s FAIRseq translation network achieves a nearly 5X speedup over pure FP32 training on the same number of GPUs, and state of the art BLEU score on an English to German translation task.
  • NVIDIA Research’s Sentiment Analysis attains an up to 4.5x speedup over pure FP32 training on the same number of GPUS, with the techniques being used to train multiple state of the art models.
  • Researchers from NVIDIA and Baidu recently showed that a wide range of bellwether networks, applied to a wide range of tasks, achieve comparable or superior test accuracy when trained with mixed-precision, using the same hyperparameters and training schedules as pure FP32 baselines.

A detailed description of mixed-precision training can be found in this post. In brief, the methodology is:

  1. FP32 master parameters to store and accumulate updates.
  2. Loss scaling to prevent underflowing gradients.
  3. A few operations (e.g. large reductions) converted to FP32.
  4. Everything else (the majority of the network) executed in FP16.

Mixed-Precision in PyTorch

PyTorch has comprehensive built-in support for mixed-precision training. Calling .half() on a module converts its parameters to FP16, and calling .half() on a tensor converts its data to FP16. Any operations performed on such modules or tensors will be carried out using fast FP16 arithmetic. PyTorch also has strong built-in support for NVIDIA math libraries (cuBLAS and cuDNN). These libraries use Tensor Cores to perform GEMMs (e.g., fully connected layers) and convolutions on FP16 data. A GEMM with dimensions [M, K] x [K, N] -> [M, N], allows cuBLAS to use Tensor Cores, assuming that M, K, and N be multiples of 8.

Introducing Apex

We developed Apex to streamline the mixed-precision user experience and enable researchers to leverage mixed-precision training in their models more conveniently. Apex is a lightweight PyTorch extension that contains two alternative tools for mixed-precision training:

  1. Amp: A library for automatically enabling all the steps of mixed-precision training.
  2. FP16_Optimizer: A class that wraps an existing PyTorch optimizer instance. FP16_Optimizerhandles master weights and loss scaling automatically, and can be implemented in an existing half-precision training script by changing only two lines.

We discuss these tools in greater depth below. Each represents a different point in the tradeoff space for enabling mixed-precision training.

Amp emphasizes simplicity by performing relatively low-level modifications to the running model. You need not worry about mixed types when writing or running your model training script. The price of simplicity is reduced control. Models that use PyTorch in less common ways may find Amp’s assumptions don’t fit as well. However, hooks exist to modify those assumptions as needed.

In contrast, the primary emphasis of FP16_Optimizer is control. It operates at the user-API level of PyTorch and so can be easily adapted to unusual or sophisticated applications. It offers slightly less simplicity than Amp since the top-level script is responsible for specifying precision of operations internal to the model.

We recommend that anyone getting started with mixed-precision training start with Amp. Those seeking more control or who find Amp’s restrictions limiting should look at FP16_Optimizer.

Drop-in Mixed-Precision Training: Amp

Our first tool for enabling mixed-precision training in PyTorch is Amp (“automatic mixed-precision”). Amp’s primary focus is simplicity: getting nearly all of the benefits of mixed-precision training without any explicit management of master weights, loss scale, or type conversions.

Integrating Amp into an existing PyTorch model

Integrating Amp into an existing PyTorch script requires two steps:

  1. Initialize Amp so it can insert the necessary modifications to PyTorch internal functions.
  2. Mark where backpropagation (.backward()) occurs so that Amp can both scale the loss and clear per-iteration state.

Step one is a single line of code:

amp_handle = amp.init(enabled=True)

You can instead pass enabled=False to make everything Amp does a no-op and leave execution unchanged. (You usually have a single command line argument to enable or disable Amp.)

Step two requires you to identify where in your code the backward pass occurs. You’ll see a few lines of code that look like the following:

loss = criterion(…)
loss.backward()
optimizer.step()

To enable loss scaling, you simply wrap the backward pass in the Amp context manager:

loss = criterion(…)
with amp_handle.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()
optimizer.step()

And that’s it. You can now re-run your script and have mixed-precision training enabled.

The Amp API offers additional features to handle complications like multiple optimizers, multiple backward passes, and working with custom C++ or CUDA layers not part of native PyTorch. Complete documentation can be found here.

How Amp works

Amp works at the logical level by employing a whitelist / blacklist model. PyTorch’s tensor operations include neural network functions like torch.nn.functional.conv2d, basic math functions like torch.log, and tensor methods like torch.Tensor.__add__ (called when you write a + b for two tensors). Note that these functions are a level below the neural network Module API. Modules (e.g., torch.nn.Conv2d) call into the corresponding functions for their implementation.

We divide the universe of functions into three sets:

  • Whitelist. Functions where we expect a speedup with FP16 math. The most common examples of these are the matrix multiply and convolution functions.
  • Blacklist. Functions for which 16 bits of precision may not be sufficient, so we want to ensure that inputs are in FP32. The most common examples of these are the neural net loss functions like softmax with cross entropy.
  • Everything else (whatever functions are leftover). These include functions for which FP16 can work but the cost of an FP32 -> FP16 cast to run them in FP16 isn’t worthwhile since the speedup is small.

In principle, the job of Amp is straightforward. Whenever a PyTorch function gets called, Amp checks whether it is whitelist / blacklist / neither. If whitelist, cast all arguments to FP16. If blacklist, cast all arguments to FP32. Finally, if neither, simply ensure all arguments are of the same type (casting to the widest type if not). In practice, though, implementing the above policy is not entirely straightforward.

Capturing function calls

Because PyTorch is so flexible and dynamic (a good thing!), it lacks a static model object or graph to latch onto and insert the casts described above. Instead, Amp does so dynamically by “monkey patching” the necessary functions to intercept and cast their arguments. For example, to ensure that torch.nn.functional.linear always casts its arguments to fp16, you can write code like this:

orig_linear = torch.nn.functional.linear
def wrapped_linear(*args):
   casted_args = []
   for arg in args:
      if torch.is_tensor(arg) and torch.is_floating_point(arg):
         casted_args.append(torch.cast(arg, torch.float16))
      else:
         casted_args.append(arg)
   return orig_linear(*casted_args)
torch.nn.functional.linear = wrapped_linear

Other subtleties exist to make the code more robust (different argument types, keyword arguments), but what Amp essentially does on a call to amp.init() is insert monkey patches on all of the relevant PyTorch functions so that arguments are casted appropriately at runtime.

Minimizing casts

One additional challenge with the function-based casting approach remains. Naively applied, any sort of parameter sharing induces multiple casts of the same weight on each iteration. For example, the nn.RNNCell module will call an RNN function once for each timestep with the same FP32 weight arguments.

To ensure each weight is casted FP32 -> FP16 no more than once per iteration, Amp keeps an internal cache of any parameter casts and reuses casted versions when appropriate. The context manager around the backward pass indicates to Amp when to clear the cache at each iteration.

Mixed-Precision Training with FP16_Optimizer

FP16_Optimizer automatically implements the master weights and loss scaling prescribed by mixed-precision training. Unlike Amp, it does not monkey patch any internal PyTorch functions, so anything within the model declared to run in FP16 will run in FP16, while anything declared to run in FP32 will run in FP32. In this way, FP16_Optimizer is less invasive, and offers additional control. Here’s how that looks in a minimal working sample:

1  x = torch.cuda.FloatTensor(batch, dim_in).normal_().half()
2  y = torch.cuda.FloatTensor(batch, dim_out).normal_().half()
3  model = torch.nn.Linear(dim_in, dim_out).cuda().half()
4
5  optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

6  ### construct FP16_Optimizer
7  optimizer = FP16_Optimizer(optimizer, static_loss_scale=128.0)
8
9  loss_fn = torch.nn.MSELoss()
10
11 for t in range(200):
12     optimizer.zero_grad()
13     y_pred = model(x)
14     loss = loss_fn(y_pred.float(), y.float())
15     optimizer.backward(loss) ### formerly loss.backward()
16     optimizer.step()

Only two lines differ compared to pure FP16 training. In line 7, FP16_Optimizer is constructed from an existing optimizer and told to use a static loss scale of 128. In line 15, the typical call to loss.backward() becomes optimizer.backward(loss). These two lines are all that’s required for FP16_Optimizer to transparently implement master weights and loss scaling.

An important thing to notice is that the user specifies what parts of the model and which input data should use FP16 with FP16_Optimizer, hence the calls to x.half(), y.half(), and model.half(). This is not necessary in Amp because conversions occur on-the-fly within monkey-patched PyTorch functions.

FP16_Optimizer Options

You can select dynamic loss scaling instead of static loss scaling by changing the constructor call:

optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)

We recommend dynamic loss scaling for production runs. It adjusts the loss scale on the fly without user intervention. FP16_Optimizer will report when it attempts a loss scale that is too high (resulting in overflow) and readjusts to a smaller value:

[OVERFLOW! Skipping step. Attempted loss scale: 1048576.0, reducing to 524288.0]

This behavior is normal and is not regarded as an error. Dynamic loss scaling may be modestly (<5%) slower than static loss scaling due to the additional overflow check it performs every timestep. See the Apex GitHub repo for a more detailed explanation of how dynamic loss scaling operates.

Loss scaling, whether static or dynamic, should not require retuning the learning rate or any other hyperparameters.

FP16_Optimizer supports optimizers like LBFGS that require closures, supports distributed training via torch.nn.parallel.DistributedDataParallel or apex.parallel.DistributedDataParallel, and supports saving and loading via the same interface as ordinary PyTorch optimizers (so you shouldn’t have to change the lines that save/restore your model).

Our simple examples demonstrate closure use, distributed training, and saving/loading. Additional examples demonstrate FP16_Optimizer in PyTorch’s Imagenet and word_language_model training scripts.

FP16_Optimizer Under the Hood

FP16_Optimizer’s constructor parses the existing optimizer’s parameters and notes which of them are FP16. For each FP16 weight, it creates an FP32 master weight. If any existing weights are already FP32, FP16_Optimizer lets the “master” weight simply be a reference to the existing weight for efficiency. In this way FP16_Optimizer handles models that have a mixture of FP16 and FP32 weights.

During construction, FP16_Optimizer also copies the existing optimizer’s type, its hyperparameters, and any param_groups into which the weights are organized, so that weights are updated as the user intended. During optimizer.step() on Line 16 above (where optimizer is now an instance of FP16_Optimizer) the master weights are updated then copied to your original model’s weights (as the mixed-precision recipe prescribes). After the call to optimizer.step(), your model’s weights (which may be all FP16, or a mixture of FP16 and FP32) should be what you expect. Using FP16_Optimizer does not alter the structure of your original model or the types of its parameters. Any FP32 master weight copies are maintained independently outside the model.

Line 15 is present because FP16_Optimizer requires ownership of the backward pass in order to scale the loss, copy any scaled FP16 model gradients to FP32 master gradients, then downscale the master gradients.

FP16_Optimizer Tips

After the call to optimizer.backward(loss) on Line 15, the gradients found in the .grad attributes of your model’s FP16 weights might not be as expected. This is because they remain scaled by whatever loss scale FP16_Optimizer just used. However, FP16_Optimizer downscales master gradients before applying them to master weights, so master weight updates will be consistent with whatever learning rate you supply.

To accommodate some implementation details, FP16_Optimizer reserves the right to alter the original optimizer passed to its constructor. Therefore, the original optimizer instance should no longer be used after FP16_Optimizer has been constructed.

Get Started with Apex

Installation instructions can be found on Apex GitHub page and complete API documentation can be found here. Apex was developed in dialogue with deep learning researchers at NVIDIA and the external community. It’s an open source project, and we welcome any suggestions, feature requests, bug reports, or contributions. Feel free to submit PRs and issues on Github, or leave a comment below.

Appendix: References and Links

NVIDIA Apex GitHub Repo

Apex Examples

Apex Documentation (Complete)

Dynamic Loss Scaler Documentation

Inside Volta (includes Tensor Core technical details)

Mixed-Precision for Deep Neural Networks blog post

Training Neural Networks with Mixed-Precision: Theory and Practice (GTC 2018 session)

Includes important concepts, benefits, dynamic loss scaling, and results

Training Neural Networks with Mixed-Precision: Real Examples (GTC 2018 session)

Practical walkthrough showing how to spot danger points, implement master weights + loss scaling in PyTorch and Tensorflow. Does not show dynamic loss scaling.

Training with Mixed-Precision documentaiton

Fairseq paper

Sentiment analysis example

ICLR Mixed-Precision Training paper

 

No Comments