Source code for macrosynergy.visuals.heatmap

"""
A subclass inheriting from `macrosynergy.visuals.plotter.Plotter`, designed to plot time
series data as a heatmap.
"""

from numbers import Number
from typing import List, Optional, Tuple, Union

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from seaborn.utils import relative_luminance

from macrosynergy.management.simulate import make_test_df
from macrosynergy.management.utils import downsample_df_on_real_date
from macrosynergy.visuals.plotter import Plotter, add_figure_footnote


[docs]class Heatmap(Plotter): """ Class for plotting time series data as a heatmap. Inherits from `macrosynergy.visuals.plotter.Plotter`. Parameters ---------- df : ~pandas.DataFrame A DataFrame with the following columns: 'cid', 'xcat', 'real_date', and at least one metric from - 'value', 'grading', 'eop_lag', or 'mop_lag'. cids : List[str] A list of cids to select from the DataFrame. If None, all cids are selected. xcats : List[str] A list of xcats to select from the DataFrame. If None, all xcats are selected. metrics : List[str] A list of metrics to select from the DataFrame. If None, all metrics are selected. start : str ISO-8601 formatted date. Select data from this date onwards. If None, all dates are selected. end : str ISO-8601 formatted date. Select data up to and including this date. If None, all dates are selected. """ def __init__( self, df: pd.DataFrame, cids: Optional[List[str]] = None, xcats: Optional[List[str]] = None, metrics: Optional[List[str]] = None, start: Optional[str] = None, end: Optional[str] = None, *args, **kwargs, ): super().__init__( df=df, cids=cids, xcats=xcats, metrics=metrics, start=start, end=end, *args, **kwargs, )
[docs] def plot( self, df: pd.DataFrame, figsize: Tuple[Number, Number] = (12, 8), x_axis_label: Optional[str] = None, y_axis_label: Optional[str] = None, axis_fontsize: int = 14, title: Optional[str] = None, title_fontsize: int = 22, title_xadjust: Number = 0.5, title_yadjust: Number = 1.0, footnote: Optional[str] = None, footnote_fontsize: int = 9, vmin: Optional[Number] = None, vmax: Optional[Number] = None, show: bool = True, save_to_file: Optional[str] = None, dpi: int = 300, return_figure: bool = False, on_axis: Optional[plt.Axes] = None, max_xticks: int = 50, cmap: Optional[Union[str, mpl.colors.Colormap]] = None, rotate_xticks: Optional[Number] = 0, rotate_yticks: Optional[Number] = 0, show_tick_lines: Optional[bool] = True, show_colorbar: Optional[bool] = True, show_annotations: Optional[bool] = False, show_boundaries: Optional[bool] = False, annotation_fontsize: int = 14, tick_fontsize: int = 13, *args, **kwargs, ) -> Optional[plt.Figure]: """ Plots a DataFrame as a heatmap with the columns along the x-axis and rows along the y-axis. Parameters ---------- figsize : Tuple tuple specifying the size of the figure. Default is (12, 8). x_axis_label : str label for x-axis. y_axis_label : str label for y-axis. axis_fontsize : int the font size for the axis labels. title : str the figure's title. title_fontsize : int the font size for the title. title_xadjust : float sets the x position of the title text. title_yadjust : float sets the y position of the title text. footnote : str Optional text shown at the bottom-left of the figure canvas. footnote_fontsize : int Font size of the footnote. Default is 9. vmin : float optional minimum value for heatmap scale. vmax : float optional maximum value for heatmap scale. show : bool if True, the image is displayed. save_to_file : str the path at which to save the heatmap as an image. If not specified, the plot will not be saved. dpi : int the resolution in dots per inch used if saving the figure. return_figure : bool if True, the function will return the figure. on_axis : plt.Axes optional `plt.Axes` object to be used instead of creating a new one. max_xticks : int the maximum number of ticks to be displayed along the x axis. Default is 50. cmap : mpl.colors.Colormap string or matplotlib Colormap object specifying the colormap of the plot. rotate_xticks : int number of degrees to rotate the tick labels on the x-axis. Default is zero. rotate_yticks : int number of degrees to rotate the tick labels on the y-axis. Default is zero. show_tick_lines : bool if True, lines are shown for ticks. Default is True. show_colorbar : bool if True, the colorbar is shown. Default is True. show_annotations : bool if True, annotations display the value of each cell. Default is False. show_boundaries : bool if True, cells are divided by a grid. Default is False. annotation_fontsize : int sets the font size of the annotations. tick_fontsize : int sets the font size of tick labels. """ if on_axis: fig: plt.Figure = on_axis.get_figure() ax: plt.Axes = on_axis else: fig: plt.Figure ax: plt.Axes fig, ax = plt.subplots(figsize=figsize, layout="constrained") data = df.to_numpy() im = ax.imshow( data, cmap=cmap, vmin=vmin, vmax=vmax, aspect="auto", **kwargs, ) xtick_labels = df.columns.to_list() ytick_labels = df.index.to_list() ax.set_xticks(np.arange(len(xtick_labels)), labels=xtick_labels) ax.set_yticks(np.arange(len(ytick_labels)), labels=ytick_labels) ax.set_xticklabels( xtick_labels, rotation=rotate_xticks, ha="center", minor=False, ) ax.set_yticklabels( ytick_labels, rotation=rotate_yticks, ha="right", minor=False, rotation_mode="anchor", ) ax.tick_params(axis="both", which="major", labelsize=tick_fontsize) if show_tick_lines: ax.tick_params(which="major", length=4, width=1, direction="out") plt.grid(False) if show_boundaries: ax.spines[:].set_visible(False) ax.set_xticks(np.arange(len(xtick_labels) + 1) - 0.5, minor=True) ax.set_yticks(np.arange(len(ytick_labels) + 1) - 0.5, minor=True) ax.grid(which="minor", color="w", linestyle="-", linewidth=3) ax.tick_params(which="minor", bottom=False, left=False) else: # Limits the number of ticks shown on the x-axis. ax.xaxis.set_major_locator(plt.MaxNLocator(nbins=max_xticks - 1)) ax.set_title( title, fontsize=title_fontsize, x=title_xadjust, y=title_yadjust, ) if x_axis_label: ax.set_xlabel(x_axis_label, fontsize=axis_fontsize) if y_axis_label: ax.set_ylabel(y_axis_label, fontsize=axis_fontsize) if show_colorbar: ax.figure.colorbar(im, ax=ax) if show_annotations: data = np.around(data, decimals=1) for i in range(data.shape[0]): for j in range(data.shape[1]): color = im.cmap(im.norm(im.get_array()))[i, j] lum = relative_luminance(color) text_color = ".15" if lum > 0.408 else "w" text_kwargs = dict( color=text_color, ha="center", va="center", size=annotation_fontsize, ) if not np.isnan(data[i, j]): ax.text(j, i, data[i, j], **text_kwargs) ax.tick_params(axis="y", which="major", pad=8) add_figure_footnote(fig, footnote=footnote, fontsize=footnote_fontsize) if save_to_file: plt.savefig( save_to_file, dpi=dpi, bbox_inches="tight", ) if show: plt.show() return if return_figure: return fig
[docs] def plot_metric( self, x_axis_column, y_axis_column, metric, xcats=None, cids=None, start=None, end=None, freq=None, agg="mean", figsize: Optional[Tuple[Number, Number]] = (12, 8), x_axis_label: Optional[str] = None, y_axis_label: Optional[str] = None, axis_fontsize: int = 14, title: Optional[str] = None, title_fontsize: int = 22, title_xadjust: Number = 0.5, title_yadjust: Number = 1.0, footnote: Optional[str] = None, footnote_fontsize: int = 9, vmin: Optional[Number] = None, vmax: Optional[Number] = None, show: bool = True, save_to_file: Optional[str] = None, dpi: int = 300, return_figure: bool = False, on_axis: Optional[plt.Axes] = None, max_xticks: int = 50, cmap: Optional[Union[str, mpl.colors.Colormap]] = None, rotate_xticks: Optional[Number] = 0, rotate_yticks: Optional[Number] = 0, show_tick_lines: Optional[bool] = True, show_colorbar: Optional[bool] = True, show_annotations: Optional[bool] = False, show_boundaries: Optional[bool] = False, annotation_fontsize: int = 14, tick_fontsize: int = 13, *args, **kwargs, ): """ Plots a metric from the DataFrame as a heatmap. Parameters ---------- x_axis_column : str the column to be used as the x-axis. y_axis_column : str the column to be used as the y-axis. metric : str the metric to be plotted. xcats : List[str] a list of xcats to select from the DataFrame. If None, all xcats are selected. cids : List[str] a list of cids to select from the DataFrame. If None, all cids are selected. start : str ISO-8601 formatted date string. Select data from this date onwards. If None, all dates are selected. end : str ISO-8601 formatted date string. Select data up to and including this date. If None, all dates are selected. freq : str frequency to downsample the data. Default is None. agg : str aggregation method. Must be one of 'mean', 'median', 'min', 'max', 'first' or 'last'. figsize : Tuple[float, float] tuple specifying the size of the figure. Default is (12, 8). x_axis_label : str label for x-axis. y_axis_label : str label for y-axis. axis_fontsize : int the font size for the axis labels. title : str the figure's title. title_fontsize : int the font size for the title. title_xadjust : float sets the x position of the title text. title_yadjust : float sets the y position of the title text. footnote : str Optional text shown at the bottom-left of the figure canvas. footnote_fontsize : int Font size of the footnote. Default is 9. vmin : float optional minimum value for heatmap scale. vmax : float optional maximum value for heatmap scale. show : bool if True, the image is displayed. save_to_file : str the path at which to save the heatmap as an image. If not specified, the plot will not be saved. dpi : int the resolution in dots per inch used if saving the figure. return_figure : bool if True, the function will return the figure. on_axis : plt.Axes optional `plt.Axes` object to be used instead of creating a new one. max_xticks : int the maximum number of ticks to be displayed along the x axis. Default is 50. cmap : mpl.colors.Colormap string or matplotlib Colormap object specifying the colormap of the plot. rotate_xticks : int number of degrees to rotate the tick labels on the x-axis. Default is zero. rotate_yticks : int number of degrees to rotate the tick labels on the y-axis. Default is zero. show_tick_lines : bool if True, lines are shown for ticks. Default is True. show_colorbar : bool if True, the colorbar is shown. Default is True. show_annotations : bool if True, annotations display the value of each cell. Default is False. show_boundaries : bool if True, cells are divided by a grid. Default is False. annotation_fontsize : int sets the font size of the annotations. tick_fontsize : int sets the font size of tick labels. """ df = self.df.copy() if not xcats: xcats = self.xcats if not cids: cids = self.cids if not start: start = self.start if not end: end = self.end # Validation checks not covered by Plotter. if metric not in ["value", "eop_lag", "mop_lag", "grading"]: raise ValueError( "`metric` must be either 'eop_lag', 'mop_lag', 'grading', or 'value'" ) if not isinstance(agg, str): raise TypeError("`agg` must be a string") else: agg: str = agg.lower() if agg not in ["mean", "median", "min", "max", "first", "last"]: raise ValueError( "`agg` must be one of 'mean', 'median', 'min', 'max', 'first' or " "'last'" ) df = df[["xcat", "cid", "real_date", metric]] if freq: df: pd.DataFrame = downsample_df_on_real_date( df=df, groupby_columns=["cid", "xcat"], freq=freq, agg=agg ) if "real_date" not in [x_axis_column, y_axis_column]: df = df.groupby(["xcat", "cid"], observed=True).mean().reset_index() else: df["real_date"] = df["real_date"].dt.strftime("%Y-%m-%d") vmax: float = max(1, df[metric].max()) vmin: float = min(0, df[metric].min()) df = df.pivot_table( index=y_axis_column, columns=x_axis_column, values=metric, observed=True, ) if figsize is None: figsize = ( max(df.shape[0] / 2, 15), max(1, df.shape[1] / 2), ) elif isinstance(figsize, list): figsize = tuple(figsize) fig = self.plot( df=df, figsize=figsize, x_axis_label=x_axis_label, y_axis_label=y_axis_label, axis_fontsize=axis_fontsize, title=title, title_fontsize=title_fontsize, title_xadjust=title_xadjust, title_yadjust=title_yadjust, footnote=footnote, footnote_fontsize=footnote_fontsize, vmin=vmin, vmax=vmax, show=show, save_to_file=save_to_file, dpi=dpi, return_figure=return_figure, on_axis=on_axis, max_xticks=max_xticks, cmap=cmap, rotate_xticks=rotate_xticks, rotate_yticks=rotate_yticks, show_tick_lines=show_tick_lines, show_colorbar=show_colorbar, show_annotations=show_annotations, show_boundaries=show_boundaries, annotation_fontsize=annotation_fontsize, tick_fontsize=tick_fontsize, ) if return_figure: return fig
if __name__ == "__main__": test_cids: List[str] = [ "USD", ] # "EUR", "GBP"] test_xcats: List[str] = ["FX", "IR"] dfE: pd.DataFrame = make_test_df( cids=test_cids, xcats=test_xcats, style="sharp-hill" ) dfM: pd.DataFrame = make_test_df( cids=test_cids, xcats=test_xcats, style="four-bit-sine" ) dfG: pd.DataFrame = make_test_df(cids=test_cids, xcats=test_xcats, style="sine") dfE.rename(columns={"value": "eop_lag"}, inplace=True) dfM.rename(columns={"value": "mop_lag"}, inplace=True) dfG.rename(columns={"value": "grading"}, inplace=True) mergeon = ["cid", "xcat", "real_date"] dfx: pd.DataFrame = pd.merge(pd.merge(dfE, dfM, on=mergeon), dfG, on=mergeon) heatmap = Heatmap(df=dfx, xcats=["FX"]) heatmap.df["real_date"] = heatmap.df["real_date"].dt.strftime("%Y-%m-%d") heatmap.df = heatmap.df.pivot_table( index="cid", columns="real_date", values="grading" ) heatmap.plot(heatmap.df, title="abc", rotate_xticks=90)