Visualize data with seaborn¶

By John Kirch

In this tutorial, you will learn how to visualize data by using the seaborn library that extends the functionality of matplotlib. While learning, you will perform the following tasks:

  • Plot the cumulative monthly US passenger flights for more than a decade
  • Prepare views of Titanic passenger distributions based on various categories
  • Create combined and separate views of data distributions for different species of penguins
  • Estimate bill length or bill depth respectively using separate linear regression models for each species of penguin
  • Predict the survival probability of Titanic passengers using logistic regression based on age but separated by gender

Prerequisites¶

Before you start, make sure that:

  • You have installed DataSpell. This tutorial was created in DataSpell 2023.1.3. INFO: You can download DataSpell and use all its features for free during a 30-day trial period. Also, consider taking part in the DataSpell Early Access Program.
  • You have Python 3.6 or newer on your computer. If you're using macOS or Linux, your computer already has Python installed. You can get Python from python.org.
  • You have installed the seaborn library.

Transitioning from matplotlib to seaborn¶

If you already worked your way through the previous tutorial, Visualize data with matplotlib, you may recall the last chart that plotted total flights by month from more than 12 years of US airport traffic statistics.

In the following example, you will create the same chart using the seaborn library. First, you need to add import seaborn as sns to the beginning of your first code cell.

In [ ]:
import seaborn as sns
import pandas as pd

# Read file
data = pd.read_csv('airlines.csv')

# Drop columns containing string data that cannot be aggregated
data = data.drop(columns=['Airport.Code', 'Airport.Name', 'Time.Label', 'Statistics.Carriers.Names'])

# Aggregate flight statistics by month, regardless of airport or year
data = data.groupby('Time.Month Name').sum().sort_values(by='Time.Month')

# Define the dimensions of the seaborn plot
width = 12
height = 4

# Set the dimensions of the seaborn plot
sns.set(rc = {'figure.figsize':(width, height)})

# Create the line plot
by_month = sns.lineplot(data=data, x='Time.Month Name', y='Statistics.Flights.Total')

# Define a custom title and x/y axes labels
by_month.set(
    title='Total number of flights in US airport throughout the year',
    xlabel='Months (from June 2003 through January 2016)',
    ylabel='Total flights (in millions)'
)
Out[ ]:
[Text(0.5, 1.0, 'Total number of flights in US airport throughout the year'),
 Text(0.5, 0, 'Months (from June 2003 through January 2016)'),
 Text(0, 0.5, 'Total flights (in millions)')]

For comparison purposes, see the matplotlib code code. Using either approach, we still needed to aggregate flight statistics by month. Once the data was created, with seaborn, we only needed to:

  1. Set the dimensions of the seaborn plot
  2. Define the plot using seaborn's lineplot() function
  3. Define the custom title and axis labels

From this example, you can see that seaborn's aim is to simplify the creation of data visualizations.

Data distributions¶

Now that you have learned some basic syntactical differences between matplotlib and seaborn, we will move on to our next goal: learning how to visualize data distributions. Until now, we have been working with aggregated data. To create histograms for visualizing data distributions, we need datasets that contain raw data.

List seaborn's sample datasets¶

With seaborn installed, you can use get_dataset_names() to get a list of seaborn's built-in datasets:

In [ ]:
# List all sample datasets
sns.get_dataset_names()
Out[ ]:
['anagrams',
 'anscombe',
 'attention',
 'brain_networks',
 'car_crashes',
 'diamonds',
 'dots',
 'dowjones',
 'exercise',
 'flights',
 'fmri',
 'geyser',
 'glue',
 'healthexp',
 'iris',
 'mpg',
 'penguins',
 'planets',
 'seaice',
 'taxis',
 'tips',
 'titanic']

Creating histograms using the "titanic" dataset¶

In this exercise, you will load seaborn's built-in "titanic" dataset and use the age column to create a histogram. First, you should load and view the dataset to get familiar with its columns and the types of data they contain:

In [ ]:
# Load the titanic dataset
titanic = sns.load_dataset('titanic')

# Display the data
titanic
Out[ ]:
survived pclass sex age sibsp parch fare embarked class who adult_male deck embark_town alive alone
0 0 3 male 22.0 1 0 7.2500 S Third man True NaN Southampton no False
1 1 1 female 38.0 1 0 71.2833 C First woman False C Cherbourg yes False
2 1 3 female 26.0 0 0 7.9250 S Third woman False NaN Southampton yes True
3 1 1 female 35.0 1 0 53.1000 S First woman False C Southampton yes False
4 0 3 male 35.0 0 0 8.0500 S Third man True NaN Southampton no True
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
886 0 2 male 27.0 0 0 13.0000 S Second man True NaN Southampton no True
887 1 1 female 19.0 0 0 30.0000 S First woman False B Southampton yes True
888 0 3 female NaN 1 2 23.4500 S Third woman False NaN Southampton no False
889 1 1 male 26.0 0 0 30.0000 C First man True C Cherbourg yes True
890 0 3 male 32.0 0 0 7.7500 Q Third man True NaN Queenstown no True

891 rows × 15 columns

In [ ]:
titanic_age_dist = sns.displot(titanic, x='age')
titanic_age_dist.set(title='Titanic: Age distribution of passengers')
Out[ ]:
<seaborn.axisgrid.FacetGrid at 0x15450b210>

From the histogram you can see that a large portion of the passengers were between 18 and 35 years old.

Categories¶

In the next exercise, you will learn how to count passengers by the categories found in the who column that comprise only 3 distinct values: child, man, or woman.

In [ ]:
passenger_type_dist = sns.catplot(titanic, x='who', kind='count')
passenger_type_dist.set(
    title='Titanic: Passengers',
    xlabel='',
    ylabel='Number of passengers'
)
Out[ ]:
<seaborn.axisgrid.FacetGrid at 0x1545a85d0>

When counting by categories, the seaborn catplot() function produces a bar chart with a distinct color for each category.

Multiple categories mapped against an attribute¶

Next, you will learn how to display multiple categories (gender and class) with respect to a single attribute (survivability) in order to visualize a data tendency. In this exercise, you will use the hue parameter for grouping by class. For the other category (gender), you will assign the column sex to the x axis. The central tendency we wish to expose is survivability based on these categories. Therefore, you will assign the survived column (having only two values, 0 or 1) to the y axis.

In [ ]:
survival = sns.catplot(
    titanic,
    x='sex',
    y='survived',
    hue='class',
    kind='bar'
)
survival.set(
    title='Titanic: Passenger survival by gender and class',
    xlabel='Gender',
    ylabel='Survivability'
)
Out[ ]:
<seaborn.axisgrid.FacetGrid at 0x15460ff90>

By default, seaborn's barplot function operates on the entire dataset to compute an estimate based on the mean. The thin, black vertical "error bars" represent a 95% confidence interval. You can use the errorbar parameter to change this default value.

You should also take note that the hue parameter is available with most of seaborn's functions. It allows you to visualize categories. Each category is represented by a unique color. Also, a legend is displayed that maps category names to colors.

Create a combined histogram of a single variable across multiple categories¶

In the next set of exercises, we will switch to seaborn's "penguins" built-in dataset that provides a richer set of categories and variables. But first, you should load and view the dataset to get familiar with its columns and the types of data they contain:

In [ ]:
# Load and display the penguins dataset
penguins = sns.load_dataset('penguins')
penguins
Out[ ]:
species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex
0 Adelie Torgersen 39.1 18.7 181.0 3750.0 Male
1 Adelie Torgersen 39.5 17.4 186.0 3800.0 Female
2 Adelie Torgersen 40.3 18.0 195.0 3250.0 Female
3 Adelie Torgersen NaN NaN NaN NaN NaN
4 Adelie Torgersen 36.7 19.3 193.0 3450.0 Female
... ... ... ... ... ... ... ...
339 Gentoo Biscoe NaN NaN NaN NaN NaN
340 Gentoo Biscoe 46.8 14.3 215.0 4850.0 Female
341 Gentoo Biscoe 50.4 15.7 222.0 5750.0 Male
342 Gentoo Biscoe 45.2 14.8 212.0 5200.0 Female
343 Gentoo Biscoe 49.9 16.1 213.0 5400.0 Male

344 rows × 7 columns

In this exercise, you will learn how to combine the distributions of 3 species of penguins across a single variable, flipper length.

In [ ]:
flipper_length = sns.displot(
    penguins,
    x='flipper_length_mm',
    hue='species',
    multiple='stack'
)
flipper_length.set(
    title='Cross-species penguin flipper length',
    xlabel='Flipper length (mm)'
)
Out[ ]:
<seaborn.axisgrid.FacetGrid at 0x154696810>

This stacked histogram reveals the overlap regarding flipper length across the different species. However, it doesn't provide a good view of the data distribution for each species.

Create separate histograms of a single variable for each category¶

In this exercise, you will see how replacing the multiple='stack' parameter with col='species' transforms the combined bar chart into separate bar charts for each species:

In [ ]:
sns.displot(
    penguins,
    x='flipper_length_mm',
    hue='species',
    col='species'
)
Out[ ]:
<seaborn.axisgrid.FacetGrid at 0x154744090>

Multiple linear regression¶

Linear regression is useful for determining the relationship between two variables. In this exercise, you will use the seaborn lmplot() function to visualize a linear fit for bill length x and bill depth y across multiple species of penguins. This function creates a scatterplot of both variables. Then, it fits the regression model y ~ x and draws the regression line.

In [ ]:
# Use the lmplot() function to draw a linear regression model
bill = sns.lmplot(
    penguins,
    x='bill_length_mm',
    y='bill_depth_mm',
    hue='species',
    height=5
)
bill.set(
    title='Linear regression model: Penguin bill length vs bill depth',
    xlabel='Bill length (mm)',
    ylabel='Bill depth (mm)'
)
Out[ ]:
<seaborn.axisgrid.FacetGrid at 0x1542f0fd0>

The endpoints of the regression lines represent the upper and lower limits of the x and y variables. The slope of the regression lines can be used to compute the estimated value of x given y and vice versa.

Faceted logistic regression¶

Faceting allows you to split the data into different categories and display the result via plot simultaneously. Logistic regression is useful for predicting how likely an event will happen using independent variables. In the next exercise, you will use the age and sex columns from the "titanic" dataset as independent variables to predict the probability of survival. The predictability will be split by gender into 2 plots: one for males and one for females.

In [ ]:
# Make a custom palette with gendered colors
pal = dict(male='#6495ED', female='#F08080')

# Show the survival probability as a function of age and sex
survival = sns.lmplot(
    titanic,
    x='age',
    y='survived',
    col='sex',
    hue='sex',
    palette=pal,
    y_jitter=.02,
    logistic=True,
    truncate=False
)
survival.set(xlim=(0, 80), ylim=(-.05, 1.05))
Out[ ]:
<seaborn.axisgrid.FacetGrid at 0x1542b6390>

In this exercise, you learned that the seaborn lmplot() function can perform both linear and logistic regression. By adding the logistic=True parameter, you can switch to logistic regression.

y_jitter only affects the appearance of the scatter plot. Noise is added to the data after fitting the regression, not before. Normally, ylim would default to (0, 1). However, it is expanded by 0.5 at both ends of the y-axis to compensate for the y_jitter setting.

Setting the hue parameter to the sex column splits the logistic regression into 2 independent models, each based on age.

Summary¶

You have completed the seaborn visualization tutorial. Here's what you have done:

  • Compared the differences between matplotlib and seaborn using a line chart from the previous tutorial
  • Learned how to list and access seaborn's built-in datasets for experimenting with a variety of raw data
  • Visualized data distributions ranging from simple histograms to complex, category-based distribution across multiple variables
  • Built charts for visualizing linear regression models
  • Built charts for visualizing logistic regression models

Copyright © 2023 John Kirch