Source code for macrosynergy.visuals.table

import pandas as pd
from typing import List, Optional, Tuple
import seaborn as sns
import matplotlib.pyplot as plt


[docs]def view_table( df: pd.DataFrame, title: Optional[str] = None, title_fontsize: Optional[int] = 16, figsize: Optional[Tuple[float, float]] = (14, 4), min_color: float = -1, max_color: float = 1, xlabel: Optional[str] = None, ylabel: Optional[str] = None, xticklabels: Optional[List[str]] = None, yticklabels: Optional[List[str]] = None, annot: bool = True, fmt: str = ".2f", return_fig: bool = False, ): """ Displays a DataFrame representing a table as a heatmap. Parameters ---------- df : ~pandas.DataFrame table to be displayed. title : str, optional string of chart title; defaults depend on type of range plot. title_fontsize : int, optional font size of chart header. Default is 16. figsize : Tuple[float, float], optional Tuple (w, h) of width and height of plot. min_color : float minimum value of colorbar. Default is -1. max_color : float maximum value of colorbar. Default is 1. xlabel : str, optional string of x-axis label. Default is None. ylabel : str, optional string of y-axis label. Default is None. xticklabels : List[str], optional list of strings to label x-axis ticks. Default is None. yticklabels : List[str], optional list of strings to label y-axis ticks. Default is None. annot : bool whether to annotate heatmap with values. fmt : str string format for annotations. Default is '.2f'. return_fig : bool If True, return the Matplotlib figure object instead of displaying. """ if not isinstance(df, pd.DataFrame): raise TypeError("Table must be a DataFrame") if df.empty: raise ValueError("Table must not be empty") try: df = df.astype(float) except ValueError: raise ValueError("Table must be numeric") if xticklabels is None: xticklabels = df.columns.to_list() elif len(xticklabels) != len(df.columns): raise ValueError("Number of xticklabels must match number of columns") if yticklabels is None: yticklabels = df.index.to_list() elif len(yticklabels) != len(df.index): raise ValueError("Number of yticklabels must match number of rows") fig, ax = plt.subplots(figsize=figsize) sns.set(style="ticks") sns.heatmap( df, cmap="vlag_r", vmin=min_color, vmax=max_color, square=False, linewidths=0.5, cbar_kws={"shrink": 0.5}, annot=annot, fmt=fmt, xticklabels=xticklabels, yticklabels=yticklabels, ) ax.set(xlabel=xlabel, ylabel=ylabel) ax.set_title(title, fontsize=title_fontsize) plt.tight_layout() if return_fig: return fig else: plt.show()
if __name__ == "__main__": data = { "Col1": [1, 2, 3, 4], "Col2": [5, 6, 7, 8], "Col3": [9, 10, 11, 12], "Col4": [13, 14, 15, 16], } row_labels = ["Row1", "Row2", "Row3", "Row4"] df = pd.DataFrame(data, index=row_labels) view_table(df, title="Table")