Before working on something more complex, where I knew I would have to implement my own backward
pass, I wanted to try something nice and simple. So, I tried to do linear regression with mean squared error loss using PyTorch. This went wrong (see third implementation option below) when I defined my own backward
method and I suspect it's because I'm not thinking very clearly about what I need to send PyTorch as gradients. So, I suspect what I need is some explanation/clarification/advice on what PyTorch expects me to provide in what form here.
I am using PyTorch 1.7.0, so a bunch of old examples no longer work (different way of working with user-defined autograd functions as described in the documentation).
First approach (standard PyTorch MSE loss function)
Let's first do it the standard way without a custom loss function:
import torch
import torch.nn as nn
import torch.nn.functional as F
# Let's generate some fake data
torch.manual_seed(42)
resid = torch.rand(100)
inputs = torch.tensor([ [ xx ] for xx in range(100)] , dtype=torch.float32)
labels = torch.tensor([ (2 + 0.5*yy + resid[yy]) for yy in range(100)], dtype=torch.float32)
# Now we define a linear regression model
class linearRegression(torch.nn.Module):
def __init__(self, inputSize, outputSize):
super(linearRegression, self).__init__()
self.bn = torch.nn.BatchNorm1d(num_features=1)
self.linear = torch.nn.Linear(inputSize, outputSize)
def forward(self, inx):
x = self.bn(inx) # Adding BN to standardize input helps us use a higher learning rate
x = self.linear(x)
return x
model = linearRegression(1, 1)
# Using the standard mse_loss of PyTorch
epochs = 25
mseloss = F.mse_loss
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
for epoch in range(epochs):
model.train()
optimizer.zero_grad()
outputs = model(inputs)
loss = mseloss(outputs.view(-1), labels)
loss.backward()
optimizer.step()
scheduler.step()
print(f'epoch {epoch}, loss {loss}')
This train just fine and I get to a loss of about 0.0824 and a plot of the fit looks fine.
Second approach (custom loss function, but relying on PyTorch's automatic gradient calculation)
So, now I replace the loss function with my own implementation of the MSE loss, but I still rely on PyTorch autograd. The only things I change here are defining the custom loss function, correspondingly defining the loss based on that, and a minor detail for how I hand over the predictions and true labels to the loss function.
#######################################################3
class MyMSELoss(nn.Module):
def __init__(self):
super(MyMSELoss, self).__init__()
def forward(self, inputs, targets):
tmp = (inputs-targets)**2
loss = torch.mean(tmp)
return loss
#######################################################3
model = linearRegression(1, 1)
mseloss = MyMSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
for epoch in range(epochs):
model.train()
outputs = model(inputs)
loss = mseloss(outputs.view(-1), labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
scheduler.step()
print(f'epoch {epoch}, loss {loss}')
This gives completely identical results as using the standard MSE loss function. Loss over epochs looks like this:
epoch 0, loss 884.2006225585938
epoch 1, loss 821.930908203125
epoch 2, loss 718.7732543945312
epoch 3, loss 538.1835327148438
epoch 4, loss 274.50909423828125
epoch 5, loss 55.115299224853516
epoch 6, loss 2.405021905899048
epoch 7, loss 0.47621214389801025
epoch 8, loss 0.1584305614233017
epoch 9, loss 0.09725229442119598
epoch 10, loss 0.0853077694773674
epoch 11, loss 0.08297089487314224
epoch 12, loss 0.08251354098320007
epoch 13, loss 0.08242412656545639
epoch 14, loss 0.08240655809640884
epoch 15, loss 0.08240310847759247
epoch 16, loss 0.08240246027708054
epoch 17, loss 0.08240233361721039
epoch 18, loss 0.08240240067243576
epoch 19, loss 0.08240223675966263
epoch 20, loss 0.08240225911140442
epoch 21, loss 0.08240220695734024
epoch 22, loss 0.08240220695734024
epoch 23, loss 0.08240220695734024
epoch 24, loss 0.08240220695734024
Third approach (custom loss function with my own backward method)
Now, the final version, where I implement my own gradients for the MSE. For that I define my own backward
method in the loss function class and apparently need to do mseloss = MyMSELoss.apply
.
from torch.autograd import Function
#######################################################
class MyMSELoss(Function):
@staticmethod
def forward(ctx, y_pred, y):
ctx.save_for_backward(y_pred, y)
return ( (y - y_pred)**2 ).mean()
@staticmethod
def backward(ctx, grad_output):
y_pred, y = ctx.saved_tensors
grad_input = torch.mean( -2.0 * (y - y_pred)).repeat(y_pred.shape[0])
# This fails, as does grad_input = -2.0 * (y-y_pred)
# I've also messed around with the sign and that's not the sole problem, either.
return grad_input, None
#######################################################
model = linearRegression(1, 1)
mseloss = MyMSELoss.apply
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
for epoch in range(epochs):
model.train()
outputs = model(inputs)
loss = mseloss(outputs.view(-1), labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
scheduler.step()
print(f'epoch {epoch}, loss {loss}')
This is where things go wrong and instead of the training loss decreasing, I get increasing training loss. Now it looks like this:
epoch 0, loss 884.2006225585938
epoch 1, loss 3471.384033203125
epoch 2, loss 47768555520.0
epoch 3, loss 1.7422577779621402e+33
epoch 4, loss inf
epoch 5, loss nan
epoch 6, loss nan
epoch 7, loss nan
epoch 8, loss nan
epoch 9, loss nan
epoch 10, loss nan
epoch 11, loss nan
epoch 12, loss nan
epoch 13, loss nan
epoch 14, loss nan
epoch 15, loss nan
epoch 16, loss nan
epoch 17, loss nan
epoch 18, loss nan
epoch 19, loss nan
epoch 20, loss nan
epoch 21, loss nan
epoch 22, loss nan
epoch 23, loss nan
epoch 24, loss nan
question from:
https://stackoverflow.com/questions/65947284/loss-with-custom-backward-function-in-pytorch-exploding-loss-in-simple-mse-exa