Source code for macrosynergy.download.dataquery

"""
Interface for downloading data from the JPMorgan DataQuery API. This module is not
intended to be used directly, but rather through macrosynergy.download.jpmaqs.py.
However, for a use cases independent of JPMaQS, this module can be used directly to
download data from the JPMorgan DataQuery API.
"""

import concurrent.futures
import time
import os
import logging
import itertools
import base64
import uuid
import io
import warnings
import requests
from datetime import datetime, timezone
from typing import List, Optional, Dict, Union, Tuple, Any
from tqdm import tqdm

from macrosynergy import __version__ as ms_version_info
from macrosynergy.download.exceptions import (
    AuthenticationError,
    DownloadError,
    InvalidResponseError,
    HeartbeatError,
    NoContentError,
    KNOWN_EXCEPTIONS,
)
from macrosynergy.download.jpm_oauth import JPMorganOAuth
from macrosynergy.management.utils import (
    is_valid_iso_date,
    form_full_url,
)

CERT_BASE_URL: str = "https://platform.jpmorgan.com/research/dataquery/api/v2"
OAUTH_BASE_URL: str = (
    "https://api-developer.jpmorgan.com/research/dataquery-authe/api/v2"
)
OAUTH_TOKEN_URL: str = "https://authe.jpmchase.com/as/token.oauth2"
OAUTH_DQ_RESOURCE_ID: str = "JPMC:URI:RS-06785-DataQueryExternalApi-PROD"
JPMAQS_GROUP_ID: str = "JPMAQS"
API_DELAY_PARAM: float = 0.25  # 250ms delay between requests (api doc says 200ms)
TOKEN_EXPIRY_BUFFER: float = 0.9  # 90% of token expiry time.
API_RETRY_COUNT: int = 5  # retry count for transient errors
HL_RETRY_COUNT: int = 5  # retry count for "high-level" requests
MAX_CONTINUOUS_FAILURES: int = 5  # max number of continuous errors before stopping
HEARTBEAT_ENDPOINT: str = "/services/heartbeat"
TIMESERIES_ENDPOINT: str = "/expressions/time-series"
CATALOGUE_ENDPOINT: str = "/group/instruments"
HEARTBEAT_TRACKING_ID: str = "heartbeat"
OAUTH_TRACKING_ID: str = "oauth"
TIMESERIES_TRACKING_ID: str = "timeseries"
CATALOGUE_TRACKING_ID: str = "catalogue"

logger: logging.Logger = logging.getLogger(__name__)
debug_stream_handler = logging.StreamHandler(io.StringIO())
debug_stream_handler.setLevel(logging.NOTSET)
debug_stream_handler.setFormatter(
    logging.Formatter(
        "%(asctime)s - %(levelname)s - %(module)s - %(funcName)s :: %(message)s"
    )
)
logger.addHandler(debug_stream_handler)


[docs]def validate_response( response: requests.Response, user_id: str, ) -> dict: """ Validates a response from the API. Raises an exception if the response is invalid (e.g. if the response is not a 200 status code). Parameters ---------- response : requests.Response response object from requests.request(). Raises ------ InvalidResponseError if the response is not valid. AuthenticationError if the response is a 401 status code. KeyboardInterrupt if the user interrupts the download. Returns ------- dict response as a dictionary. If the response is not valid, this function will raise an exception. """ if not response.ok: logger.info("Response status is NOT OK : %s", response.status_code) error_str = format_invalid_response_msg(response, user_id) if response.status_code == 401: raise AuthenticationError(error_str) if HEARTBEAT_ENDPOINT in response.request.url: raise HeartbeatError(error_str) raise InvalidResponseError( f"Request did not return a 200 status code.\n{error_str}" ) try: response_dict = response.json() if response_dict is None: error_str = format_invalid_response_msg(response, user_id) raise InvalidResponseError(f"Response is empty.\n{error_str}") return response_dict except Exception as exc: if isinstance(exc, KeyboardInterrupt): raise exc error_str = format_invalid_response_msg(response, user_id) raise InvalidResponseError(error_str + f"Error parsing response as JSON: {exc}")
[docs]def format_invalid_response_msg(response: requests.Response, user_id: str) -> str: """ This function formats an error message for an invalid response from the API. Should only be called if there is an error in the response (as the functions adds the response text to the error message). """ error_str: str = ( f"Response: {response}\n" f"User ID: {user_id}\n" f"Requested URL: {response.request.url}\n" f"Response status code: {response.status_code}\n" f"Response headers: {response.headers}\n" f"Response text: {response.text}\n" f"DataQuery Interaction ID: {response.headers.get('x-dataquery-interaction-id', 'N/A')}\n" f"Timestamp (UTC): {datetime.now(timezone.utc).isoformat()}; \n" ) return error_str
[docs]def request_wrapper( url: str, headers: Optional[Dict] = None, params: Optional[Dict] = None, method: str = "get", tracking_id: Optional[str] = None, proxy: Optional[Dict] = None, cert: Optional[Tuple[str, str]] = None, **kwargs, ) -> dict: """ Wrapper for requests.request() that handles retries and logging. All parameters and kwargs are passed to requests.request(). Parameters ---------- url : str URL to request. headers : dict headers to pass to requests.request(). params : dict params to pass to requests.request(). method : str HTTP method to use. Must be one of "get" or "post". Defaults to "get". kwargs : dict kwargs to pass to requests.request(). tracking_id : str default None, unique tracking ID of request. proxy : dict default None, dictionary of proxy settings for request. cert : Tuple[str, str] default None, tuple of string for filename of certificate and key. Raises ------ InvalidResponseError if the response is not valid. AuthenticationError if the response is a 401 status code. DownloadError if the request fails after retrying. KeyboardInterrupt if the user interrupts the download. ValueError if the method is not one of "get" or "post". Exception other exceptions may be raised by requests.request(). Returns ------- dict response as a dictionary. """ if method not in ["get", "post"]: raise ValueError(f"Invalid method: {method}") user_id: str = kwargs.pop("user_id", "unknown") verify: bool = kwargs.pop("verify", True) # insert tracking info in headers if headers is None: headers: Dict = {} if "User-Agent" not in headers: headers["User-Agent"] = f"MacrosynergyPackage/{ms_version_info}" uuid_str: str = str(uuid.uuid4()) if (tracking_id is None) or (tracking_id == ""): tracking_id: str = uuid_str else: tracking_id: str = f"uuid::{uuid_str}::{tracking_id}" headers["X-Tracking-Id"] = tracking_id log_url: str = form_full_url(url, params) logger.debug(f"Requesting URL: {log_url} with tracking_id: {tracking_id}") raised_exceptions: List[Exception] = [] error_statements: List[str] = [] error_statement: str = "" retry_count: int = 0 response: Optional[requests.Response] = None while retry_count < API_RETRY_COUNT: try: prepared_request: requests.PreparedRequest = requests.Request( method, url, headers=headers, params=params, **kwargs ).prepare() with requests.Session().send( prepared_request, proxies=proxy, cert=cert, timeout=300, verify=verify, ) as response: if isinstance(response, requests.Response): return validate_response(response=response, user_id=user_id) else: raise InvalidResponseError( f"Request did not return a response.\n" f"User ID: {user_id}\n" f"Requested URL: {log_url}\n" f"Timestamp (UTC): {datetime.now(timezone.utc).isoformat()}; \n" ) except Exception as exc: # if keyboard interrupt, raise as usual if isinstance(exc, KeyboardInterrupt): print("KeyboardInterrupt -- halting download") raise exc # authentication error, clearly not a transient error if isinstance(exc, AuthenticationError): raise exc error_statement = ( f"Request to {log_url} failed with error {exc}. " f"User ID: {user_id}. " f"Retry count: {retry_count}. " f"Tracking ID: {tracking_id}" ) raised_exceptions.append(exc) error_statements.append(error_statement) known_exceptions = KNOWN_EXCEPTIONS + [HeartbeatError] # NOTE : HeartBeat is a special case # NOTE: exceptions that need the code to break should be caught before this # all other exceptions are caught here and retried after a delay if any([isinstance(exc, e) for e in known_exceptions]): logger.warning(error_statement) retry_count += 1 time.sleep(API_DELAY_PARAM) else: raise exc if isinstance(raised_exceptions[-1], HeartbeatError): raise HeartbeatError(error_statement) errs_str = "\n\n".join( ("\t" + str(e) + " - \n\t\t" + est) for e, est in zip(raised_exceptions, error_statements) ) e_str = f"Request to {log_url} failed with error {raised_exceptions[-1]}. \n" e_str += "-" * 20 + "\n" if isinstance(response, requests.Response): e_str += f" Status code: {response.status_code}." e_str += ( f" No longer retrying. Tracking ID: {tracking_id}" f"Exceptions raised:\n{errs_str}" ) raise DownloadError(e_str)
[docs]class DataQueryOAuth(JPMorganOAuth): def __init__( self, client_id: str, client_secret: str, proxy: Optional[dict] = None, token_url: str = OAUTH_TOKEN_URL, dq_base_url: str = OAUTH_BASE_URL, dq_resource_id: str = OAUTH_DQ_RESOURCE_ID, application_name: str = "DataQueryHttpAPI", **kwargs, ): super().__init__( client_id=client_id, client_secret=client_secret, auth_url=token_url, root_url=dq_base_url, resource=dq_resource_id, proxies=proxy, application_name=application_name, **kwargs, )
[docs]class DataQueryCertAuth(object): """ Class for handling certificate based authentication for the DataQuery API. Parameters ---------- username : str username for the DataQuery API. password : str password for the DataQuery API. crt : str path to the certificate file. key : str path to the key file. Raises ------ AssertionError if any of the parameters are of the wrong type. FileNotFoundError if certificate or key file is missing from filesystem. Exception other exceptions may be raised by underlying functions. """ def __init__( self, username: str, password: str, crt: str, key: str, proxy: Optional[dict] = None, ): for varx, namex in zip([username, password], ["username", "password"]): if not isinstance(varx, str): raise TypeError(f"{namex} must be a <str> and not {type(varx)}.") auth_str = f"{username:s}:{password:s}" self.auth: str = base64.b64encode(auth_str.encode("utf-8")).decode("utf-8") # Key and Certificate check for varx, namex in zip([crt, key], ["crt", "key"]): if not isinstance(varx, str): raise TypeError(f"{namex} must be a <str> and not {type(varx)}.") if not os.path.isfile(varx): raise FileNotFoundError(f"The file '{varx}' does not exist.") self.key: str = key self.crt: str = crt self.username: str = username self.password: str = password self.proxy: Optional[dict] = proxy def _get_user_id(self) -> str: return "CertAuth_Username - " + self.username
[docs] def get_headers(self) -> dict: headers = { "Authorization": f"Basic {self.auth:s}", "User-Agent": f"MacrosynergyPackage/DataQueryHttpAPI-CertAuth/{ms_version_info}", } return headers
[docs] def get_auth(self) -> Dict[str, Union[str, Optional[Tuple[str, str]]]]: """ Returns a dictionary with the authentication information, in the same format as the `macrosynergy.download.dataquery.OAuth.get_auth()` method. """ headers = {"Authorization": f"Basic {self.auth:s}"} return { "headers": headers, "cert": (self.crt, self.key), }
[docs]def validate_download_args( expressions: List[str], start_date: str, end_date: str, show_progress: bool, endpoint: str, calender: str, frequency: str, conversion: str, nan_treatment: str, reference_data: str, retry_counter: int, delay_param: float, batch_size: int, ): """ Validate the arguments passed to the `download_data()` method. Raises ------ TypeError if any of the arguments are of the wrong type. ValueError if any of the arguments are semantically incorrect. Returns ------- bool True if all arguments are valid. """ if expressions is None: raise ValueError("`expressions` must be a list of strings.") if not isinstance(expressions, list): raise TypeError("`expressions` must be a list of strings.") if not all(isinstance(expr, str) for expr in expressions): raise TypeError("`expressions` must be a list of strings.") for varx, namex in zip([start_date, end_date], ["start_date", "end_date"]): if (varx is None) or not isinstance(varx, str): raise TypeError(f"`{namex}` must be a string.") if not is_valid_iso_date(varx): raise ValueError( f"`{namex}` must be a string in the ISO-8601 format (YYYY-MM-DD)." ) if not isinstance(show_progress, bool): raise TypeError("`show_progress` must be a boolean.") if not isinstance(retry_counter, int): raise TypeError("`retry_counter` must be an integer.") if not isinstance(delay_param, float): raise TypeError("`delay_param` must be a float >=0.2 (seconds).") if delay_param < 0.0: raise ValueError("`delay_param` must be a float >=0.2 (seconds).") if delay_param < 0.2: warnings.warn( RuntimeWarning( "`delay_param` is too low; DataQuery API may reject requests. " "Minimum recommended value is 0.2 seconds. " ) ) if not isinstance(batch_size, int): raise TypeError("`batch_size` must be an integer.") if batch_size < 1: raise ValueError("`batch_size` must be an integer >=1.") elif batch_size > 20: warnings.warn( RuntimeWarning( "`batch_size` is too high; DataQuery API's time-series endpoint " "accepts a maximum of 20 expressions per request. " ) ) vars_types_zip: zip = zip( [ endpoint, calender, frequency, conversion, nan_treatment, reference_data, ], [ "endpoint", "calender", "frequency", "conversion", "nan_treatment", "reference_data", ], ) for varx, namex in vars_types_zip: if not isinstance(varx, str): raise TypeError(f"`{namex}` must be a string.") return True
[docs]class DataQueryInterface(object): """ High level interface for the DataQuery API. When using OAuth authentication: Parameters ---------- client_id : str client ID for the OAuth application. client_secret : str client secret for the OAuth application. When using certificate authentication: crt : str path to the certificate file. key : str path to the key file. username : str username for the DataQuery API. password : str password for the DataQuery API. oauth : bool whether to use OAuth authentication. Defaults to True. debug : bool whether to print debug messages. Defaults to False. concurrent : bool whether to use concurrent requests. Defaults to True. batch_size : int default 20, number of expressions to send in a single request. Must be a number between 1 and 20 (both included). check_connection : bool whether to send a check_connection request. Defaults to True. base_url : str base URL for the DataQuery API. Defaults to OAUTH_BASE_URL if `oauth` is True, CERT_BASE_URL otherwise. token_url : str token URL for the DataQuery API. Defaults to OAUTH_TOKEN_URL. suppress_warnings : bool whether to suppress warnings. Defaults to True. custom_auth : Any custom authentication object. When specified oauth must be False and the object must have a get_auth method. Defaults to None. Raises ------ TypeError if any of the parameters are of the wrong type. ValueError if any of the parameters are semantically incorrect. InvalidResponseError if the response from the server is not valid. DownloadError if the download fails to complete after a number of retries. HeartbeatError if the heartbeat (check connection) fails. Exception other exceptions may be raised by underlying functions. """ def __init__( self, client_id: Optional[str] = None, client_secret: Optional[str] = None, crt: Optional[str] = None, key: Optional[str] = None, username: Optional[str] = None, password: Optional[str] = None, proxy: Optional[dict] = None, oauth: bool = True, debug: bool = False, batch_size: int = 20, check_connection: bool = True, base_url: str = OAUTH_BASE_URL, token_url: str = OAUTH_TOKEN_URL, suppress_warning: bool = True, custom_auth=None, verify: bool = True, ): self._check_connection: bool = check_connection self.msg_errors: List[str] = [] self.msg_warnings: List[str] = [] self.unavailable_expressions: List[str] = [] self.debug: bool = debug self.suppress_warning: bool = suppress_warning self.batch_size: int = batch_size for varx, namex, typex in [ (client_id, "client_id", str), (client_secret, "client_secret", str), (crt, "crt", str), (key, "key", str), (username, "username", str), (password, "password", str), (proxy, "proxy", dict), ]: if not isinstance(varx, typex) and varx is not None: raise TypeError(f"{namex} must be a {typex} and not {type(varx)}.") self.auth: Optional[Union[DataQueryCertAuth, DataQueryOAuth]] = None if oauth and not all([client_id, client_secret]): warnings.warn( "OAuth authentication requested but client ID and/or client secret " "not found. Falling back to certificate authentication.", UserWarning, ) if not all([username, password, crt, key]): raise ValueError( "Certificate credentials not found. " "Check the parameters passed to the DataQueryInterface class." ) else: oauth: bool = False self.verify: bool = verify if oauth: self.auth: DataQueryOAuth = DataQueryOAuth( client_id=client_id, client_secret=client_secret, token_url=token_url, proxy=proxy, verify=self.verify, ) elif custom_auth is not None: self.auth = custom_auth else: if base_url == OAUTH_BASE_URL: base_url: str = CERT_BASE_URL self.auth: DataQueryCertAuth = DataQueryCertAuth( username=username, password=password, crt=crt, key=key, proxy=proxy, ) assert self.auth is not None, ( "Unable to instantiate authentication object. " "Check the parameters passed to the DataQueryInterface class." ) self.proxy: Optional[dict] = proxy self.base_url: str = base_url def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): if exc_type: logger.error("Exception %s - %s", exc_type, exc_value) print(f"Exception: {exc_type} {exc_value}") def _get_unavailable_expressions( self, expected_exprs: List[str], dicts_list: List[Dict], ) -> List[str]: """ Method to get the expressions that are not available in the response. Looks at the dict["attributes"][0]["expression"] field of each dict in the list. Parameters ---------- expected_exprs : List[str] list of expressions that were requested. dicts_list : List[Dict] list of dicts to search for the expressions. Returns ------- List[str] list of expressions that were not found in the dicts. """ found_exprs: List[str] = [ curr_dict["attributes"][0]["expression"] for curr_dict in dicts_list if curr_dict["attributes"][0]["time-series"] is not None ] return list(set(expected_exprs) - set(found_exprs))
[docs] def check_connection(self, verbose=False, raise_error: bool = False) -> bool: """ Check the connection to the DataQuery API using the Heartbeat endpoint. Parameters ---------- verbose : bool whether to print a message if the heartbeat is successful. Useful for debugging. Defaults to False. Raises ------ HeartbeatError if the heartbeat fails. Returns ------- bool True if the connection is successful, False otherwise. """ logger.debug(f"Sleep before checking connection - {API_DELAY_PARAM} seconds") time.sleep(API_DELAY_PARAM) logger.debug("Check if connection can be established to JPMorgan DataQuery") js: dict = request_wrapper( url=self.base_url + HEARTBEAT_ENDPOINT, params={"data": "NO_REFERENCE_DATA"}, proxy=self.proxy, tracking_id=HEARTBEAT_TRACKING_ID, verify=self.verify, **self.auth.get_auth(), ) result: bool = True if (js is None) or (not isinstance(js, dict)) or ("info" not in js): logger.warning("Connection to JPMorgan DataQuery heartbeat failed") result: bool = False if result: result = (int(js["info"]["code"]) == 200) and ( js["info"]["message"] == "Service Available." ) if verbose: print("Connection successful!" if result else "Connection failed.") if raise_error and not result: raise ConnectionError(HeartbeatError("Heartbeat failed.")) return result
def _fetch( self, url: str, params: dict = None, tracking_id: Optional[str] = None, ) -> List[Dict]: """ Make a request to the DataQuery API using the specified parameters. Used to wrap a request in a thread for concurrent requests, or to simplify the code for single requests. Used by the `_fetch_timeseries()` method. Parameters ---------- url : str URL to request. params : dict parameters to send with the request. proxy : dict proxy to use for the request. tracking_id : str tracking ID to use for the request. Raises ------ InvalidResponseError if the response from the server is not valid. Exception other exceptions may be raised by underlying functions. Returns ------- List[Dict] list of dictionaries containing the response data. """ downloaded_data: List[Dict] = [] response: Dict = request_wrapper( url=url, params=params, proxy=self.proxy, tracking_id=tracking_id, verify=self.verify, **self.auth.get_auth(), ) if (response is None) or ("instruments" not in response.keys()): if response is not None: if ( ("info" in response) and ("code" in response["info"]) and (int(response["info"]["code"]) == 204) ): raise NoContentError( f"Content was not found for the request: {response}\n" f"User ID: {self.auth._get_user_id()}\n" f"URL: {form_full_url(url, params)}\n" f"Timestamp (UTC): {datetime.now(timezone.utc).isoformat()}" ) raise InvalidResponseError( f"Invalid response from DataQuery: {response}\n" f"User ID: {self.auth._get_user_id()}\n" f"URL: {form_full_url(url, params)}" f"Timestamp (UTC): {datetime.now(timezone.utc).isoformat()}" ) downloaded_data.extend(response["instruments"]) if "links" in response.keys() and response["links"][1]["next"] is not None: logger.debug("DQ response paginated - get next response page") downloaded_data.extend( self._fetch( url=self.base_url + response["links"][1]["next"], params={}, tracking_id=tracking_id, ) ) return downloaded_data def _fetch_timeseries( self, url: str, params: dict, tracking_id: str = None, *args, **kwargs ) -> List[Dict]: """ Exists to provide a wrapper for the `_fetch()` method that can be modified when inheriting from this class. This method is used by the `_concurrent_loop()` method. """ return self._fetch(url=url, params=params, tracking_id=tracking_id)
[docs] def get_catalogue( self, group_id: str = JPMAQS_GROUP_ID, page_size: int = 1000, verbose: bool = True, ) -> List[str]: """ Method to get the JPMaQS catalogue. Queries the DataQuery API's Groups/Search endpoint to get the list of tickers in the JPMaQS group. The group ID can be changed to fetch a different group's catalogue. Parameters ---------- group_id : str the group ID to fetch the catalogue for. Defaults to "JPMAQS". page_size : int the number of tickers to fetch in a single request. Defaults to 1000 (maximum allowed by the API). Raises ------ ValueError if the response from the server is not valid. Returns ------- List[str] list of all tickers in the requested group. """ if not isinstance(group_id, str): raise TypeError("`group_id` must be a string.") pgsize_err = "`page_size` must be an integer between 1 and 1000." if not isinstance(page_size, int): raise TypeError(pgsize_err) elif (page_size < 1) or (page_size > 1000): raise ValueError(pgsize_err) if verbose: print(f"Downloading the {group_id} catalogue from DataQuery...") try: response_list: Dict = self._fetch( url=self.base_url + CATALOGUE_ENDPOINT, params={"group-id": group_id, "limit": page_size}, tracking_id=CATALOGUE_TRACKING_ID, ) except Exception as e: raise e tickers: List[str] = [d["instrument-name"] for d in response_list] utkr_count: int = len(tickers) tkr_idx: List[int] = sorted([d["item"] for d in response_list]) if not ( (min(tkr_idx) == 1) and (max(tkr_idx) == utkr_count) and (len(set(tkr_idx)) == utkr_count) ): raise ValueError("The downloaded catalogue is corrupt.") if verbose: print(f"Downloaded {group_id} catalogue with {utkr_count} tickers.") return tickers
def _concurrent_loop( self, expr_batches: List[List[str]], show_progress: bool, url: str, params: dict, tracking_id: str, delay_param: float, *args, **kwargs, ) -> Tuple[List[Union[Dict, Any]], List[List[str]]]: """ Concurrent loop to download data from the DataQuery API. Used by the `_download()` method. Returns ------- Tuple[List[Union[Dict, Any]], List[List[str]]] tuple of two lists. The first list contains the downloaded data, and the second list contains the failed batches. """ future_objects: List[concurrent.futures.Future] = [] download_outputs: List[Union[Dict, Any]] = [] failed_batches: List[List[str]] = [] last_five_exc: List[Exception] = [] continuous_failures: int = 0 with concurrent.futures.ThreadPoolExecutor() as executor: for ib, expr_batch in tqdm( enumerate(expr_batches), desc="Requesting data", disable=not show_progress, total=len(expr_batches), ): curr_params: Dict = params.copy() curr_params["expressions"] = expr_batch try: future_objects.append( executor.submit( self._fetch_timeseries, url=url, params=curr_params, tracking_id=tracking_id, *args, **kwargs, ) ) except Exception as exc: raise exc time.sleep(delay_param) for ib, future in tqdm( enumerate(future_objects), desc="Downloading data", disable=not show_progress, total=len(future_objects), ): try: if future.exception() is not None: raise future.exception() download_outputs.append(future.result()) continuous_failures = 0 except Exception as exc: if isinstance(exc, (KeyboardInterrupt, AuthenticationError)): executor.shutdown(wait=False, cancel_futures=True) raise exc failed_batches.append(expr_batches[ib]) self.msg_errors.append(f"Batch {ib} failed with exception: {exc}") continuous_failures += 1 last_five_exc.append(exc) if continuous_failures > MAX_CONTINUOUS_FAILURES: exc_str: str = "\n".join([str(e) for e in last_five_exc]) raise DownloadError( f"Failed {continuous_failures} times to download data." f" Last five exceptions: \n{exc_str}" ) if self.debug: raise exc return download_outputs, failed_batches def _chain_download_outputs( self, download_outputs: List[Union[Dict, Any]], ) -> List[Dict]: """ Chain the download outputs from the concurrent loop into a single list. Used by the `download_data()` method. Exists to provide a method that can be modified when inheriting from this class. Parameters ---------- download_outputs : List[Union[Dict, Any] list of list of dictionaries/ other objects. Returns ------- List[Dict] list of dictionaries/other objects. """ return list(itertools.chain.from_iterable(download_outputs)) def _download( self, expressions: List[str], params: dict, url: str, tracking_id: str, delay_param: float, show_progress: bool = False, retry_counter: int = 0, *args, **kwargs, ) -> List[dict]: """ Backend method to download data from the DataQuery API. Used by the `download_data()` method. """ if 0 < retry_counter < HL_RETRY_COUNT: print("Retrying failed downloads. Retry count:", retry_counter) if retry_counter > HL_RETRY_COUNT: error_str = ( f"Failed {retry_counter} times to download data all requested data.\n" "No longer retrying." ) if len(self.msg_errors) > 0: error_str += "\n".join(self.msg_errors) raise DownloadError(error_str) expr_batches: List[List[str]] = [ expressions[i : i + self.batch_size] for i in range(0, len(expressions), self.batch_size) ] download_outputs: List[List[Dict]] failed_batches: List[List[str]] download_outputs, failed_batches = self._concurrent_loop( expr_batches=expr_batches, show_progress=show_progress, url=url, params=params, tracking_id=tracking_id, delay_param=delay_param, *args, **kwargs, ) if len(failed_batches) > 0: flat_failed_batches: List[str] = list( itertools.chain.from_iterable(failed_batches) ) logger.warning( "Failed batches %d - retry download for %d expressions", len(failed_batches), len(flat_failed_batches), ) retried_output: List[dict] = self._download( expressions=flat_failed_batches, params=params, url=url, tracking_id=tracking_id, delay_param=delay_param + 0.1, show_progress=show_progress, retry_counter=retry_counter + 1, *args, **kwargs, ) download_outputs.extend(retried_output) if retry_counter == 0: return self._chain_download_outputs(download_outputs) return download_outputs
[docs] def download_data( self, expressions: List[str], start_date: str = "2000-01-01", end_date: str = None, show_progress: bool = False, endpoint: str = TIMESERIES_ENDPOINT, calender: str = "CAL_ALLDAYS", frequency: str = "FREQ_DAY", conversion: str = "CONV_LASTBUS_ABS", nan_treatment: str = "NA_NOTHING", reference_data: str = "NO_REFERENCE_DATA", retry_counter: int = 0, delay_param: float = API_DELAY_PARAM, batch_size: Optional[int] = None, *args, **kwargs, ) -> List[Dict]: """ Download data from the DataQuery API. Parameters ---------- expressions : List[str] list of expressions to download. start_date : str start date for the data in the ISO-8601 format (YYYY-MM-DD). end_date : str end date for the data in the ISO-8601 format (YYYY-MM-DD). show_progress : bool whether to show a progress bar for the download. endpoint : str endpoint to use for the download. calender : str calendar setting to use for the download. frequency : str frequency of data points to use for the download. conversion : str conversion setting to use for the download. nan_treatment : str NaN treatment setting to use for the download. reference_data : str reference data to pass to the API kwargs. retry_counter : int number of times the download has been retried. delay_param : float delay between requests to the API. Raises ------ ValueError if any arguments are invalid or semantically incorrect (see validate_download_args()). DownloadError if the download fails. ConnectionError(HeartbeatError) if the heartbeat fails. Exception other exceptions may be raised by underlying functions. Returns ------- List[Dict] list of dictionaries containing the response data. """ tracking_id: str = TIMESERIES_TRACKING_ID if end_date is None: end_date = datetime.today().strftime("%Y-%m-%d") # NOTE : if "future dates" are passed, they must be passed by parent functions # see jpmaqs.py if batch_size is None: batch_size = self.batch_size # NOTE : args validated only on first call, not on retries # this is because the args can be modified by the retry mechanism # (eg. date format) expressions = sorted(expressions) validate_download_args( expressions=expressions, start_date=start_date, end_date=end_date, show_progress=show_progress, endpoint=endpoint, calender=calender, frequency=frequency, conversion=conversion, nan_treatment=nan_treatment, reference_data=reference_data, retry_counter=retry_counter, delay_param=delay_param, batch_size=batch_size, ) self.batch_size = batch_size if datetime.strptime(end_date, "%Y-%m-%d") < datetime.strptime( start_date, "%Y-%m-%d" ): wStr = "Start date ({}) is after end-date ({}): swapping them!" logger.warning(wStr.format(start_date, end_date)) warnings.warn(wStr.format(start_date, end_date), UserWarning) start_date, end_date = end_date, start_date # remove dashes from dates to match DQ format start_date: str = start_date.replace("-", "") end_date: str = end_date.replace("-", "") # check heartbeat before each "batch" of requests if self._check_connection: if not self.check_connection(verbose=True): raise ConnectionError( HeartbeatError( f"Heartbeat failed. Timestamp (UTC):" f" {datetime.now(timezone.utc).isoformat()}\n" f"User ID: {self.auth._get_user_id()}\n" ) ) time.sleep(delay_param) logger.info( "Download %d expressions from DataQuery from %s to %s", len(expressions), datetime.strptime(start_date, "%Y%m%d").date(), datetime.strptime(end_date, "%Y%m%d").date(), ) params_dict: Dict = { "format": "JSON", "start-date": start_date, "end-date": end_date, "calendar": calender, "frequency": frequency, "conversion": conversion, "nan_treatment": nan_treatment, "data": reference_data, } final_output: List[dict] = self._download( expressions=expressions, params=params_dict, url=self.base_url + endpoint, tracking_id=tracking_id, delay_param=delay_param, show_progress=show_progress, *args, **kwargs, ) if ( isinstance(final_output, list) and (len(final_output) > 0) and isinstance(final_output[0], dict) ): self.unavailable_expressions = self._get_unavailable_expressions( expected_exprs=expressions, dicts_list=final_output ) logger.info( "Downloaded expressions: %d, unavailable: %d", len(final_output), len(self.unavailable_expressions), ) return final_output
if __name__ == "__main__": import os client_id: str = os.getenv("DQ_CLIENT_ID") client_secret: str = os.getenv("DQ_CLIENT_SECRET") expressions = ["DB(CFX,GBP,)"] with DataQueryInterface( client_id=client_id, client_secret=client_secret, ) as dq: assert dq.check_connection(verbose=True) data = dq.download_data( expressions=expressions, start_date="2024-01-25", end_date="2024-02-05", show_progress=True, ) print(data) print(f"Succesfully downloaded data for {len(data)} expressions.")