Backpropagation and the chain rule

In order to define an function approximation, we need to define a loss function, such as the mean-squared-error often used in linear regression. To minimise such a (generally highly complicated) loss function, gradient descent is a natural numerical algorithm.

OK -- so how does one actually calculate these gradients efficiently? Consider the following neural network (where each box represents a tensor, e.g. vector $X$s, matrix $W$s, whatever $y$ and scalar $L$ -- and arrows indicate that something is being fed as input through a function):

Well, you might observe that the tensors are in a chain. What do we do when we see a chain and we want to differentiate stuff? We use the chain rule.

To be more specific, we're interested in differentiating -- i.e. taking the gradient of -- the function that takes in the weights and outputs the loss (i.e. the loss function). This function can be understood as the composition of several functions -- specifically all the red arrows. But really, descending in the gradient direction is the same as descending in the respective derivative directions of each parameter. I.e. it suffices to talk about:

$$\frac{\partial L}{\partial W_2}=\frac{\partial L}{\partial X_2}\frac{\partial X_2}{\partial W_2}$$$$\frac{\partial L}{\partial W_1}=\frac{\partial L}{\partial X_2}\frac{\partial X_2}{\partial X_1}\frac{\partial X_1}{\partial W_1}$$
Or in general, for a network with layers $X_0,...X_m$ and $L=:X_{m+1}$:

$$\frac{\partial L}{\partial W_i} = \left[\prod_{k=m}^i\frac{\partial X_{k+1}}{\partial X_k}\right]\frac{\partial X_i}{\partial W_i}$$
Note that each item in this product is a tensor, i.e. involves taking tensor derivatives. This is why you see ML programming packages labeled stuff like "TensorFlow" -- what they do is keep track of derivatives for you.

Keeping track of derivatives and computing products on the spot is better than trying to come up with a general expression for the derivatives, because a generic neural network may be much more complicated than the one we've described, and may have arrows that skip a layer, etc. for which the composition doesn't even yield a matrix multiplication. In general, we may have any sort of operations involved in the network (activation functions are an obvious one), and as long as we can differentiate them, we can keep track of what we need to multiply. This algorithmic use of the chain rule is called backpropagation. 

Note how the chain rule itself has no problem at all with computing gradients for multiple data points -- we still have a loss function that is a function of the network's parameters. But with our algorithmic use, we'll need to form an expression for the gradient of the batch from each feed-forward's gradient. At least if the loss function is additive, so are the gradients.

No comments:

Post a Comment