Cohort Analysis using Python: A Detailed Guide

A Beginner's Guide To Cohort Analysis Using Python

Cohort analysis is a kind of analytics that is used to study behavioral patterns. A newly developed branch of business analytics, it tracks the behaviors and habits of customers mainly to study buying and selling patterns on e-commerce websites and online businesses.

Cohort analysis takes into account time series data over the lifespan of a user to predict and analyze their future consumption patterns on various websites. In big data analysis, huge amounts of data are clustered into common groups or cohorts that share common features and characteristics over a specific timeline.

Each group or cohort is a subset containing data or information that starts with fresh sets of consumers or users allowing bigger companies to look at relevant data and analyze current presentable information. Large multinational companies use cohort analysis to specifically analyze the habits of new as well as old users so that they can optimize their user experience without having to blindly scout through millions and millions of datasets over the lifespan of a customer.

Also read: Linear Discriminant Analysis in Python – A Detailed Guide

Some common examples of Cohort Analysis

Let’s look at some examples in which cohort analysis can be used to study the behavioral patterns of customers or users.

  • In ecommerce, customers who have started using the service just two or three days ago or had just made their first purchase can be grouped into one group or cohort whereas people who have been on the website for 1 or 2 years may be grouped into another cohort and their purchase activity over the year can be used to predict how many times they will visit the website in the future or how many times they will actually click on a product and purchase it. This can be used to personalize their own home feed so that the click-through rate increases and their retention for that particular website don’t subside.
  • On websites, such as YouTube, cohort analysis can be used to predict viewer retention when a new person logs in for the first time vs a person who has been on the platform for a long time. To keep the new user on the platform, their home feed will be personalized with the most popular videos that have a huge amount of views, clickbait thumbnails, or videos with popular search terms based on that particular user’s location data whereas a long-time user will mostly receive video updates from their subscriptions and their most watched channels.

Also read: Introduction to Time Series Analysis using Python

Steps of Cohort Analysis

There are five main steps to performing cohort analysis. They are as follows:

Determining the main objective of the cohort analysis(Determining the question): First and foremost, we need to determine the main intent of performing the analysis, such as to analyze why people on YouTube don’t watch videos after the 6-minute mark. This sets the ultimate goal of the analysis and uses the huge pool of information for practical issues. This helps in pinpointing the root issue or cause and companies can then work towards improved business practices to provide a better user experience.

Defining the metrics to respond to the question: The next step would be to identify what defines the problem. In simpler words, from the above example, we need to analyze when a viewer leaves a video or at what minute before moving on to something else, his/her watch time, and click-through rates on YouTube.

Depicting the particular groups or cohorts that will be relevant: To analyze users we need to pick out a group of viewers who display common behavioral patterns. In order to do this we need to analyze data from different user inputs and identify relevant similarities and differences between them and then separate it into specified cohorts.

Performing the Cohort Analysis: Now we will use data visualization techniques to perform the cohort analysis based on the objective of the problem. This can be done using many programming languages out of which the preferred languages are python and R. Cohort analysis in python can be done using libraries such as NumPy and seaborn. Heat maps are usually used to display user retention and visualize data in a tabular form.

Testing the Results: Last but not the least, the results need to be checked and tested in order to make sure that they can actually reduce company losses and optimize business practices. We will obtain retention rates from the analysis and a heatmap(or any other suitable graph) of user retention and retention rate will help us analyze and improve experiences for the users.

Steps of cohort analysis.
Steps of Cohort analysis

Components of a Cohort and Representation

Cohorts mainly have three components: Time, Size, and behavior. Time is the most commonly used component of a cohort which is used to analyze a user’s or customer’s retention on a specific platform. Size is defined as how much money or resources a customer is spending or spends usually. And behavior defines the working of a cohort at a specific time.

In this article, we will analyze user retention rate hence we will be analyzing data on the basis of the time parameter using a pivot table and representing said pivot table in the form of a heatmap. The rows of the pivot table consist of the beginning of user activity or the month from which the user has started visiting the ecommerce website or has made the first purchase. The columns represent the user’s retention rate or how long has the user been coming back to purchase since his first time.

For example, if a customer started using the website in April, his activity first appears in the row named April. If he continues to use the website in October, then his activity since April will be shown under specific monthly columns.

Cohort Analysis using Python

In Python, there are various libraries that can be used to perform cohort analysis in a structured manner. One of the most important libraries that we will use is seaborn, along with pandas and NumPy and openpyxl to read excel sheets. Let us begin.

You will need to download and install the libraries mentioned above on your system if you don’t have them already installed. Open your command prompt and run it in administrator mode.

pip install seaborn
pip install numpy
pip install openpyxl
pip install pandas
pip install matplotlib

In case you’re using anaconda, run your anaconda prompt in administrator mode and run the following commands to complete installation.

conda install -c anaconda seaborn
conda install -c anaconda pandas
conda install -c anaconda openpyxl
conda install -c anaconda numpy
conda install -c conda-forge matplotlib

Obtaining sample dataset

Now if you do not have a specific customer data set for your analysis, you can use sample datasets from various APIs or websites. For example, the dataset that I’m going to use is obtained from a sample dataset available on Kaggle. Click here to download the dataset used in this tutorial.

You can also use the UCL machine learning API for datasets. In this project, we are using an online retail data set to represent user retention rates. The given picture is the dataset I used with 541909 entries.

dataset.
Online retail dataset.

Importing the Dataset

Now, we’ll import this dataset into our code along with the required modules so that we can use them later. But for now, let’s import this excel file into our program and display the columns and their names along with the datatype for each column.

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import openpyxl

# load in the data and take a look 
data = pd.read_excel("Online Retail.xlsx")
data.info()

The output of the above code will look like the one shown below:

display dataset

Dropping Rows with No CustomerID

Now to properly analyze the cohorts, we will drop the rows where CustomerID is absent because the number of CustomerID and the total number of InvoiceNo are different. Also, we will define a function for getting the month from the invoice and extract that information for specific dates.

#eliminating rows without customer ID
data = data.dropna(subset=['CustomerID'])


#function for getting the invoice month
def getting_months(m):
    return dt.datetime(m.year, m.month,1)

#using the above function
data['Invoice-Month'] = data['InvoiceDate'].apply(getting_months)

Indexing the Column By First Month of Visit

Next, we will index the columns for the first month’s visit of the customer and then create a function to properly sort out the day, month, and year of the first visit of a customer.

#indexing a column for the first month visit of the customer
data['Cohort-Month'] =  data.groupby('CustomerID')['Invoice-Month'].transform('min')
data.head(30)

#function for data to create a series
def get_date_elements(df, column):
    day = df[column].dt.day
    month = df[column].dt.month
    year = df[column].dt.year
    return day, month, year 

# getting date for columns and invoice
_,Invoiceofmonth,Invoiceofyear =  get_date_elements(data,'Invoice-Month')
_,Cohortofmonth,Cohortofyear =  get_date_elements(data,'Cohort-Month')

Grouping Data With CustomerIDs

Now to create the index for the cohorts, we will run the following code with the following formulae and then group the data using the customerIDs.

#cohort index creation
yeardifference = Invoiceofyear -Cohortofyear
monthdifference = Invoiceofmonth - Cohortofmonth
data['Cohort-Index'] = yeardifference*12+monthdifference+1

#counting customer ID 
cohort_data = data.groupby(['Cohort-Month','Cohort-Index'])['CustomerID'].apply(pd.Series.nunique).reset_index()

Creating a Pivot Table

Now, it’s time to finally create the pivot table so that we can use it to plot the heatmaps along with changing the index of the cohort table.

#pivot table creation
cohort_table = cohort_data.pivot(index='Cohort-Month', columns=['Cohort-Index'],values='CustomerID')

# changing index of the cohort table
cohort_table.index = cohort_table.index.strftime('%B %Y')

Creating the Heatmaps

Now for the final portion of the code, we will create two heatmaps, one where we will plot the actual number of users and the other one which will help us visualize the entire dataset as percentages along with the user retention rates.

#creation of heatmap and visualization
plt.figure(figsize=(21,10))
sns.heatmap(cohort_table,annot=True,cmap='Greens')

#cohort for percentage analysis
new_cohort_table = cohort_table.divide(cohort_table.iloc[:,0],axis=0)

#creating a percentage visualization
plt.figure(figsize=(21,10))
colormap=sns.color_palette("mako", as_cmap=True)
sns.heatmap(new_cohort_table,annot=True,fmt='.0%',cmap=colormap)

#show the heatmaps
plt.show()

Complete Code for Cohort Analysis in Python

Now, run the code entirely as given below:

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import openpyxl
import datetime as dt

# loading data and displaying
data = pd.read_excel("Online Retail.xlsx")


#eliminating rows without customer ID
data = data.dropna(subset=['CustomerID'])


#function for getting the invoice month
def getting_months(m):
    return dt.datetime(m.year, m.month,1)

#using the above function
data['Invoice-Month'] = data['InvoiceDate'].apply(getting_months)

#indexing a column for the first month visit of the customer
data['Cohort-Month'] =  data.groupby('CustomerID')['Invoice-Month'].transform('min')
data.head(30)

#function for data to create a series
def get_date_elements(df, column):
    day = df[column].dt.day
    month = df[column].dt.month
    year = df[column].dt.year
    return day, month, year 

# getting date for columns and invoice
_,Invoiceofmonth,Invoiceofyear =  get_date_elements(data,'Invoice-Month')
_,Cohortofmonth,Cohortofyear =  get_date_elements(data,'Cohort-Month')

#cohort index creation
yeardifference = Invoiceofyear -Cohortofyear
monthdifference = Invoiceofmonth - Cohortofmonth
data['Cohort-Index'] = yeardifference*12+monthdifference+1

#counting customer ID 
cohort_data = data.groupby(['Cohort-Month','Cohort-Index'])['CustomerID'].apply(pd.Series.nunique).reset_index()

#pivot table creation
cohort_table = cohort_data.pivot(index='Cohort-Month', columns=['Cohort-Index'],values='CustomerID')

# changing index of the cohort table
cohort_table.index = cohort_table.index.strftime('%B %Y')

#creation of heatmap and visualization
plt.figure(figsize=(21,10))
sns.heatmap(cohort_table,annot=True,cmap='Greens')

#cohort for percentage analysis
new_cohort_table = cohort_table.divide(cohort_table.iloc[:,0],axis=0)

#creating a percentage visualization
plt.figure(figsize=(21,10))
colormap=sns.color_palette("mako", as_cmap=True)
sns.heatmap(new_cohort_table,annot=True,fmt='.0%',cmap=colormap)
#display the heatmaps.
plt.show()

The two heatmaps are given below:

Normal Heatmap
Normal Heatmap

The percentage analysis is given below:

Percentage Heatmap
Percentage Heatmap

Summary

Cohort analysis is an extremely helpful tool that can be used to improve business practices and can effectively increase user retention if businesses implement necessary changes according to the test results. In today’s world where data is everywhere, cohort analysis is effective in extracting useful information by analyzing the behavioral patterns of customers in order to predict the future of the business. Observing cohorts over time gives insight into user experience and helps in developing better tactics. Not only python but cohort analysis can also be done using R and other programming languages. For more information regarding cohort analysis, click here.