Is it something wrong with torch.optim.SGD with momentum

622 Views Asked by At

It seems that torch.optim.SGD may have a bug when momentum is added. From my understanding, one can implement SGD with momentum by simply providing some value for the momentum argument, such as

torch.optim.SGD(params, lr=0.01, momentum=0.9)

I suspect a potential bug because I try to replicate the pytorch lightning tutorial regarding optimizer here. Rather than implementing optimizers from scratch as in the tutorial, I used the function from torch.optim directly. In particular, in [32], I replaced

SGDMom_points = train_curve(lambda params: SGDMomentum(params, lr=10, momentum=0.9))

by

SGDMom_points = train_curve(lambda params: torch.optim.SGD(params, lr=10, momentum=0.9))

The result appears to be much worse than indicated here. The result for Nesterov accelerated gradient as implemented below appears to be worse than one may expect as well.

NAG_points = train_curve(lambda params: torch.optim.SGD(params, lr=10, momentum=0.9, nesterov=True))

Maybe my understanding was incorrect. But I can't really spot anything suspicious and I really appreciate someone else can check if they get the same discrepancy.


As requested in the comment, I have added the code below. Please note that nothing has changed except the few lines to compute SGD_points, SGDMom_points, and Adam_points

from matplotlib import cm
import seaborn as sns
from matplotlib import pyplot as plt
import torch
import numpy as np

def pathological_curve_loss(w1, w2):
    # Example of a pathological curvature. There are many more possible, feel free to experiment here!
    x1_loss = torch.tanh(w1) ** 2 + 0.01 * torch.abs(w1)
    x2_loss = torch.sigmoid(w2)
    return x1_loss + x2_loss

def plot_curve(
    curve_fn, x_range=(-5, 5), y_range=(-5, 5), plot_3d=False, cmap=cm.viridis, title="Pathological curvature"
):
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d') if plot_3d else fig.gca()
#     ax = fig.gca(projection="3d") if plot_3d else fig.gca()

    x = torch.arange(x_range[0], x_range[1], (x_range[1] - x_range[0]) / 100.0)
    y = torch.arange(y_range[0], y_range[1], (y_range[1] - y_range[0]) / 100.0)
    x, y = torch.meshgrid([x, y])
    z = curve_fn(x, y)
    x, y, z = x.numpy(), y.numpy(), z.numpy()

    if plot_3d:
        ax.plot_surface(x, y, z, cmap=cmap, linewidth=1, color="#000", antialiased=False)
        ax.set_zlabel("loss")
    else:
        ax.imshow(z.T[::-1], cmap=cmap, extent=(x_range[0], x_range[1], y_range[0], y_range[1]))
    plt.title(title)
    ax.set_xlabel(r"$w_1$")
    ax.set_ylabel(r"$w_2$")
    plt.tight_layout()
    return ax


# sns.reset_orig()
# _ = plot_curve(pathological_curve_loss, plot_3d=True)
# plt.show()

from torch import nn

def train_curve(optimizer_func, curve_func=pathological_curve_loss, num_updates=100, init=[5, 5]):
    """
    Args:
        optimizer_func: Constructor of the optimizer to use. Should only take a parameter list
        curve_func: Loss function (e.g. pathological curvature)
        num_updates: Number of updates/steps to take when optimizing
        init: Initial values of parameters. Must be a list/tuple with two elements representing w_1 and w_2
    Returns:
        Numpy array of shape [num_updates, 3] with [t,:2] being the parameter values at step t, and [t,2] the loss at t.
    """
    weights = nn.Parameter(torch.FloatTensor(init), requires_grad=True)
    optim = optimizer_func([weights])

    list_points = []
    for _ in range(num_updates):
        loss = curve_func(weights[0], weights[1])
        list_points.append(torch.cat([weights.data.detach(), loss.unsqueeze(dim=0).detach()], dim=0))
        optim.zero_grad()
        loss.backward()
        optim.step()
    points = torch.stack(list_points, dim=0).numpy()
    return points


# BEGIN only place changed from https://pytorch-lightning.readthedocs.io/en/stable/deploy/production_intermediate.html
SGD_points = train_curve(lambda params: torch.optim.SGD(params, lr=10))
SGDMom_points = train_curve(lambda params: torch.optim.SGD(params, lr=10, momentum=0.9))
Adam_points = train_curve(lambda params: torch.optim.Adam(params, lr=1))
# END only place changed from https://pytorch-lightning.readthedocs.io/en/stable/deploy/production_intermediate.html

all_points = np.concatenate([SGD_points, SGDMom_points, Adam_points], axis=0)
ax = plot_curve(
    pathological_curve_loss,
    x_range=(-np.absolute(all_points[:, 0]).max(), np.absolute(all_points[:, 0]).max()),
    y_range=(all_points[:, 1].min(), all_points[:, 1].max()),
    plot_3d=False,
)
ax.plot(SGD_points[:, 0], SGD_points[:, 1], color="red", marker="o", zorder=1, label="SGD")
ax.plot(SGDMom_points[:, 0], SGDMom_points[:, 1], color="blue", marker="o", zorder=2, label="SGDMom")
ax.plot(Adam_points[:, 0], Adam_points[:, 1], color="grey", marker="o", zorder=3, label="Adam")
plt.legend()
plt.show()
1

There are 1 best solutions below

0
user559678 On

Okay, it turns out that PyTorch exactly implemented as described in the Note here. If I modified the tutorial presentation to

class SGDMomentum(OptimizerTemplate):
    def __init__(self, params, lr, momentum=0.0):
        super().__init__(params, lr)
        self.momentum = momentum  # Corresponds to beta_1 in the equation above
        self.param_momentum = {p: torch.zeros_like(p.data) for p in self.params}  # Dict to store m_t

    def update_param(self, p):
#         self.param_momentum[p] = (1 - self.momentum) * p.grad + self.momentum * self.param_momentum[p]
        self.param_momentum[p] = p.grad + self.momentum * self.param_momentum[p]
        p_update = -self.lr * self.param_momentum[p]
        p.add_(p_update)

The two implementations now match precisely.

Or if we keep the original SGDMomentum implementation, the LR of the pytorch implementation is actually 1/(1-self.momentum) larger than the tutorial presentation. So if we change

SGDMom_points = train_curve(lambda params: torch.optim.SGD(params, lr=10, momentum=0.9))

to

SGDMom_points = train_curve(lambda params: torch.optim.SGD(params, lr=1, momentum=0.9))

It matches the curve in the tutorial precisely as well.