Heatmaps¶
A heatmap uses color to represent the magnitude of a variable across two dimensions, often showing how different categories or groups relate to each other.
By visualizing the data as a grid of colored squares or cells, heatmaps can quickly convey complex patterns and relationships, such as correlations, clustering, or hierarchical structures, making it easier to identify insights and trends in large datasets.
Show code cell source
import plotly.io as pio
pio.renderers.default = "sphinx_gallery"
import plotly.express as px
from plotly.subplots import make_subplots
import pandas as pd
import statsplotly
Note that
statsplotly.heatmap
does not operate on wide-form DataFrame, so the data needs to be melted :
df = px.data.medals_wide(indexed=True)
fig = statsplotly.heatmap(
data=df.melt(ignore_index=False),
x="nation",
y="medal",
z="value",
opacity=0.8,
color_palette=["#d4f542", "#4275f5"],
axis="square",
)
fig.show()
Setting a color logscale¶
One can also specify the base of a logscale for the colormap with the logscale
parameter :
df = px.data.stocks().set_index("date")
df.index = pd.DatetimeIndex(df.index, yearfirst=True, name="date")
fig = statsplotly.heatmap(
data=df.melt(ignore_index=False, var_name="company", value_name="stock_value"),
x="company",
y="date",
z="stock_value",
color_palette="Greens_r",
logscale=10,
)
fig.show()
One coloraxis per subplot¶
Statsplotly makes it easy to draw subplots of heatmaps with sound coloraxis and colorbar managements.
Here for example we calculate cross-correlation features across the iris species :
df = px.data.iris().set_index("species")
corr_df = pd.concat(
[
df.loc[species, ["sepal_length", "sepal_width", "petal_length", "petal_width"]].corr()
for species in df.index.unique()
],
keys=df.index.unique(),
names=["species", "index"],
)
fig = make_subplots(rows=len(corr_df.index.unique("species")), cols=1, shared_xaxes=True)
for i, species in enumerate(corr_df.index.unique("species"), 1):
fig = statsplotly.heatmap(
fig=fig,
row=i,
data=corr_df.loc[species].melt(
ignore_index=False, var_name=species, value_name="correlation"
),
x="index",
y=species,
z="correlation",
title="Iris features cross-correlations",
color_palette="reds",
axis="equal",
)
fig.layout.height = 800
fig.show()
Heatmap slices¶
statsplotly.heatmap
can slice the data along a particular dimension. This is handy when one seeks to individually inspect subsets of the data :
df = px.data.iris().set_index("species")
def construct_slicable_cross_correlation_matrix(df: pd.DataFrame) -> pd.DataFrame:
corr_df = (
pd.concat(
[df.loc[species].reset_index(drop=True) for species in df.index.unique()],
keys=df.index.unique(),
axis=1,
)
.drop(columns=["species_id"], level=1)
.corr()
)
corr_df.columns = ["-".join(column) for column in corr_df.columns.to_flat_index()]
corr_df["species"] = corr_df.index.get_level_values("species")
corr_df.index = ["-".join(idx) for idx in corr_df.index.to_flat_index()]
return corr_df
fig = statsplotly.heatmap(
data=construct_slicable_cross_correlation_matrix(df).melt(
ignore_index=False, id_vars="species", value_name="correlation"
),
x="index",
y="variable",
z="correlation",
title="Iris features cross-correlations",
color_palette="RdBu_r",
color_limits=(-1, 1),
opacity=1,
slicer="species",
axis="equal",
)
fig.layout.height = 800
fig.show()
Full details of the API : heatmap()
.