Customize optimizer in PyTorch

In this post, we are gonna implement our own optimizer instead of using built-in pytorch optimizers. Before we dive into implementation details of it, we first need to understand what optimizer in pytorch actually do.

The standard way to use optimizer during training phase can be shown as follow:

optimizer.zero_grad()
output_data = model(input_data)
// compute loss
loss.backward()
optimizer.step()

The optimizer.zero_grad() function makes sure we don't mix up different batches' gradient values. The loss.backward() performs the back-propagation algorithm, computes every tensor's gradient, and stores it into the grad attribute. The optimizer.step() will access the grad attribute in every tensor and apply a specific optimization algorithm to update the model's parameters.


Programming model

To create our own optimizer, we need to inheret from from torch.optim import Optimizer and reimplement 2 methods.

  • __init__(): defind your own optimizer's hyper parameters
  • step(): implementation of the algorithm

Here is the code template that we can follow:

class CustomOpt(Optimizer):
    def __init__(
            self,
            params: Iterable[torch.nn.parameter.Parameter],
            lr: float = 1e-3,
            betas: Tuple[float, float] = (0.9, 0.999),
            eps: float = 1e-6,
            weight_decay: float = 0.0,
            correct_bias: bool = True,
    ):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
        super().__init__(params, defaults)

    def step(self, closure: Callable = None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")

                # State should be stored in this dictionary.
                state = self.state[p]

                # Access hyperparameters from the `group` dictionary.
                alpha = group["lr"]
                
        return loss


init method

To instantiate an optimizer, we need input some parameters. Since customized optimizer is inherited from Optimizer, we must input iteratable model parameters and some hyper parameters like learning rate, moment or some logical flag to control the behavior when optimizing.

We also do some correctness check here to ensure input values land in right place.


step method

In this function, when we want to perform optimization scheme we need first iterate over param_groups.

You may curious about what is param_groups. Well, in PyTorch, param_groups is an attribute of an optimizer object that contains all the parameter groups. These groups are dictionaries specifying different sections of parameters in your model that can have distinct learning rates, weight decay, etc. It's useful when you want to apply specific optimization strategies to different parts of the model.

Here is an example of how param_groups are set up and used in a PyTorch optimizer:

import torch

# Simple model with two parameters
model = torch.nn.Sequential(
    torch.nn.Linear(10, 5),
    torch.nn.ReLU(),
    torch.nn.Linear(5, 2)
)

# Accessing all parameters of the model
all_params = model.parameters()

# Creating the optimizer with default parameter group
optimizer = torch.optim.SGD(all_params, lr=0.01)

# Suppose we want to set a different learning rate for the first and second layers
layer1_params = model[0].parameters()
layer2_params = model[2].parameters()

# Creating a new optimizer with multiple parameter groups
custom_optimizer = torch.optim.SGD([
    {'params': layer1_params, 'lr': 0.01},
    {'params': layer2_params, 'lr': 0.02}  # Different learning rate for the second layer
])

# Accessing parameter groups
for param_group in custom_optimizer.param_groups:
    print(param_group['lr'])  # This will print 0.01 for the first group and 0.02 for the second group

In the above example, custom_optimizer is set up with different learning rates for different layers of the model. This is achieved using param_groups. Each element in param_groups is a dictionary representing a single group of parameters, where you can individually adjust settings such as the learning rate.

When we call optimizer.step()

def step(self, closure: Callable = None):
    for group in self.param_groups:
        for p in group["params"]:
            # Access hyperparameters from the `group` dictionary.
            alpha = group["lr"]
            # We can access specific learning rate we set before within this group
            
    return loss

So it basically provides a convenient interface to customize our model learning process.

In optimizer, there's a very important parameter called state, which stores internal state of algorithm on the fly. In PyTorch, the state in an optimizer refers to a dictionary used to store and manage state information for each parameter that the optimizer updates. This state can include various items necessary for the optimization algorithm to function effectively over successive iterations of training. The actual contents of this state dictionary can vary depending on the specific optimizer being used.

For example, for SGD, the state is empty. But for Adam, the state includes running estimates of the first and second moments of the gradients (commonly referred to as m and v). These are essential for the algorithm to compute adaptive learning rates for each parameter.


Hands on AdamW

After understanding all of these things above (maybe there are still some that I didn't cover, but these are enough for us to kick in), we can now implement our custmoized optimizer. In this post, I choose AdamW. But first we need to take a look at Adam [2].

The Adam algorithm is a method for stochastic optimization that calculates adaptive learning rates for each parameter. Here's how it works:

Requirements:

  • : Stepsize
  • (between [0, 1]): Exponential decay rates for the moment estimates
  • f(): Stochastic objective function with parameters
  • : Initial parameter vector

Initialization:

  • Initialize the first moment vector,
  • Initialize the second moment vector,
  • Initialize the time step,

Algorithm Steps:

While the parameters have not converged:

Return:

  • The resulting parameters

Efficient Version:

For efficiency, the last three lines in the loop can be replaced with the following two lines: - Update learning rate: - Update parameters:

We use efficient version here and the code implementation is:

class AdamW(Optimizer):
    def __init__(
            self,
            params: Iterable[torch.nn.parameter.Parameter],
            lr: float = 1e-3,
            betas: Tuple[float, float] = (0.9, 0.999),
            eps: float = 1e-6,
            weight_decay: float = 0.0,
            correct_bias: bool = True,
    ):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
        super().__init__(params, defaults)

    def step(self, closure: Callable = None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")

                # State should be stored in this dictionary.
                state = self.state[p]

                # Access hyperparameters from the `group` dictionary.
                alpha = group["lr"]
                eps = group["eps"]
                beta1, beta2 = group["betas"]

                if 'step' not in state:
                    state['step'] = 0
                    state['first_momment'] = torch.zeros_like(p.data)
                    state['second_moment'] = torch.zeros_like(p.data)
                
                state['step'] += 1

                state['first_momment'] = state['first_momment'].mul_(beta1) + (1 - beta1) * grad
                state['second_moment'] = state['second_moment'].mul_(beta2)+ (1 - beta2) * grad * grad

                first_bias_correction = 1 - beta1 ** state['step']
                second_bias_correction = 1 - beta2 ** state['step']

                alpha_t = alpha * math.sqrt(second_bias_correction) / first_bias_correction

                sqrt_second_moment = torch.sqrt(state['second_moment']) + eps
                p.data = p.data - alpha_t * state['first_momment'] / sqrt_second_moment
								
                # here we use Decoupled Weight Decay Regularization [1]
                p.data = p.data - alpha * group['weight_decay'] * p.data

        return loss

Note that we need to update first_moment and second_moment in state, instead of creating copies.

During implementation, we need to be aware of what parameters need to change in state and what not, or we will make mistakes.

References

[1] Ilya Loshchilov and Frank Hutter. Decoupled weight decay regularization. arXiv preprint arXiv:1711.05101, 2017.

[2] Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.