"""
A subclass inheriting from `macrosynergy.visuals.plotter.Plotter`, designed to plot time
series data as a heatmap.
"""
from numbers import Number
from typing import List, Optional, Tuple, Union
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from seaborn.utils import relative_luminance
from macrosynergy.management.simulate import make_test_df
from macrosynergy.management.utils import downsample_df_on_real_date
from macrosynergy.visuals.plotter import Plotter, add_figure_footnote
[docs]class Heatmap(Plotter):
"""
Class for plotting time series data as a heatmap. Inherits from
`macrosynergy.visuals.plotter.Plotter`.
Parameters
----------
df : ~pandas.DataFrame
A DataFrame with the following columns: 'cid', 'xcat', 'real_date', and at least
one metric from - 'value', 'grading', 'eop_lag', or 'mop_lag'.
cids : List[str]
A list of cids to select from the DataFrame. If None, all cids are selected.
xcats : List[str]
A list of xcats to select from the DataFrame. If None, all xcats are selected.
metrics : List[str]
A list of metrics to select from the DataFrame. If None, all metrics are
selected.
start : str
ISO-8601 formatted date. Select data from this date onwards. If None, all dates
are selected.
end : str
ISO-8601 formatted date. Select data up to and including this date. If None, all
dates are selected.
"""
def __init__(
self,
df: pd.DataFrame,
cids: Optional[List[str]] = None,
xcats: Optional[List[str]] = None,
metrics: Optional[List[str]] = None,
start: Optional[str] = None,
end: Optional[str] = None,
*args,
**kwargs,
):
super().__init__(
df=df,
cids=cids,
xcats=xcats,
metrics=metrics,
start=start,
end=end,
*args,
**kwargs,
)
[docs] def plot(
self,
df: pd.DataFrame,
figsize: Tuple[Number, Number] = (12, 8),
x_axis_label: Optional[str] = None,
y_axis_label: Optional[str] = None,
axis_fontsize: int = 14,
title: Optional[str] = None,
title_fontsize: int = 22,
title_xadjust: Number = 0.5,
title_yadjust: Number = 1.0,
footnote: Optional[str] = None,
footnote_fontsize: int = 9,
vmin: Optional[Number] = None,
vmax: Optional[Number] = None,
show: bool = True,
save_to_file: Optional[str] = None,
dpi: int = 300,
return_figure: bool = False,
on_axis: Optional[plt.Axes] = None,
max_xticks: int = 50,
cmap: Optional[Union[str, mpl.colors.Colormap]] = None,
rotate_xticks: Optional[Number] = 0,
rotate_yticks: Optional[Number] = 0,
show_tick_lines: Optional[bool] = True,
show_colorbar: Optional[bool] = True,
show_annotations: Optional[bool] = False,
show_boundaries: Optional[bool] = False,
annotation_fontsize: int = 14,
tick_fontsize: int = 13,
*args,
**kwargs,
) -> Optional[plt.Figure]:
"""
Plots a DataFrame as a heatmap with the columns along the x-axis and rows along
the y-axis.
Parameters
----------
figsize : Tuple
tuple specifying the size of the figure. Default is (12, 8).
x_axis_label : str
label for x-axis.
y_axis_label : str
label for y-axis.
axis_fontsize : int
the font size for the axis labels.
title : str
the figure's title.
title_fontsize : int
the font size for the title.
title_xadjust : float
sets the x position of the title text.
title_yadjust : float
sets the y position of the title text.
footnote : str
Optional text shown at the bottom-left of the figure canvas.
footnote_fontsize : int
Font size of the footnote. Default is 9.
vmin : float
optional minimum value for heatmap scale.
vmax : float
optional maximum value for heatmap scale.
show : bool
if True, the image is displayed.
save_to_file : str
the path at which to save the heatmap as an image. If not specified, the
plot will not be saved.
dpi : int
the resolution in dots per inch used if saving the figure.
return_figure : bool
if True, the function will return the figure.
on_axis : plt.Axes
optional `plt.Axes` object to be used instead of creating a new one.
max_xticks : int
the maximum number of ticks to be displayed along the x axis. Default is 50.
cmap : mpl.colors.Colormap
string or matplotlib Colormap object specifying the colormap of the plot.
rotate_xticks : int
number of degrees to rotate the tick labels on the x-axis. Default is zero.
rotate_yticks : int
number of degrees to rotate the tick labels on the y-axis. Default is zero.
show_tick_lines : bool
if True, lines are shown for ticks. Default is True.
show_colorbar : bool
if True, the colorbar is shown. Default is True.
show_annotations : bool
if True, annotations display the value of each cell. Default is False.
show_boundaries : bool
if True, cells are divided by a grid. Default is False.
annotation_fontsize : int
sets the font size of the annotations.
tick_fontsize : int
sets the font size of tick labels.
"""
if on_axis:
fig: plt.Figure = on_axis.get_figure()
ax: plt.Axes = on_axis
else:
fig: plt.Figure
ax: plt.Axes
fig, ax = plt.subplots(figsize=figsize, layout="constrained")
data = df.to_numpy()
im = ax.imshow(
data,
cmap=cmap,
vmin=vmin,
vmax=vmax,
aspect="auto",
**kwargs,
)
xtick_labels = df.columns.to_list()
ytick_labels = df.index.to_list()
ax.set_xticks(np.arange(len(xtick_labels)), labels=xtick_labels)
ax.set_yticks(np.arange(len(ytick_labels)), labels=ytick_labels)
ax.set_xticklabels(
xtick_labels,
rotation=rotate_xticks,
ha="center",
minor=False,
)
ax.set_yticklabels(
ytick_labels,
rotation=rotate_yticks,
ha="right",
minor=False,
rotation_mode="anchor",
)
ax.tick_params(axis="both", which="major", labelsize=tick_fontsize)
if show_tick_lines:
ax.tick_params(which="major", length=4, width=1, direction="out")
plt.grid(False)
if show_boundaries:
ax.spines[:].set_visible(False)
ax.set_xticks(np.arange(len(xtick_labels) + 1) - 0.5, minor=True)
ax.set_yticks(np.arange(len(ytick_labels) + 1) - 0.5, minor=True)
ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
ax.tick_params(which="minor", bottom=False, left=False)
else:
# Limits the number of ticks shown on the x-axis.
ax.xaxis.set_major_locator(plt.MaxNLocator(nbins=max_xticks - 1))
ax.set_title(
title,
fontsize=title_fontsize,
x=title_xadjust,
y=title_yadjust,
)
if x_axis_label:
ax.set_xlabel(x_axis_label, fontsize=axis_fontsize)
if y_axis_label:
ax.set_ylabel(y_axis_label, fontsize=axis_fontsize)
if show_colorbar:
ax.figure.colorbar(im, ax=ax)
if show_annotations:
data = np.around(data, decimals=1)
for i in range(data.shape[0]):
for j in range(data.shape[1]):
color = im.cmap(im.norm(im.get_array()))[i, j]
lum = relative_luminance(color)
text_color = ".15" if lum > 0.408 else "w"
text_kwargs = dict(
color=text_color,
ha="center",
va="center",
size=annotation_fontsize,
)
if not np.isnan(data[i, j]):
ax.text(j, i, data[i, j], **text_kwargs)
ax.tick_params(axis="y", which="major", pad=8)
add_figure_footnote(fig, footnote=footnote, fontsize=footnote_fontsize)
if save_to_file:
plt.savefig(
save_to_file,
dpi=dpi,
bbox_inches="tight",
)
if show:
plt.show()
return
if return_figure:
return fig
[docs] def plot_metric(
self,
x_axis_column,
y_axis_column,
metric,
xcats=None,
cids=None,
start=None,
end=None,
freq=None,
agg="mean",
figsize: Optional[Tuple[Number, Number]] = (12, 8),
x_axis_label: Optional[str] = None,
y_axis_label: Optional[str] = None,
axis_fontsize: int = 14,
title: Optional[str] = None,
title_fontsize: int = 22,
title_xadjust: Number = 0.5,
title_yadjust: Number = 1.0,
footnote: Optional[str] = None,
footnote_fontsize: int = 9,
vmin: Optional[Number] = None,
vmax: Optional[Number] = None,
show: bool = True,
save_to_file: Optional[str] = None,
dpi: int = 300,
return_figure: bool = False,
on_axis: Optional[plt.Axes] = None,
max_xticks: int = 50,
cmap: Optional[Union[str, mpl.colors.Colormap]] = None,
rotate_xticks: Optional[Number] = 0,
rotate_yticks: Optional[Number] = 0,
show_tick_lines: Optional[bool] = True,
show_colorbar: Optional[bool] = True,
show_annotations: Optional[bool] = False,
show_boundaries: Optional[bool] = False,
annotation_fontsize: int = 14,
tick_fontsize: int = 13,
*args,
**kwargs,
):
"""
Plots a metric from the DataFrame as a heatmap.
Parameters
----------
x_axis_column : str
the column to be used as the x-axis.
y_axis_column : str
the column to be used as the y-axis.
metric : str
the metric to be plotted.
xcats : List[str]
a list of xcats to select from the DataFrame. If None, all xcats are selected.
cids : List[str]
a list of cids to select from the DataFrame. If None, all cids are selected.
start : str
ISO-8601 formatted date string. Select data from this date onwards. If None,
all dates are selected.
end : str
ISO-8601 formatted date string. Select data up to and including this date. If
None, all dates are selected.
freq : str
frequency to downsample the data. Default is None.
agg : str
aggregation method. Must be one of 'mean', 'median', 'min', 'max', 'first' or
'last'.
figsize : Tuple[float, float]
tuple specifying the size of the figure. Default is (12, 8).
x_axis_label : str
label for x-axis.
y_axis_label : str
label for y-axis.
axis_fontsize : int
the font size for the axis labels.
title : str
the figure's title.
title_fontsize : int
the font size for the title.
title_xadjust : float
sets the x position of the title text.
title_yadjust : float
sets the y position of the title text.
footnote : str
Optional text shown at the bottom-left of the figure canvas.
footnote_fontsize : int
Font size of the footnote. Default is 9.
vmin : float
optional minimum value for heatmap scale.
vmax : float
optional maximum value for heatmap scale.
show : bool
if True, the image is displayed.
save_to_file : str
the path at which to save the heatmap as an image. If not specified, the
plot will not be saved.
dpi : int
the resolution in dots per inch used if saving the figure.
return_figure : bool
if True, the function will return the figure.
on_axis : plt.Axes
optional `plt.Axes` object to be used instead of creating a new one.
max_xticks : int
the maximum number of ticks to be displayed along the x axis. Default is 50.
cmap : mpl.colors.Colormap
string or matplotlib Colormap object specifying the colormap of the plot.
rotate_xticks : int
number of degrees to rotate the tick labels on the x-axis. Default is zero.
rotate_yticks : int
number of degrees to rotate the tick labels on the y-axis. Default is zero.
show_tick_lines : bool
if True, lines are shown for ticks. Default is True.
show_colorbar : bool
if True, the colorbar is shown. Default is True.
show_annotations : bool
if True, annotations display the value of each cell. Default is False.
show_boundaries : bool
if True, cells are divided by a grid. Default is False.
annotation_fontsize : int
sets the font size of the annotations.
tick_fontsize : int
sets the font size of tick labels.
"""
df = self.df.copy()
if not xcats:
xcats = self.xcats
if not cids:
cids = self.cids
if not start:
start = self.start
if not end:
end = self.end
# Validation checks not covered by Plotter.
if metric not in ["value", "eop_lag", "mop_lag", "grading"]:
raise ValueError(
"`metric` must be either 'eop_lag', 'mop_lag', 'grading', or 'value'"
)
if not isinstance(agg, str):
raise TypeError("`agg` must be a string")
else:
agg: str = agg.lower()
if agg not in ["mean", "median", "min", "max", "first", "last"]:
raise ValueError(
"`agg` must be one of 'mean', 'median', 'min', 'max', 'first' or "
"'last'"
)
df = df[["xcat", "cid", "real_date", metric]]
if freq:
df: pd.DataFrame = downsample_df_on_real_date(
df=df, groupby_columns=["cid", "xcat"], freq=freq, agg=agg
)
if "real_date" not in [x_axis_column, y_axis_column]:
df = df.groupby(["xcat", "cid"], observed=True).mean().reset_index()
else:
df["real_date"] = df["real_date"].dt.strftime("%Y-%m-%d")
vmax: float = max(1, df[metric].max())
vmin: float = min(0, df[metric].min())
df = df.pivot_table(
index=y_axis_column,
columns=x_axis_column,
values=metric,
observed=True,
)
if figsize is None:
figsize = (
max(df.shape[0] / 2, 15),
max(1, df.shape[1] / 2),
)
elif isinstance(figsize, list):
figsize = tuple(figsize)
fig = self.plot(
df=df,
figsize=figsize,
x_axis_label=x_axis_label,
y_axis_label=y_axis_label,
axis_fontsize=axis_fontsize,
title=title,
title_fontsize=title_fontsize,
title_xadjust=title_xadjust,
title_yadjust=title_yadjust,
footnote=footnote,
footnote_fontsize=footnote_fontsize,
vmin=vmin,
vmax=vmax,
show=show,
save_to_file=save_to_file,
dpi=dpi,
return_figure=return_figure,
on_axis=on_axis,
max_xticks=max_xticks,
cmap=cmap,
rotate_xticks=rotate_xticks,
rotate_yticks=rotate_yticks,
show_tick_lines=show_tick_lines,
show_colorbar=show_colorbar,
show_annotations=show_annotations,
show_boundaries=show_boundaries,
annotation_fontsize=annotation_fontsize,
tick_fontsize=tick_fontsize,
)
if return_figure:
return fig
if __name__ == "__main__":
test_cids: List[str] = [
"USD",
] # "EUR", "GBP"]
test_xcats: List[str] = ["FX", "IR"]
dfE: pd.DataFrame = make_test_df(
cids=test_cids, xcats=test_xcats, style="sharp-hill"
)
dfM: pd.DataFrame = make_test_df(
cids=test_cids, xcats=test_xcats, style="four-bit-sine"
)
dfG: pd.DataFrame = make_test_df(cids=test_cids, xcats=test_xcats, style="sine")
dfE.rename(columns={"value": "eop_lag"}, inplace=True)
dfM.rename(columns={"value": "mop_lag"}, inplace=True)
dfG.rename(columns={"value": "grading"}, inplace=True)
mergeon = ["cid", "xcat", "real_date"]
dfx: pd.DataFrame = pd.merge(pd.merge(dfE, dfM, on=mergeon), dfG, on=mergeon)
heatmap = Heatmap(df=dfx, xcats=["FX"])
heatmap.df["real_date"] = heatmap.df["real_date"].dt.strftime("%Y-%m-%d")
heatmap.df = heatmap.df.pivot_table(
index="cid", columns="real_date", values="grading"
)
heatmap.plot(heatmap.df, title="abc", rotate_xticks=90)