-Adwita Singh

This blog is going to be focussed more towards the coding of the backward pass than the general math required (simple chain rule from high school partial derivation tbh. If you know that, you’re good to go).

Let us suppose we have a random net that looks as follows (don’t ponder too hard upon why I chucked random activations into it):

class RandomNet:
	def **init**(self):
		self.w1 = torch.randn(4, 6)   # input, hidden
		self.w2 = torch.randn(6, 5)   # hidden, hidden
		self.w3 = torch.randn(5, 2)   # hidden, output
	
	def forward(self, x):
	    z1 = x @ self.w1
	    a1 = torch.relu(z1)
	
	    z2 = a1 @ self.w2   
	    a2 = torch.tanh(z2)
	
	    z3 = a2 @ self.w3
	    return z1, a1, z2, a2, z3
	

Our first goal always, in my opinion, if to check the dimensions of each transformation, activation and weight matrix here. In that way, while writing down the backward pass, we can easily check for the order of matrix multiplication.

x -> input(batch_size, 4)
z1 -> matrix transformation (batch_size, 6)
a1 -> Relu activation (batch,6)
w1 -> weights to z1 (4,6)

z2 -> matrix multiplication (batch_size, 5)
a2-> tanh activation (batch_size, 5)
w2 -> weight to z2 (6,5)

z3 -> matrix transformation (batch_size, 2)
w3 -> weights to z3 (5,2)

Once we have checked through the dimensions, we use chain rule to compute the derivative of the loss with respect to the variable whose derivative is to be computed.

IMG_0062.jpeg

We compute the loss after the output given by activation z3 with the help of Mean Squared Loss (although any formula could have been used. Cross-entropy is one that is commonly used):

IMG_0063.jpeg

Once we have our Loss Function Sorted, we get the difference between our predicted values and the true values. This gives us how far our neural net is from it’s ideal form. Once we have the loss, we can compute it’s derivative with respect to the weights (w1, w2, and w3) of the layers of the network.

We represent these derivatives as dw1, dw2 and dw3. To calculate the change in weights, we must first compute the ∂Loss/∂z3, since z3 is the final activation. It must be noted that the backward pass takes place in the opposite direction of the forward pass (…which is obvious I suppose).

IMG_0064.jpeg

IMG_0065.jpeg