from typing import Any, List import pandas as pd import datetime as dt from Database import Database from shared_utilities import query_satisfied, query_uptodate, unix_time_millis, timeframe_to_minutes import logging # Set up logging logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) class DataCache: """ Fetches manages data limits and optimises memory storage. Handles connections and operations for the given exchanges. Example usage: -------------- db = DataCache(exchanges=some_exchanges_object) """ def __init__(self, exchanges): """ Initializes the DataCache class. :param exchanges: The exchanges object handling communication with connected exchanges. """ # Maximum number of tables to cache at any given time. self.max_tables = 50 # Maximum number of records to be kept per table. self.max_records = 1000 # A dictionary that holds all the cached records. self.cached_data = {} # The class that handles the DB interactions. self.db = Database() # The class that handles exchange interactions. self.exchanges = exchanges def cache_exists(self, key: str) -> bool: """ Return False if a cache doesn't exist for this key. """ return key in self.cached_data def update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None: """ :param more_records: Adds records to existing cache. :param key: The access key. :return: None. """ # Combine the new candles with the previously cached dataframe. records = pd.concat([more_records, self.get_cache(key)], axis=0, ignore_index=True) # Drop any duplicates from overlap. records = records.drop_duplicates(subset="open_time", keep='first') # Sort the records by open_time. records = records.sort_values(by='open_time').reset_index(drop=True) # Replace the incomplete dataframe with the modified one. self.set_cache(data=records, key=key) return def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None: """ Creates a new cache key and inserts some data. Todo: This is where this data will be passed to a cache server. :param data: The records to insert into cache. :param key: The index key for the data. :param do_not_overwrite: - Flag to prevent overwriting existing data. :return: None """ # If the flag is set don't overwrite existing data. if do_not_overwrite and key in self.cached_data: return # Assign the data self.cached_data[key] = data def update_cached_dict(self, cache_key, dict_key: str, data: Any) -> None: """ Updates a dictionary stored in cache. Todo: This is where this data will be passed to a cache server. :param data: - The records to insert into cache. :param cache_key: - The cache index key for the dictionary. :param dict_key: - The dictionary key for the data. :return: None """ # Assign the data self.cached_data[cache_key].update({dict_key: data}) def get_cache(self, key: str) -> Any: """ Returns data indexed by key. Todo: This is where data will be retrieved from a cache server. :param key: The index key for the data. :return: Any|None - The requested data or None on key error. """ if key not in self.cached_data: print(f"[WARNING: DataCache.py] The requested cache key({key}) doesn't exist!") return None return self.cached_data[key] def get_records_since(self, key: str, start_datetime: dt.datetime, record_length: int, ex_details: List[str]) -> pd.DataFrame: """ Return any records from the cache indexed by table_name that are newer than start_datetime. :param ex_details: List[str] - Details to pass to the server :param key: str - The dictionary table_name of the records. :param start_datetime: dt.datetime - The datetime of the first record requested. :param record_length: int - The timespan of the records. :return: pd.DataFrame - The Requested records Example: -------- records = data_cache.get_records_since('BTC/USD_2h_binance', dt.datetime.utcnow() - dt.timedelta(minutes=60), 60, ['BTC/USD', '2h', 'binance']) """ try: # End time of query defaults to the current time. end_datetime = dt.datetime.utcnow() if self.cache_exists(key=key): logger.debug('Getting records from cache.') # If the records exist, retrieve them from the cache. records = self.get_cache(key) else: # If they don't exist in cache, get them from the database. logger.debug( f'Records not in cache. Requesting from DB: starting at: {start_datetime} to: {end_datetime}') records = self.get_records_since_from_db(table_name=key, st=start_datetime, et=end_datetime, rl=record_length, ex_details=ex_details) logger.debug(f'Got {len(records.index)} records from DB.') self.set_cache(data=records, key=key) # Check if the records in the cache go far enough back to satisfy the query. first_timestamp = query_satisfied(start_datetime=start_datetime, records=records, r_length_min=record_length) if first_timestamp: # The records didn't go far enough back if a timestamp was returned. end_time = dt.datetime.utcfromtimestamp(first_timestamp) logger.debug(f'Requesting additional records from {start_datetime} to {end_time}') # Request additional records from the database. additional_records = self.get_records_since_from_db(table_name=key, st=start_datetime, et=end_time, rl=record_length, ex_details=ex_details) logger.debug(f'Got {len(additional_records.index)} additional records from DB.') if not additional_records.empty: # If more records were received, update the cache. self.update_candle_cache(additional_records, key) # Check if the records received are up-to-date. last_timestamp = query_uptodate(records=records, r_length_min=record_length) if last_timestamp: # The query was not up-to-date if a timestamp was returned. start_time = dt.datetime.utcfromtimestamp(last_timestamp) logger.debug(f'Requesting additional records from {start_time} to {end_datetime}') # Request additional records from the database. additional_records = self.get_records_since_from_db(table_name=key, st=start_time, et=end_datetime, rl=record_length, ex_details=ex_details) logger.debug(f'Got {len(additional_records.index)} additional records from DB.') if not additional_records.empty: self.update_candle_cache(additional_records, key) # Create a UTC timestamp. _timestamp = unix_time_millis(start_datetime) logger.debug(f"Start timestamp in human-readable form: {dt.datetime.utcfromtimestamp(_timestamp / 1000.0)}") # Return all records equal to or newer than the timestamp. result = self.get_cache(key).query('open_time >= @_timestamp') return result except Exception as e: logger.error(f"An error occurred: {str(e)}") raise def get_records_since_from_db(self, table_name: str, st: dt.datetime, et: dt.datetime, rl: float, ex_details: list) -> pd.DataFrame: """ Returns records from a specified table that meet a criteria and ensures the records are complete. If the records do not go back far enough or are not up-to-date, it fetches additional records from an exchange and populates the table. :param table_name: Database table name. :param st: Start datetime. :param et: End datetime. :param rl: Timespan in minutes each record represents. :param ex_details: Exchange details [symbol, interval, exchange_name]. :return: DataFrame of records. Example: -------- records = db.get_records_since('test_table', start_time, end_time, 1, ['BTC/USDT', '1m', 'binance']) """ def add_data(data, tn, start_t, end_t): new_records = self._populate_db(table_name=tn, start_time=start_t, end_time=end_t, ex_details=ex_details) print(f'Got {len(new_records.index)} records from exchange_name') if not new_records.empty: data = pd.concat([data, new_records], axis=0, ignore_index=True) data = data.drop_duplicates(subset="open_time", keep='first') return data if self.db.table_exists(table_name=table_name): print('\nTable existed retrieving records from DB') print(f'Requesting from {st} to {et}') records = self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=st, et=et) print(f'Got {len(records.index)} records from db') else: print(f'\nTable didnt exist fetching from {ex_details[2]}') temp = (((unix_time_millis(et) - unix_time_millis(st)) / 1000) / 60) / rl print(f'Requesting from {st} to {et}, Should be {temp} records') records = self._populate_db(table_name=table_name, start_time=st, end_time=et, ex_details=ex_details) print(f'Got {len(records.index)} records from {ex_details[2]}') first_timestamp = query_satisfied(start_datetime=st, records=records, r_length_min=rl) if first_timestamp: print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}') print(f'first ts on record is: {first_timestamp}') end_time = dt.datetime.utcfromtimestamp(first_timestamp) print(f'Requesting from {st} to {end_time}') records = add_data(data=records, tn=table_name, start_t=st, end_t=end_time) last_timestamp = query_uptodate(records=records, r_length_min=rl) if last_timestamp: print(f'\nRecords were not updated. Requesting from {ex_details[2]}.') print(f'the last record on file is: {last_timestamp}') start_time = dt.datetime.utcfromtimestamp(last_timestamp) print(f'Requesting from {start_time} to {et}') records = add_data(data=records, tn=table_name, start_t=start_time, end_t=et) return records def _populate_db(self, table_name: str, start_time: dt.datetime, ex_details: list, end_time: dt.datetime = None) -> pd.DataFrame: """ Populates a database table with records from the exchange. :param table_name: Name of the table in the database. :param start_time: Start time to fetch the records from. :param end_time: End time to fetch the records until (optional). :param ex_details: Exchange details [symbol, interval, exchange_name, user_name]. :return: DataFrame of the data downloaded. Example: -------- records = db._populate_table('test_table', start_time, ['BTC/USDT', '1m', 'binance', 'user1']) """ if end_time is None: end_time = dt.datetime.utcnow() sym, inter, ex, un = ex_details records = self._fetch_candles_from_exchange(symbol=sym, interval=inter, exchange_name=ex, user_name=un, start_datetime=start_time, end_datetime=end_time) if not records.empty: self.db.insert_candles_into_db(records, table_name=table_name, symbol=sym, exchange_name=ex) else: print(f'No records inserted {records}') return records def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str, start_datetime: dt.datetime = None, end_datetime: dt.datetime = None) -> pd.DataFrame: """ Fetches and returns all candles from the specified market, timeframe, and exchange. :param symbol: Symbol of the market. :param interval: Timeframe in the format '' (e.g., '15m', '4h'). :param exchange_name: Name of the exchange. :param user_name: Name of the user. :param start_datetime: Start datetime for fetching data (optional). :param end_datetime: End datetime for fetching data (optional). :return: DataFrame of candle data. Example: -------- candles = db._fetch_candles_from_exchange('BTC/USDT', '1m', 'binance', 'user1', start_time, end_time) """ def fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame: """ Fills gaps in the data by replicating the last known data point for the missing periods. :param records: DataFrame containing the original records. :param interval: Interval of the data (e.g., '1m', '5m'). :return: DataFrame with gaps filled. Example: -------- filled_records = fill_data_holes(df, '1m') """ time_span = timeframe_to_minutes(interval) last_timestamp = None filled_records = [] logger.info(f"Starting to fill data holes for interval: {interval}") for index, row in records.iterrows(): time_stamp = row['open_time'] # If last_timestamp is None, this is the first record if last_timestamp is None: last_timestamp = time_stamp filled_records.append(row) logger.debug(f"First timestamp: {time_stamp}") continue # Calculate the difference in milliseconds and minutes delta_ms = time_stamp - last_timestamp delta_minutes = (delta_ms / 1000) / 60 logger.debug(f"Timestamp: {time_stamp}, Delta minutes: {delta_minutes}") # If the gap is larger than the time span of the interval, fill the gap if delta_minutes > time_span: num_missing_rec = int(delta_minutes / time_span) step = int(delta_ms / num_missing_rec) logger.debug(f"Gap detected. Filling {num_missing_rec} records with step: {step}") for ts in range(int(last_timestamp) + step, int(time_stamp), step): new_row = row.copy() new_row['open_time'] = ts filled_records.append(new_row) logger.debug(f"Filled timestamp: {ts}") filled_records.append(row) last_timestamp = time_stamp logger.info("Data holes filled successfully.") return pd.DataFrame(filled_records) # Default start date if not provided if start_datetime is None: start_datetime = dt.datetime(year=2017, month=1, day=1) # Default end date if not provided if end_datetime is None: end_datetime = dt.datetime.utcnow() # Check if start date is greater than end date if start_datetime > end_datetime: raise ValueError("Invalid start and end parameters: start_datetime must be before end_datetime.") # Get the exchange object exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name) # Calculate the expected number of records temp = (((unix_time_millis(end_datetime) - unix_time_millis( start_datetime)) / 1000) / 60) / timeframe_to_minutes(interval) logger.info(f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {temp}') # If start and end times are the same, set end_datetime to None if start_datetime == end_datetime: end_datetime = None # Fetch historical candlestick data from the exchange candles = exchange.get_historical_klines(symbol=symbol, interval=interval, start_dt=start_datetime, end_dt=end_datetime) num_rec_records = len(candles.index) logger.info(f'{num_rec_records} candles retrieved from the exchange.') # Check if the retrieved data covers the expected time range open_times = candles.open_time estimated_num_records = (((open_times.max() - open_times.min()) / 1000) / 60) / timeframe_to_minutes(interval) # Fill in any missing data if the retrieved data is less than expected if num_rec_records < estimated_num_records: candles = fill_data_holes(candles, interval) return candles