Description

You know that gradient descent is a numerical method for minimizing a function. You also know that when you are doing regression, you want to minimize an error function. Thus, you use gradient descent to minimize the error function, genius!

Now, recall that gradient descent requires you to repeatedly calculate (evaluate) the (symbolic) gradient (and then move your parameters in that direction a little bit). Calculating (evaluating) the gradient of an error function is expensive, and the more data points you have, the more expensive it gets. In other words

The expense of calculating the gradient of an error function (wrt a single variable!) scales with respect to the number of data points you have!

This is because the derivative of the error function with respect to a single variable, has a term for every feature of every data. We can reduce this down to every feature of only 1 data by

simply pretending we only have 1 data!

Sounds silly, but that is exactly what stochastic gradient descent does. Every time we need to calculate a gradient, we pick a random data point from our set, and pretend that is the only data we have when we calculate the gradient.

Example

First,

please don’t be alarmed at the large amount of large equations below.

They are all very simple to understand and follow. They represent simple concepts. Like all mathematical expressions, they communicate basic logic, but you cannot depend on just looking at the equations and getting it. You must read the surrounding context, and take your time to pause and ponder. Often, mathematical expressions are terse and packed with info. Once you understand the surrounding context, the equations are nothing but formalizations of those concepts.

Now, let’s make our discussion of stochastic gradient descent a bit more concrete by doing an example regression where we will use stochastic gradient descent to minimize our error function.

Let’s say that we have 3 dimensional data points that we are trying to find a pattern in. We will assume that our output is a linear combination of the 3 inputs, thus all we have to do is find 3 weights1. Thus, we have an error function with 3 parameters. We will call these to . Here is what our error function looks like:

Where

  • is a training instance (i.e. a single training example)
  • is the actual output of that training example
  • is the predicted output of that training example (with our weights)

Basically, we iterate through each of our training examples, and see how different the predicted output is from the actual output. We square them to 1) ignore positive from negative differences and 2) penalize bigger differences more than smaller ones.

Let’s refine

Where is the value for the 1st feature of data point . is the value for the 2nd feature of data point , and so on.

Let’s plug back into our error function

Let’s get the derivative of this (the error function) wrt

First, we will use the chain rule

Since we are taking the derivative wrt to , we can pretend all other weights are constants. Thus we can cross out some terms:

And so we are left with

Notice,

the derivative of our error function with respect to a single variable has a term in it for every feature of every data point.

Since we are doing stochastic gradient descent, each time we need to calculate the gradient, we pick a single data point from our data, and pretend that is all we have. Thus each time we need to calculate a gradient, we still have a term for every feature, but of only 1 data! It is critical to note that

we should pick a random data point each time we need to calculate the gradient

We should not pick the same data over and over.

That is it!

Summary

Once you understand regular gradient descent, and what error functions typically look like (at least for linear functions), understanding stochastic gradient descent is extremely easy. It’s just telling you

that every time you need to calculate the gradient, randomly pick one of your data points, and pretend that is all you have

Notes

Stochastic gradient descent is much less likely to get stuck in a local minima. This is because the gradient (the direction that you’re going to move in) is always calculated (evaluated) based on a single random data point, thus you have a high probability of getting pulled out of local minima.


  1. Or 4, if you wanna include a bias. But that will slightly complicate things, and thus require more words to explain, thus making this article bigger, which as you know, I do not like.