Source code for macrosynergy.visuals.lagged_corr

"""
Functions used to visualize lagged correlation between two series.
"""

from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from macrosynergy.visuals import FacetPlot


[docs]def plot_lagged_correlation( df: pd.DataFrame, cids: List[str], xcats: List[str], lags: Union[int, Sequence] = 3, alpha: float = 0.05, remove_zero_predictor: bool = False, start: Optional[str] = None, end: Optional[str] = None, blacklist: Optional[Dict[str, List[str]]] = None, figsize: Tuple[float, float] = (16, 9), title: Optional[str] = None, share_x: bool = True, share_y: bool = True, zero: bool = False, **kwargs, ): """ Plots a facet grid of lagged correlation plots for two given xcats and multiple cids. Parameters: ----------- df : pd.DataFrame The input DataFrame with columns ['real_date', 'cid', 'xcat', 'value']. cids : List[str] List of cids to plot. xcats : List[str] A list of two xcats to plot the lagged correlation between. lags : Union[int, Sequence], default=30 Number of lags for the correlation calculation. If an integer, the lags from 0 to lags are plotted. If a sequence is provided, the lags are plotted as given. remove_zero_predictor : bool, default=False Remove zeros from the input series. blacklist : dict cross-sections with date ranges that should be excluded from the data frame. If one cross-section has several blacklist periods append numbers to the cross-section code. start : str ISO-8601 formatted date string. Select data from this date onwards. If None, all dates are selected. end : str ISO-8601 formatted date string. Select data up to and including this date. If None, all dates are selected. figsize : Tuple[float, float], default=(16,9) Figure size for the plot. title : Optional[str], default=None Title for the plot. share_x : bool, default=True Share x-axis across all subplots. share_y : bool, default=True Share y-axis across all subplots. kwargs : Dict Additional keyword arguments for the plot passed directly to Facetplot.lineplot. """ _checks_plot_lc( df=df, cids=cids, xcats=xcats, lags=lags, remove_zero_predictor=remove_zero_predictor, start=start, end=end, blacklist=blacklist, figsize=figsize, title=title, share_x=share_x, share_y=share_y, ) if title is None: title = f"Lagged correlation for {xcats[0]} and {xcats[1]}" plot_func = _plot_lagged_corr plot_func_kwargs = { "lags": lags, "alpha": alpha, "zero": zero, "signal_xcat": xcats[0], "target_xcat": xcats[1], "remove_zero_predictor": remove_zero_predictor, } _lagged_corr_facetplot_wrapper( df=df, cids=cids, xcats=xcats, plot_func=plot_func, plot_func_kwargs=plot_func_kwargs, start=start, end=end, blacklist=blacklist, figsize=figsize, title=title, share_x=share_x, share_y=share_y, **kwargs, )
def _lagged_corr_facetplot_wrapper( df: pd.DataFrame, cids: List[str], xcats: List[str], plot_func: Callable, plot_func_kwargs: Dict, start: Optional[str] = None, end: Optional[str] = None, blacklist: Optional[Dict[str, List[str]]] = None, figsize: Tuple[float, float] = (16, 9), title: Optional[str] = None, share_x: bool = True, share_y: bool = True, **kwargs, ): with FacetPlot( df=df, xcats=xcats, cids=cids, intersect=True, start=start, end=end, blacklist=blacklist, tickers=None, metrics=["value"], ) as fp: if len(fp.cids) <= 3: kwargs["ncols"] = len(fp.cids) fp.cids = sorted(fp.cids) fp.lineplot( plot_func=plot_func, plot_func_kwargs=plot_func_kwargs, share_x=share_x, share_y=share_y, figsize=figsize, title=title, cid_grid=True, interpolate=True, legend=False, **kwargs, ) def _plot_lagged_corr( df, plt_dict, signal_xcat, target_xcat, ax=None, lags=[0, 1, 2, 3], remove_zero_predictor=True, **kwargs, ): """ Compute and plot cross-correlation. """ if isinstance(lags, int): lags = list(range(lags + 1)) cid = plt_dict["Y"][0].split("_")[0] target_df = ( df.loc[(cid, target_xcat), ["real_date", "value"]] .rename(columns={"value": "value_target"}) .reset_index(drop=True) ) signal_df = ( df.loc[(cid, signal_xcat), ["real_date", "value"]] .rename(columns={"value": "value_signal"}) .reset_index(drop=True) ) merged_df = target_df.merge(signal_df, on="real_date") cross_corrs = [] for lag in lags: shifted_signal = merged_df["value_signal"].shift(lag) shifted_target = merged_df["value_target"] valid_mask = shifted_signal.notna() & shifted_target.notna() if remove_zero_predictor: valid_mask &= shifted_signal != 0 corr = ( shifted_signal[valid_mask].corr(shifted_target[valid_mask]) if valid_mask.any() else np.nan ) cross_corrs.append(corr) if ax is None: fig, ax = plt.subplots(figsize=(8, 5)) ax.stem(lags, cross_corrs) ax.axhline(0, color="black", linestyle="--", lw=1) ax.set_title(cid) plt.xticks(lags) return ax def _checks_plot_lc( df: pd.DataFrame, cids: List[str], xcats: List[str], lags: int = 30, remove_zero_predictor: bool = False, start: Optional[str] = None, end: Optional[str] = None, blacklist: Optional[Dict[str, List[str]]] = None, figsize: Tuple[float, float] = (16, 9), title: Optional[str] = None, share_x: bool = True, share_y: bool = True, ): if not isinstance(df, pd.DataFrame): raise TypeError("`df` must be a pandas DataFrame.") if len(df.columns) < 4: df = df.copy().reset_index() if not isinstance(lags, (int, np.ndarray, list, tuple)): raise TypeError("`lags` must be an integer or list of integers.") if not isinstance(remove_zero_predictor, bool): raise TypeError("`remove_zero_predictor` must be a boolean.") if start is None: start: str = pd.Timestamp(df["real_date"].min()).strftime("%Y-%m-%d") if end is None: end: str = pd.Timestamp(df["real_date"].max()).strftime("%Y-%m-%d") if not isinstance(xcats, list): raise TypeError("`xcat` must be a string.") if not all(isinstance(xcat, str) for xcat in xcats): raise TypeError("All elements in `xcats` must be strings)") if isinstance(cids, str): cids: List[str] = [cids] if not isinstance(cids, list): raise TypeError("`cids` must be a list.") if not all(isinstance(cid, str) for cid in cids): raise TypeError("All elements in `cids` must be strings.") if blacklist: if not isinstance(blacklist, dict): raise TypeError("`blacklist` must be a dictionary.") for key, value in blacklist.items(): if not isinstance(key, str): raise TypeError("Keys in `blacklist` must be strings.") if not isinstance(value, list): raise TypeError("Values in `blacklist` must be lists.") if not isinstance(figsize, tuple): raise TypeError("`figsize` must be a tuple.") if title is not None and not isinstance(title, str): raise TypeError("`title` must be a string.") if not isinstance(share_x, bool): raise TypeError("`share_x` must be a boolean.") if not isinstance(share_y, bool): raise TypeError("`share_y` must be a boolean.") if __name__ == "__main__": import numpy as np from macrosynergy.management.simulate import make_test_df from macrosynergy.visuals import FacetPlot np.random.seed(42) cids: List[str] = [ "USD", "EUR", "GBP", "AUD", "CAD", "JPY", "CHF", "NZD", "SEK", "NOK", "DKK", "INR", ] xcats: List[str] = [ "FXXR", "EQXR", "RIR", "IR", "REER", "CPI", "PPI", "M2", "M1", "M0", "FXVOL", "FX", ] sel_cids: List[str] = [ "USD", "EUR", "GBP", "AUD", "CAD", "JPY", "CHF", "NZD", ] # ["USD", "EUR", "GBP"] sel_xcats: List[str] = ["FXXR", "EQXR", "RIR", "IR"] r_styles: List[str] = [ "linear", "decreasing-linear", "sharp-hill", "sine", "four-bit-sine", ] df: pd.DataFrame = make_test_df( cids=list(set(cids) - set(sel_cids)), xcats=xcats, start="2000-01-01", ) for rstyle, xcatx in zip(r_styles, sel_xcats): dfB: pd.DataFrame = make_test_df( cids=sel_cids, xcats=[xcatx], start="2000-01-01", style=rstyle, ) df: pd.DataFrame = pd.concat([df, dfB], axis=0) for ix, cidx in enumerate(sel_cids): df.loc[df["cid"] == cidx, "value"] = ( ((df[df["cid"] == cidx]["value"]) * (ix + 1)).reset_index(drop=True).copy() ) for ix, xcatx in enumerate(sel_xcats): df.loc[df["xcat"] == xcatx, "value"] = ( ((df[df["xcat"] == xcatx]["value"]) * (ix * 10 + 1)) .reset_index(drop=True) .copy() ) df.loc[df["xcat"] == "EQXR", "value"] *= ( np.arange(len(df.loc[df["xcat"] == "EQXR", "value"])) % 20 == 0 ) df["grading"] = np.nan plot_lagged_correlation( df, cids=sel_cids, xcats=["EQXR", "FXXR"], # title="ccf Facet Plot", remove_zero_predictor=True, lags=[1, 2], )