top of page
  • Writer's pictureOmnipresent

A Simple Guide to Machine Learning: Linear Regression

Updated: Jun 19, 2020

Make accurate predictions of what will happen in the future, through regression analysis

Date: 2020/03/15

by Shan Swanlow


If you wish to use the code written below, you will need to install Python 3 and matplotlib. Instructions for their installation and usage are provided at the links above.


This guide is the first in a planned series of guides that aim to demystify machine learning and make it easy to learn (and apply) for all groups of people, whether they come from a mathematical background or not. Linear regression is one of the simplest problems in machine learning, and is therefore the easiest to enter the field with. We’ll be exploring this method as our first tutorial.

A Word on Regression

Linear regression is a type of statistical analysis within a field known as regression analysis. Regression analysis is a set of methods that are used to determine relationships between variables. In simpler terms, regression analysis is a set of tools that you can use to determine how certain variables impact other variables, and a variable, in this case, is simply a measurement you could take of something in the real world. Things such as age, income, temperature, and height, are all examples of variables. When we use regression analysis, we often want to find out what the exact relationship is between something like age and income, and using the relationship data we discover, we can begin to make accurate predictions of what will happen in the future.

Linear regression is a subset of regression analysis that attempts to model the relationship between two variables, by fitting a linear equation to the data. Do not worry if this sounds intimidating- we will explore this topic in detail, however, a basic understanding of graphing will be needed first in order to do this. We will cover graphing below.

Gathering Data

When collecting data for analysis, you will often find that most people tend to collect multiple variables at once. This is important, as it helps you get a better picture of what impacts the variable you’re trying to understand. However, not all variables are treated the same. Some variables you take may not be affected by anything at all. For example, time isn’t affected by anything- it will continue to increase at the same rate, no matter what happens. Variables that you take, that are not affected by anything, are known as independent variables. Some variables, however, are dependent on those independent variables. For example, as time increases, so does age. Variables such as age (in this case), are known as dependent variables.

Understanding Data

Once you have obtained data, the next step is to make sense of it. Sometimes the data you get has an obvious relationship- for example, if I measured time and age, it’d be clear that as time increases, so does my age. Other datasets, however, may have a relationship that is harder to see. Some things like exchange rates fluctuate wildly over time, so looking at the numbers may not give much insight. However, if you draw a graph using the data, you may be able to visualize the relationship much easier. Using matplotlib, this can be done very easily- all that we need to remember is that the independent variable is always plotted on the x-axis (i.e. horizontally).

For example, let’s look at a trivial example of sales of a product over a 5-year period:

It may be hard to see if there’s a relationship between years and sales. A graph might provide more insight into any potential patterns, so let’s use some code to sketch a graph of it. The only rule to remember, is to place the independent variable as the x-axis; we achieve this by writing its dataset first:

plt.plot([1, 2, 3, 4, 5], [100000, 125000, 110000, 120000, 120000],
 'bo') # ’bo’ means to draw the points as blue circles

Running the code above produces the following graph:

Line of Best Fit

It may be challenging to see a relationship for this, but we can ascertain that sales are generally increasing over time. Graphing data in this way also enables us to make predictions because we can draw a line of best fit. Imagine drawing a diagonal line that goes right through the middle of all the points, essentially “cutting” them in half. That is a line of best fit- something that is used to show the general trend of data. A line of best fit for this graph looks like this:

It doesn’t touch any of the points, but it helps you in understanding the general relationship between the variables. Knowing this line of best fit allows us to predict what’s likely to happen in the future.

You might wonder how a line of best fit allows us to make predictions. For all graphs with a linear relationship, there is a mathematical equation that can generate them. In other words, for the line above, there is some formula that will allow you to generate y-values (like 100000), and if you graphed the y-values with their corresponding x-values, you would get the exact same line you see above. When you have this equation, you can insert any x-value into this equation, and get a future y-value. Practically speaking, this means that if you have the graph equation, you can figure out what sales will look like in 10 or 20 years. The equation for generating these points has a general form, and it is written as:

Y = m * x + c, or in machine learning, it’s more commonly written as

Y = B0 + B1 * x

For this formula, y is the value you would plot on the y-axis (such as 100000 or 125000 for the table above), and B1 is the gradient, which is a number expressing how significantly the graph changes from point to point. X is the value on your x-axis (in this example, it’s our years), and B0 is a number that simply gets added to each value- B0 is necessary when your y values do not start from zero. In order for us to get the equation for the line of best fit, we need to calculate the gradient and find the constant, that is, calculate the values B0 and B1.

With all prerequisite knowledge out of the way, we can now begin using Machine Learning to solve linear regression.

Methods for Fitting a Linear Regression Model

Multiple methods exist to solve linear regression. Most commonly, linear algebra and stochastic gradient descent are used. The linear algebra approach involves solving (and viewing) the problem in terms of matrices, and using matrix operations to find the graph equation. We won’t be applying this method in this guide, although it’s preferred to use this than Stochastic Gradient Descent in the real world.

Stochastic Gradient Descent is the other popular method used to solve for the graph equation, and is widely used in Machine Learning. This method involves the computer looking at the data you’ve given to it, making predictions for B0 and B1, seeing how inaccurate its predictions are after plotting its newest point, and then minimizing that inaccuracy as it reaches the next data point. In other words, it learns from its mistakes and tries to lower its error as much as possible. Other variations of gradient descent exist, but we will focus on stochastic gradient descent for this tutorial.

Solving a Problem with Linear Regression

Continuing with the given dataset of years and sales, let’s attempt to find the graph equation using linear regression. Assume that we have been asked to calculate what the total sales will be in 10 years time.

Understanding Stochastic Gradient Descent

The process our code will follow will be this: we will start with some random values for B0 and B1. We will use these values for B0 and B1 to create a graph equation, and then calculate y-values for it. Once we have these predicted y-values, we can see how inaccurate they are from the true y-values; the ones we were given. We do this by subtracting our actual y-value from our predicted y-value. This is more commonly known as an error calculation. Using some mathematics, we can find out where the error will be its lowest and in other words, by knowing how to minimize the error, we can make our line calculation as accurate as possible.

Minimizing the error calculation is what stochastic gradient descent aims to do. The process is fairly straightforward. You can imagine that stochastic gradient descent will always reach the point where your error is at its lowest. Bearing this in mind, all that you need is a way of telling the computer how to reach that point. This can be done by having two numbers- a learning rate, and a training count. The learning rate (commonly referred to as alpha) is a number that helps us get to that lowest point by adjusting our predictions according to the error that was calculated. It also helps us get there faster because it scales with the size of the error; in other words, a larger error means it will create a larger adjustment, and a smaller error will create a smaller adjustment. For stochastic gradient descent, we update the predictions according to the learning rate after each point (i.e. row in the table) is processed. The training count is simply a number that we select, that tells the code how long to run the gradient descent process for. The process of continuously running gradient descent is known as training. Generally, we count training cycles in terms of complete passes through our sample data (the rows in our table). One complete pass through sample data is known as an epoch.

Note that it doesn’t matter what the starting values are for linear regression, because gradient descent ensures you’ll reach the same point no matter where you begin due to certain mathematical principles.

Coding a solution

Following the explanation of gradient descent, we can now write a solution to get values for B0, B1, and display them as we go along. For this example, we will train for 300 epochs, though you can continue to train as long as you would like.

# Initial values
x_values = [1, 2, 4, 3, 5]
y_values = [100000, 125000, 110000, 120000, 120000]
# Perform Stochastic Gradient Descent
# y = mx + c is the same as y = B0 + B1 * x
b0 = 0
b1 = 0
alpha = 0.01 # learning rate
epochs = 300 # an epoch is a complete pass through the data set, i.e. the x/y values.
error = None
print("B0 B1")
for _ in range(0, epochs):
 for i in range(0, len(x_values)):
 predicted_result = b0 + b1 * x_values[i] # computer calculates an estimate
 error = predicted_result - y_values[i] # sees how inaccurate it is
 b0 -= alpha * error # lowers the inaccuracy by adjusting according to alpha
 b1 -= alpha * error * x_values[i] # same as above- adjust B1 based on alpha
 print(str(b0) + " " + str(b1)) # display the results in the terminal

Running the above code, you should get the following output:

These are the values for B0 and B1. We can now simply place any set of B0 and B1 values into the graph equation (Y = B0 + B1 * x) and begin to make predictions. Before that, however, it would be useful to see how accurate our line of best fit is and if we’re on track. Using the last calculated values for B0 and B1, we can graph what our line of best fit looks like using the following code- ensure to paste this after the previous code:

import matplotlib.pyplot as plt
plt.plot(x_values, y_values, 'bo')
# Calculate the y values for our line of best fit, store them somewhere 
# and then graph them:
best_fit_y_values = []
for i in range(0, len(x_values)):
 best_fit_y_values.append(100615.92742500186 +
 x_values[i] * 4276.83798121538)
plt.plot(x_values, best_fit_y_values)

This should produce the following graph:

The line of best fit is fairly accurate here, and isn’t much closer to one set of points than the other, so we can use this in our prediction calculations. By placing our found values into the Y = B0 + B1 * x equation, we can predict what sales will look like in 10 years:

100615.92742500186 + 10 * 4276.83798121538 

We can therefore say that in 10 years, sales should at least be around 143,384. If you train your dataset for less or more epochs, you will obtain a different result.

Bear in mind that the line of best fit merely gives a prediction based on existing data. It cannot be said that the sales predicted in 10 years' time will be the actual sales in 10 years' time. It may be more or less due to various factors, but the prediction helps greatly for those who need to plan. As an example, in the real world, linear regression is often used to predict population sizes and allocate/budget resources accordingly.

bottom of page