Customize optimizer in PyTorch
In this post, we are gonna implement our own optimizer instead of using built-in pytorch optimizers. Before we dive into implementation details of it, we first need to understand what optimizer in pytorch actually do.
The standard way to use optimizer during training phase can be shown as follow:
optimizer.zero_grad()
output_data = model(input_data)
// compute loss
loss.backward()
optimizer.step()
The optimizer.zero_grad()
function makes sure we don't
mix up different batches' gradient values. The
loss.backward()
performs the back-propagation algorithm,
computes every tensor's gradient, and stores it into the
grad
attribute. The optimizer.step()
will
access the grad
attribute in every tensor and apply a
specific optimization algorithm to update the model's parameters.
Programming model
To create our own optimizer, we need to inheret from
from torch.optim import Optimizer
and reimplement 2
methods.
__init__()
: defind your own optimizer's hyper parametersstep()
: implementation of the algorithm
Here is the code template that we can follow:
class CustomOpt(Optimizer):
def __init__(
self,
params: Iterable[torch.nn.parameter.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
):
if lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
super().__init__(params, defaults)
def step(self, closure: Callable = None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
# State should be stored in this dictionary.
state = self.state[p]
# Access hyperparameters from the `group` dictionary.
alpha = group["lr"]
return loss
init method
To instantiate an optimizer, we need input some parameters. Since
customized optimizer is inherited from Optimizer
, we must
input iteratable model parameters and some hyper parameters like
learning rate, moment or some logical flag to control the behavior when
optimizing.
We also do some correctness check here to ensure input values land in right place.
step method
In this function, when we want to perform optimization scheme we need
first iterate over param_groups
.
You may curious about what is param_groups
. Well, in
PyTorch, param_groups
is an attribute of an optimizer
object that contains all the parameter groups. These groups are
dictionaries specifying different sections of parameters in your model
that can have distinct learning rates, weight decay, etc. It's useful
when you want to apply specific optimization strategies to different
parts of the model.
Here is an example of how param_groups
are set up and
used in a PyTorch optimizer:
import torch
# Simple model with two parameters
model = torch.nn.Sequential(
torch.nn.Linear(10, 5),
torch.nn.ReLU(),
torch.nn.Linear(5, 2)
)
# Accessing all parameters of the model
all_params = model.parameters()
# Creating the optimizer with default parameter group
optimizer = torch.optim.SGD(all_params, lr=0.01)
# Suppose we want to set a different learning rate for the first and second layers
layer1_params = model[0].parameters()
layer2_params = model[2].parameters()
# Creating a new optimizer with multiple parameter groups
custom_optimizer = torch.optim.SGD([
{'params': layer1_params, 'lr': 0.01},
{'params': layer2_params, 'lr': 0.02} # Different learning rate for the second layer
])
# Accessing parameter groups
for param_group in custom_optimizer.param_groups:
print(param_group['lr']) # This will print 0.01 for the first group and 0.02 for the second group
In the above example, custom_optimizer
is set up with
different learning rates for different layers of the model. This is
achieved using param_groups
. Each element in
param_groups
is a dictionary representing a single group of
parameters, where you can individually adjust settings such as the
learning rate.
When we call optimizer.step()
def step(self, closure: Callable = None):
for group in self.param_groups:
for p in group["params"]:
# Access hyperparameters from the `group` dictionary.
alpha = group["lr"]
# We can access specific learning rate we set before within this group
return loss
So it basically provides a convenient interface to customize our model learning process.
In optimizer, there's a very important parameter called state, which stores internal state of algorithm on the fly. In PyTorch, the state in an optimizer refers to a dictionary used to store and manage state information for each parameter that the optimizer updates. This state can include various items necessary for the optimization algorithm to function effectively over successive iterations of training. The actual contents of this state dictionary can vary depending on the specific optimizer being used.
For example, for SGD, the state is empty. But for Adam, the state includes running estimates of the first and second moments of the gradients (commonly referred to as m and v). These are essential for the algorithm to compute adaptive learning rates for each parameter.
Hands on AdamW
After understanding all of these things above (maybe there are still some that I didn't cover, but these are enough for us to kick in), we can now implement our custmoized optimizer. In this post, I choose AdamW. But first we need to take a look at Adam [2].
The Adam algorithm is a method for stochastic optimization that calculates adaptive learning rates for each parameter. Here's how it works:
Requirements:
: Stepsize (between [0, 1]): Exponential decay rates for the moment estimates - f(
): Stochastic objective function with parameters : Initial parameter vector
Initialization:
- Initialize the first moment vector,
- Initialize the second moment vector,
- Initialize the time step,
Algorithm Steps:
While the parameters
Return:
- The resulting parameters
Efficient Version:
For efficiency, the last three lines in the loop can be replaced with
the following two lines: - Update learning rate:
We use efficient version here and the code implementation is:
class AdamW(Optimizer):
def __init__(
self,
params: Iterable[torch.nn.parameter.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
):
if lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
super().__init__(params, defaults)
def step(self, closure: Callable = None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
# State should be stored in this dictionary.
state = self.state[p]
# Access hyperparameters from the `group` dictionary.
alpha = group["lr"]
eps = group["eps"]
beta1, beta2 = group["betas"]
if 'step' not in state:
state['step'] = 0
state['first_momment'] = torch.zeros_like(p.data)
state['second_moment'] = torch.zeros_like(p.data)
state['step'] += 1
state['first_momment'] = state['first_momment'].mul_(beta1) + (1 - beta1) * grad
state['second_moment'] = state['second_moment'].mul_(beta2)+ (1 - beta2) * grad * grad
first_bias_correction = 1 - beta1 ** state['step']
second_bias_correction = 1 - beta2 ** state['step']
alpha_t = alpha * math.sqrt(second_bias_correction) / first_bias_correction
sqrt_second_moment = torch.sqrt(state['second_moment']) + eps
p.data = p.data - alpha_t * state['first_momment'] / sqrt_second_moment
# here we use Decoupled Weight Decay Regularization [1]
p.data = p.data - alpha * group['weight_decay'] * p.data
return loss
Note that we need to update first_moment
and
second_moment
in state, instead of creating copies.
During implementation, we need to be aware of what parameters need to change in state and what not, or we will make mistakes.
References
[1] Ilya Loshchilov and Frank Hutter. Decoupled weight decay regularization. arXiv preprint arXiv:1711.05101, 2017.
[2] Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.