import numpy as np
import pandas as pd
from typing import List, Optional, Tuple, Union
import seaborn as sns
import matplotlib.pyplot as plt
[docs]def view_table(
df: pd.DataFrame,
title: Optional[str] = None,
title_fontsize: Optional[int] = 16,
figsize: Optional[Tuple[float, float]] = (14, 4),
min_color: float = -1,
max_color: float = 1,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
xticklabels: Optional[List[str]] = None,
yticklabels: Optional[List[str]] = None,
annot: Union[bool, np.ndarray, pd.DataFrame] = True,
fmt: str = ".2f",
return_fig: bool = False,
highlight_mask: Optional[Union[np.ndarray, pd.DataFrame]] = None,
footnote: Optional[str] = None,
footnote_fontsize: int = 10,
) -> Optional[plt.Figure]:
"""
Display a numeric DataFrame as an annotated colour-coded heatmap table.
Parameters
----------
df : pd.DataFrame
Numeric DataFrame to display.
title : str, optional
Title displayed above the heatmap.
title_fontsize : int, optional
Font size of the title. Default is 16.
figsize : Tuple[float, float], optional
Width and height of the figure in inches.
min_color : float
Data value mapped to the bottom of the colormap. Default is -1.
max_color : float
Data value mapped to the top of the colormap. Default is 1.
xlabel : str, optional
Label for the x-axis.
ylabel : str, optional
Label for the y-axis.
xticklabels : List[str], optional
Tick labels for the columns. Defaults to the DataFrame column names.
yticklabels : List[str], optional
list of strings to label y-axis ticks. Default is None.
annot : bool or array-like of str
if a bool, controls whether the numeric values of ``df`` are annotated.
If a DataFrame or 2D array of strings is supplied, those strings are
rendered as cell annotations verbatim (``fmt`` is ignored).
fmt : str
string format for annotations. Default is '.2f'. Ignored when ``annot``
is array-like.
return_fig : bool
If True, return the Matplotlib figure object instead of displaying.
highlight_mask : array-like of bool, optional
DataFrame or 2D array of the same shape as ``df``. Cells where the
mask is True have their annotation text rendered in black and bold.
Has no effect when ``annot`` is False.
footnote : str, optional
Free-text caption rendered below the heatmap, useful for noting the
statistical test, the panel scope, or how to read the annotations.
Multi-line strings are supported. Default is None (no footnote).
footnote_fontsize : int, optional
Font size for the footnote text. Default is 10.
"""
if not isinstance(df, pd.DataFrame):
raise TypeError("Table must be a DataFrame")
if df.empty:
raise ValueError("Table must not be empty")
try:
df = df.astype(float)
except ValueError:
raise ValueError("Table must be numeric")
if xticklabels is None:
xticklabels = df.columns.to_list()
elif len(xticklabels) != len(df.columns):
raise ValueError("Number of xticklabels must match number of columns")
if yticklabels is None:
yticklabels = df.index.to_list()
elif len(yticklabels) != len(df.index):
raise ValueError("Number of yticklabels must match number of rows")
annot_fmt = fmt
if isinstance(annot, (pd.DataFrame, np.ndarray)):
annot_arr = annot.values if isinstance(annot, pd.DataFrame) else annot
if annot_arr.shape != df.shape:
raise ValueError(
"annot array shape must match the DataFrame shape "
f"{df.shape}, got {annot_arr.shape}."
)
annot = annot_arr
annot_fmt = ""
fig, ax = plt.subplots(figsize=figsize)
sns.set(style="ticks")
sns.heatmap(
df,
cmap="vlag_r",
vmin=min_color,
vmax=max_color,
square=False,
linewidths=0.5,
cbar_kws={"shrink": 0.5},
annot=annot,
fmt=annot_fmt,
xticklabels=xticklabels,
yticklabels=yticklabels,
)
if highlight_mask is not None and not (annot is False):
if isinstance(highlight_mask, pd.DataFrame):
mask_arr = highlight_mask.values
else:
mask_arr = np.asarray(highlight_mask)
if mask_arr.shape != df.shape:
raise ValueError(
"highlight_mask shape must match the DataFrame shape "
f"{df.shape}, got {mask_arr.shape}."
)
# seaborn lays out ax.texts in row-major order matching df.values.ravel()
for txt, hi in zip(ax.texts, mask_arr.ravel()):
if bool(hi):
txt.set_color("black")
txt.set_weight("bold")
ax.set(xlabel=xlabel, ylabel=ylabel)
ax.set_title(title, fontsize=title_fontsize)
plt.tight_layout()
if footnote:
n_lines = footnote.count("\n") + 1
# Reserve bottom margin for the heatmap + tick labels + optional xlabel + footnote.
# Figure-relative units: tick labels ~0.06, xlabel ~0.04 if present, then footnote.
xlabel_pad = 0.04 if ax.get_xlabel() else 0.0
footnote_block = 0.04 * n_lines
fig.subplots_adjust(bottom=0.10 + xlabel_pad + footnote_block)
fig.text(
0.5,
0.02,
footnote,
ha="center",
va="bottom",
fontsize=footnote_fontsize,
style="italic",
wrap=True,
)
if return_fig:
return fig
else:
plt.show()
if __name__ == "__main__":
data = {
"Col1": [1, 2, 3, 4],
"Col2": [5, 6, 7, 8],
"Col3": [9, 10, 11, 12],
"Col4": [13, 14, 15, 16],
}
row_labels = ["Row1", "Row2", "Row3", "Row4"]
df = pd.DataFrame(data, index=row_labels)
view_table(df, title="Table")