Learn how to transform a data table into a heat map using python!
To run this notebook you need to install basic datascience libraries:
Download this notebook.
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
# Minimalist design
sns.set_style("whitegrid")
# Examples of palette, choose wisely !
# More information here: https://seaborn.pydata.org/tutorial/color_palettes.html
sns.palplot(sns.cubehelix_palette(8, start=.5, rot=-.75))
sns.palplot(sns.diverging_palette(220, 20, sep=10, n=10))
# Toy dataset
data = np.array([[14,14,64,0,7,0],
[1,15,52,16,11,5],
[6,17,53,9,7,3],
[0,6,24,29,35,6],
[0,21,18,18,31,11],
[0,9,21,9,9,53]])
# Labels
labelsX = ['X1','X2', 'X3', 'X4', 'X5', 'X6']
labelsY = ['Y1','Y2', 'Y3', 'Y4', 'Y5', 'Y6']
# Version 0
import pandas as pd
pd.DataFrame(data, columns=labelsX, index=labelsY)
X1 | X2 | X3 | X4 | X5 | X6 | |
---|---|---|---|---|---|---|
Y1 | 14 | 14 | 64 | 0 | 7 | 0 |
Y2 | 1 | 15 | 52 | 16 | 11 | 5 |
Y3 | 6 | 17 | 53 | 9 | 7 | 3 |
Y4 | 0 | 6 | 24 | 29 | 35 | 6 |
Y5 | 0 | 21 | 18 | 18 | 31 | 11 |
Y6 | 0 | 9 | 21 | 9 | 9 | 53 |
# Version 1
fig, ax = plt.subplots(figsize=(10,10))
sns.heatmap(data, linewidths=.5, square=True, annot=True, fmt="d", robust=True,
xticklabels=labelsX, yticklabels=labelsY,
cmap=sns.cubehelix_palette(8, start=.5, rot=-.75),
cbar_kws={'label': 'Legend description', 'shrink': 0.8})
# Version 2
fig, ax = plt.subplots(figsize=(10,10))
sns.heatmap(data, linewidths=.5, square=True, annot=True, fmt="d", robust=True,
xticklabels=labelsX, yticklabels=labelsY,
cmap=sns.cubehelix_palette(8, start=.5, rot=-.75),
cbar_kws={'label': 'Legend description', 'shrink': 0.8})
ax.xaxis.tick_top()
plt.ylabel('X axis Label')
plt.ylabel('X axis\n Label',rotation=0)
plt.xlabel('Y axis Label')
plt.yticks(rotation=0)
# Version 3
fig, ax = plt.subplots(figsize=(10,10))
sns.heatmap(data, linewidths=.5, square=True, annot=True, fmt="d", robust=True,
xticklabels=labelsX, yticklabels=labelsY,
cmap=sns.diverging_palette(220, 20, sep=10, n=10), center=10,
cbar_kws={'label': 'Legend description', 'shrink': 0.8})
ax.xaxis.tick_top()
plt.ylabel('X axis Label')
plt.ylabel('X axis\n Label',rotation=0)
plt.xlabel('Y axis Label')
plt.yticks(rotation=0)
# Version 4
fig, ax = plt.subplots(figsize=(10,10))
sns.heatmap(data, linewidths=.5, square=True, annot=True, fmt="d", robust=True,
xticklabels=labelsX, yticklabels=labelsY,
cmap=sns.diverging_palette(220, 20, sep=10, n=10), center=10,
cbar_kws = {'use_gridspec':False,'location':'top','label': 'Legend description', 'shrink': 0.78})
ax.xaxis.tick_top()
plt.ylabel('X axis Label')
plt.ylabel('X axis\n Label',rotation=0)
plt.xlabel('Y axis Label')
plt.yticks(rotation=0)