Backpropagation
CSE 891: Deep Learning
Vishnu Boddeti
class ComputationalGraph(object):
# ...
def forward(inputs):
# 1. [pass inputs to input gates]
# 2. forward the computational graph
for gate in self.graph_nodes_topologically_sorted():
gate.forward()
return loss # final gate in the graph outputs the loss
def backward(loss):
for gate in reversed(self.graph_nodes_topologically_sorted()):
gate.backward() # chain rule applied
return input_gradients
class Multiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x,y)
z = x * y
return z
@staticmethod
def backward(ctx, grad_z):
x, y = ctx.saved_tensors
grad_x = y * grad_z # dz/dx * dL/dz
grad_y = x * grad_z # dz/dy * dL/dz
return grad_x, grad_y