SGD optimizer, lr value with loop over batch and epoch, in pytorch

43 Views Asked by At

I am puzzling on how to optimize lr value, when I loop over the batch and epoch. I want to iterate for all epochs and keep the best. The same time, i iterate for batch size, and the lr value is getting updated (it should stay fixed in the batch loop), thus in the next epoch is already small.

In the code, we have X_train,X_test,y_train,y_test already.

# Define the model
model = nn.Sequential(
nn.Linear(X_train.shape[1], 24),
nn.ReLU(),
nn.Linear(24, 12),
nn.ReLU(),
nn.Linear(12, 6),
nn.ReLU(),
nn.Linear(6, 1)
)

# loss function and optimizer
loss_fn = nn.MSELoss()  # mean square error

optimizer = optim.SGD(model.parameters(), lr=0.01)

scheduler = lr_scheduler.LinearLR(optimizer, start_factor=1, end_factor=0.01, 
total_iters=100)


n_epochs = 10000   # number of epochs to run
batch_size = 64  # size of each batch
batch_start = torch.arange(0, len(X_train), batch_size)

# Hold the best model
best_mse = np.inf   # init to infinity
best_weights = None
history = []

for epoch in range(n_epochs):
    model.train()
    for start in batch_start:
        # take a batch
        X_batch = X_train[start:start+batch_size]
        y_batch = y_train[start:start+batch_size]
        # forward pass
        y_pred = model(X_batch)
        loss = loss_fn(y_pred, y_batch)
        # backward pass
        # set optimizer to zero grad
        # to remove previous epoch gradients
        optimizer.zero_grad()
        # backward propagation
        loss.backward()
        # update weights
        optimizer.step()
        # print progress
    # evaluate accuracy at end of each epoch
    model.eval()
    y_pred = model(X_test)
    mse = loss_fn(y_pred, y_test)
    mse = float(mse)
    history.append(mse)
    scheduler.step()
    # print("Epoch %d: LIN lr %.5f" % (epoch,optimizer.param_groups[0]["lr"] ))
    if mse < best_mse:
        best_mse = mse
        best_weights = copy.deepcopy(model.state_dict())

AFAIK, the lr value is updated in BOTH the optimizer.step() and scheduler.step().

How can I fix the lr value, within the batch loop (i.e. don't update the lr when the opitmizer.step() is called.

0

There are 0 best solutions below