import matplotlib.pyplot as plt
Lab: Matplotlib
ACTL3143 & ACTL5111 Deep Learning for Actuaries
Matplotlib is a Python library for creating high-quality data visualisations. It can be used to build a wide variety of charts, and in this tutorial we will explore how to build line plots, scatter plots, bar plots, and histograms. Charts built using Matplotlib are highly customisable.
As a data scientist, the ability to visualise your data effectively is important as it allows you to develop a deep understanding and relationship with your data. You’ll be able to see potential trends and data characteristics that you can incorporate or account for in your modelling later.
Once Matplotlib is installed, you can import it into your Python program:
Note that we specifically need to import pyplot
as opposed to Matplotlib itself. This is because pyplot is an interface for Matplotlib that enables the library to work more like MATLAB, in which you will first initialise the figure and then each function makes some change to that figure (source: https://matplotlib.org/stable/tutorials/introductory/pyplot.html).
Basic plot types
Line plot
Pyplot’s plot()
function will create a line plot:
# Create sample data
= [-2,-1,0,1,2]
x = [-4,-2,0,2,4]
y
# Create line plot
plt.plot(x,y)
As you can see, we have created a simple line plot. We can customise this by adding a title, customising the x- and y-axes, and even changing the colour of the line:
# Create sample data
= [-2,-1,0,1,2]
x = [-4,-2,0,2,4]
y
# Create line plot
= "purple")
plt.plot(x,y, color
# Add title
"Plot of y = 2x")
plt.title(
# Add axes labels
"x")
plt.xlabel("y") plt.ylabel(
Text(0, 0.5, 'y')
You can also add multiple lines to a plot:
# Create sample data
= [-2,-1,0,1,2]
x = [-4,-2,0,2,4]
y1 = [6,3,0,-3,-6]
y2
# Create line plot
= "purple")
plt.plot(x,y1, color = "green")
plt.plot(x,y2, color
# Add title
"Plots of y = 2x and y = -3x")
plt.title(
# Add axes labels
"x")
plt.xlabel("y") plt.ylabel(
Text(0, 0.5, 'y')
Scatter plot
We use plt.scatter()
to put together a scatter plot:
# Create sample data
= [0, 1, 2, 3, 4, 5]
x = [0, 1, 4, 9, 16, 25]
y
# Create scatter plot
plt.scatter(x, y)
# Add title
"Scatter plot of y = x^2, x >= 0")
plt.title(
# Add axes labels
"x")
plt.xlabel("y") plt.ylabel(
Text(0, 0.5, 'y')
Bar plot
We use plt.bar()
to put together a bar plot:
# Create sample data
= [1, 2, 3, 4, 5]
x = [1, 4, 9, 16, 25]
y
# Create scatter plot
plt.bar(x, y)
# Add title
"Bar plot of y = x^2, x >= 0")
plt.title(
# Add axes labels
"x")
plt.xlabel("y") plt.ylabel(
Text(0, 0.5, 'y')
Histogram
We use plt.hist()
to put together a histogram.
# Create sample data
= [1.2,1.5,1.7,2,2.1,2.2,2.8,3.6,4.1,4.4,4.9]
x
# Create histogram
plt.hist(x)
# Add title
"Histogram")
plt.title(
# Add axes labels
"x")
plt.xlabel("Frequency") plt.ylabel(
Text(0, 0.5, 'Frequency')
plt.hist()
will automatically set the bin widths for you.
Advanced plot customisation
Histogram bin settings
While we are on the topic of histograms, let’s customise the histogram we have just created, specifically in terms of the bins.
You can set the number of bins that the histogram can have using the bins
argument in plt.hist()
:
# Create sample data
= [1.2,1.5,1.7,2,2.1,2.2,2.8,3.6,4.1,4.4,4.9]
x
# Create histogram with 4 bins
= 4)
plt.hist(x, bins
# Add title
"Histogram, 4 bins")
plt.title(
# Add axes labels
"x")
plt.xlabel("Frequency") plt.ylabel(
Text(0, 0.5, 'Frequency')
Alternatively, you can set custom bin edges:
# Create sample data
= [1.2,1.5,1.7,2,2.1,2.2,2.8,3.6,4.1,4.4,4.9]
x
# Set custom bin edges
= [0,1.5,3,4,5]
bin_edges
# Create histogram with 4 bins of custom width
= bin_edges, edgecolor = "black")
plt.hist(x, bins
# Add title
"Histogram, 4 bins custom")
plt.title(
# Add axes labels
"x")
plt.xlabel("Frequency") plt.ylabel(
Text(0, 0.5, 'Frequency')
Editing axes
Let’s go back to our line plot of y = 2x:
# Create sample data
= [-2,-1,0,1,2]
x = [-4,-2,0,2,4]
y
# Create line plot
= "purple")
plt.plot(x,y, color
# Add title
"Plot of y = 2x")
plt.title(
# Add axes labels
"x")
plt.xlabel("y") plt.ylabel(
Text(0, 0.5, 'y')
Notice that the tick marks for both the x- and y-axes are quite close together. You might prefer this as it gives you more granularity, however, some may find this quite cluttered. We can edit the axes tick marks (as well as the axes limits) using the plt.xticks()
and plt.yticks()
functions.
# Create sample data
= [-2,-1,0,1,2]
x = [-4,-2,0,2,4]
y
# Create line plot
= "purple")
plt.plot(x,y, color
# Add title
"Plot of y = 2x")
plt.title(
# Edit tick marks
range(-2,3))
plt.xticks(-4,-2,0,2,4])
plt.yticks([
# Add axes labels
"x")
plt.xlabel("y") plt.ylabel(
Text(0, 0.5, 'y')
As you can see, the x- and y-axes do look significantly cleaner. We can improve how easy it is to see certain values by adding a grid using plt.grid(True)
# Create sample data
= [-2,-1,0,1,2]
x = [-4,-2,0,2,4]
y
# Create line plot
= "purple")
plt.plot(x,y, color
# Add title
"Plot of y = 2x")
plt.title(
# Edit tick marks
range(-2,3))
plt.xticks(-4,-2,0,2,4])
plt.yticks([
# Add grid
True)
plt.grid(
# Add axes labels
"x")
plt.xlabel("y") plt.ylabel(
Text(0, 0.5, 'y')
You can also edit the x- and y-axis limits by using plt.xlim()
and plt.ylim()
:
# Create sample data
= [-2,-1,0,1,2]
x = [-4,-2,0,2,4]
y
# Create line plot
= "purple")
plt.plot(x,y, color
# Add title
"Plot of y = 2x")
plt.title(
# Set axis limits
-3,3))
plt.xlim((-5,5))
plt.ylim((
# Add axes labels
"x")
plt.xlabel("y") plt.ylabel(
Text(0, 0.5, 'y')
Formatting text
To format the text in a plot created using Matplotlib, you can use the fontsize
and fontweight
arguments of the various text functions, such as title
, xlabel
, and ylabel
. These arguments allow you to specify the font size and font weight (i.e. thickness) of the text, respectively.
# Create sample data
= [-2,-1,0,1,2]
x = [-4,-2,0,2,4]
y
# Create line plot
= "purple")
plt.plot(x,y, color
# Add title and bold it
"Plot of y = 2x", fontweight = 'bold')
plt.title(
# Add axes labels and set their font sizes to 15
"x", fontsize = 15)
plt.xlabel("y", fontsize = 15) plt.ylabel(
Text(0, 0.5, 'y')
You can use the fontstyle
argument to specify whether you would like to italicise your text. The fontfamily
argument allows you to specify the font family, such as “serif”, “sans-serif”, or “monospace”. If you want to use a specific font, you can use the fontname
argument instead.
# Create sample data
= [-2,-1,0,1,2]
x = [-4,-2,0,2,4]
y
# Create line plot
= "purple")
plt.plot(x,y, color
# Add title and bold it
"Plot of y = 2x", fontstyle = 'italic')
plt.title(
# Add axes labels and set their font sizes to 15
"This is the x-axis", fontsize = 15, fontfamily = 'monospace')
plt.xlabel("This is the y-axis",
plt.ylabel(= 15,
fontsize = 'serif') fontfamily
Text(0, 0.5, 'This is the y-axis')
Adding a legend
You can add a legend to your plot using the plt.legend()
argument. Notice that to label the lines in your plot, you need to use the label
argument in the plt.plot()
function, rather than through the legend
function itself:
# Create sample data
= [-2,-1,0,1,2]
x = [-4,-2,0,2,4]
y1 = [6,3,0,-3,-6]
y2
# Create line plot
= "purple", label = "y = 2x")
plt.plot(x,y1, color = "green", label = "y = -3x")
plt.plot(x,y2, color
# Add title
"Plots of y = 2x and y = -3x")
plt.title(
# Add a legend to the top right hand corner
="upper right")
plt.legend(loc
# Add axes labels
"x")
plt.xlabel("y") plt.ylabel(
Text(0, 0.5, 'y')
Subplots
If you want to visualise multiple plots at a time in the form of a grid, you can use the plt.subplots() function:
# Create sample data
= [-2,-1,0,1,2]
x = [-4,-2,0,2,4]
y1 = [6,3,0,-3,-6]
y2
# Create 1x2 grid of charts, with a figure size of 16x9 units
= plt.subplots(nrows=1, ncols=2, figsize=(16, 9))
fig, ax
# Create a line plot in each space on the grid
0].plot(x,y1, color = "purple")
ax[1].plot(x,y2, color = "green") ax[
Saving plots
Use the plt.savefig()
function to save your plots. This function takes in the name of the file that you want to save your chart to. Because of this, you can save a chart to various formats including PNG, JPEG, and TIFF.
Let’s fully build our line chart and save it to linechart.png
.
# Create sample data
= [-2,-1,0,1,2]
x = [-4,-2,0,2,4]
y1 = [6,3,0,-3,-6]
y2
# Create line plot
= "purple", label = "y = 2x")
plt.plot(x,y1, color = "green", label = "y = -3x")
plt.plot(x,y2, color
# Add title
"Plots of y = 2x and y = -3x")
plt.title(
# Add a legend to the top right hand corner
="upper right")
plt.legend(loc
# Edit tick marks
range(-2,3))
plt.xticks(-4,-2,0,2,4])
plt.yticks([
# Add grid
True)
plt.grid(
# Add axes labels
"x")
plt.xlabel("y")
plt.ylabel(
# Save chart
"linechart.png") plt.savefig(
The chart should now appear in the file explorer pane in Google Colab.