Writing Your Own Optimizers in PyTorch
This article will teach you how to write your own optimizers in PyTorch - you know the kind, the ones where you can write something like
optimizer = MySOTAOptimizer(my_model.parameters(), lr=0.001)
for epoch in epochs:
for batch in epoch:
outputs = my_model(batch)
loss = loss_fn(outputs, true_values)
loss.backward()
optimizer.step()
The great thing about PyTorch is that it comes packaged with a great standard library of optimizers that will cover all of your garden variety machine learning needs. However, sometimes you’ll find that you need something just a little more specialized. Maybe you wrote your own optimization algorithm that works particularly well for the type of problem you’re working on, or maybe you’re looking to implement an optimizer from a recently published research paper that hasn’t yet made its way into the PyTorch standard library. No matter. Whatever your particular use case may be, PyTorch allows you to write optimizers quickly and easily, provided you know just a little bit about its internals. Let’s dive in.
Subclassing the PyTorch Optimizer Class
All optimizers in PyTorch need to inherit from torch.optim.Optimizer
. This is a base class which handles all general optimization machinery. Within this class,
there are two primary methods that you’ll need to override: __init__
and step
. Let’s see how it’s done.
The init Method
The __init__
method is where you’ll set all configuration settings for your
optimizers. Your __init__
method must take a params
argument which specifies
an iterable of parameters that will be optimized. This iterable must have a
deterministic ordering - the user of your optimizer shouldn’t pass in something
like a dictionary or a set. Usually a list of torch.Tensor
objects is given.
Other typical parameters you’ll specify in the __init__
method include
lr
, the learning rate, weight_decays
, betas
for Adam-based optimizers,
etc.
The __init__
method should also perform some basic checks on passed in
parameters. For example, an exception should be raised if the provided learning
rate is negative.
In addition to params
, the Optimizer
base class requires a parameter called
defaults
on initialization. This should be a dictionary mapping parameter
names to their default values. It can be constructed from the kwarg parameters
collected in your optimizer class’ __init__
method. This will be important in
what follows.
The last step in the __init__
method is a call to the Optimizer
base class.
This is performed by calling super()
using the following general signature.
super(YourOptimizerName, self).__init__(params, defaults)
Implementing a Novel Optimizer from Scratch
Let’s investigate and reinforce the above methodology using an example taken
from the HuggingFace pytorch-transformers
NLP library. They implement a PyTorch
version of a weight decay Adam optimizer from the BERT paper. First we’ll take a
look at the class definition and __init__
method. Here are both combined.
You can see that the __init__
method accomplishes all the basic requirements
listed above. It implements basic checks on the validity of all provided kwargs
and raises exceptions if they are not met. It also constructs a dictionary of
defaults from these required parameters. Finally, the super()
method is called
to initialize the Optimizer
base class using the provided params
and defaults
.
The step() Method
The real magic happens in the step()
method. This is where the optimizer’s logic
is implemented and enacted on the provided parameters. Let’s take a look at how
this happens.
The first thing to note in step(self, closure=None)
is the presence of the
closure
keyword argument. If you consult the PyTorch documentation, you’ll
see that closure
is an optional callable that allows you to reevaluate the
loss at multiple time steps. This is unnecessary for most optimizers, but is
used in a few such as Conjugate Gradient and LBFGS. According to the docs,
“the closure should clear the gradients, compute the loss, and return it”.
We’ll leave it at that, since a closure is unnecessary for the AdamW
optimizer.
The next thing you’ll notice about the AdamW
step function is that it iterates
over something called param_groups
. The optimizer’s param_groups
is a list
of dictionaries which gives a simple way of breaking a model’s parameters into
separate components for optimization. It allows the trainer of the model to
segment the model parameters into separate units which can then be optimized
at different times and with different settings. One use for multiple param_groups
would be in training separate layers of a network using, for example, different
learning rates. Another prominent use cases arises in transfer learning. When
fine-tuning a pretrained network, you may want to gradually unfreeze layers
and add them to the optimization process as finetuning progresses. For this,
param_groups
are vital. Here’s an example given in the PyTorch documentation
in which param_groups
are specified for SGD in order to separately tune the
different layers of a classifier.
Now that we’ve covered some things specific to the PyTorch internals, let’s get to the algorithm. Here’s a link to the paper which originally proposed the AdamW algorithm. And here, from the paper, is a screenshot of the proposed update rules.
Let’s go through this line by line with the source code. First, we have the loop
for p in group['params']
Nothing mysterious here. For each of our parameter groups, we’re iterating over the parameters within that group. Next.
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
This is all simple stuff as well. If there is no gradient for the current
parameter, we just skip it. Next, we get the actual plain Tensor object for
the gradient by accessing p.grad.data
. Finally, if the tensor is sparse, we
raise an error because we are not going to consider implementing this for sparse
objects.
Next, we access the current optimizer state with
state = self.state[p]
In PyTorch optimizers, the state
is simply a dictionary associated with the
optimizer that holds the current configuration of all parameters.
If this is the first time we’ve accessed the state of a given parameter, then we set the following defaults
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
We obviously start with step 0, along with zeroed out exponential average and exponential squared average parameters, both the shape of the gradient tensor.
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
Next, we gather the parameters from the state dict that will be used in the computation of the update. We also increment the current step.
Now, we begin the actual updates. Here’s the code.
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
denom = exp_avg_sq.sqrt().add_(group['eps'])
step_size = group['lr']
if group['correct_bias']: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state['step']
bias_correction2 = 1.0 - beta2 ** state['step']
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
p.data.addcdiv_(-step_size, exp_avg, denom)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group['weight_decay'] > 0.0:
p.data.add_(-group['lr'] * group['weight_decay'], p.data)
The above code corresponds to equations 6-12 in the algorithm implementation from the paper. Following along with the math should be easy enough. What I’d like to take a closer look at is the built in Tensor methods that allow us to do the in-place computations.
A nice, relatively hidden feature of PyTorch which you might not be aware of is
that you can access any of the standard PyTorch functions, e.g. torch.add()
,
torch.mul()
, etc. as in-place operations on the Tensors directly by appending
an _
to the method name. Thus, taking a closer look at the first update, we
find we can quickly compute it as
exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
rather than
torch.mul(beta1, torch.add(1.0 - beta1, grad))
Of course, there are a few special operations used here with which you may not
be familiar, for example, Tensor.addcmul_
and Tensor.addcdiv_
. This takes the
input and adds it to either the product or dividend, respectively, of the two
latter inputs. If you need a more in-depth rundwon of the various operations
available to be performed on Tensor
objects, I highly recommend checking out
this post.
You’ll also see that the learning rate is accessed in the last line in the computation of the final result. This loss is then returned.
And…that’s it! Constructing your own optimizers is as simple as that. Of course, you need to devise your own optimization algorithm first, which can be a little bit trickier ;). I’ll leave that one to you.
Special thanks to the authors of Hugging Face for implementing the AdamW
optimizer in PyTorch.