Neural Network Backpropagation Derivation

I have spent a few days hand-rolling neural networks such as CNN and RNN. This post shows my notes of neural network backpropagation derivation. The derivation of Backpropagation is one of the most complicated algorithms in machine learning. There are many resources for understanding how to compute gradients using backpropagation. But in my opinion, most of them lack a simple example to demonstrate the problem and walk through the algorithm.

1. A Simple Neural Network

The following diagram shows the structure of a simple neural network used in this post.


2. Cost Function

A cost function reflects the distance between the ground truth and the predicted values. Given a simple sum squared error function as the cost function,


the goal is to find the best W which leads to the lowest cost J. The cost function is often defined in a way which is mathematically convenient. Apparently, there could be multiple ways to evaluate the distance between the truth and the predicted values. Each cost function has its own applicable area.

In a simplified two dimenstional space, the related between the cost J and W may look like:

We keep updating W, so that the cost moves to the lowest point. The algorithm is called gradient descent. By calculating the partial derivatives for the W, we can update W and make the cost lower.


In a real scenario, it would be a multi-dimensional space (hard to visualize here). Each move will lead to a direction in a multi-dimensional space.

3. Partial Derivative Calculation

Let's first look at the second layer of the network.



(Note that f'(z) can be further simplified.)

Similarly, we can calculate the weight values for the first layer.



4. Visualization of the Calculation

Let's take the weights of the second layer for example.

We can see first how each sample contributes to the gradient, and then see how all 5 training samples contribute to the gradient.


The dimension of the result is 3 by 1 which is the same of W2.

Category >> Others  
If you want someone to read your code, please put the code inside <pre><code> and </code></pre> tags. For example:
String foo = "bar";