import json from typing import List, Any, Tuple import pandas as pd import datetime as dt import logging from shared_utilities import unix_time_millis from Database import Database import numpy as np # Configure logging logger = logging.getLogger(__name__) def timeframe_to_timedelta(timeframe: str) -> pd.Timedelta | pd.DateOffset: digits = int("".join([i if i.isdigit() else "" for i in timeframe])) unit = "".join([i if i.isalpha() else "" for i in timeframe]) if unit == 'm': return pd.Timedelta(minutes=digits) elif unit == 'h': return pd.Timedelta(hours=digits) elif unit == 'd': return pd.Timedelta(days=digits) elif unit == 'w': return pd.Timedelta(weeks=digits) elif unit == 'M': return pd.DateOffset(months=digits) elif unit == 'Y': return pd.DateOffset(years=digits) else: raise ValueError(f"Invalid timeframe unit: {unit}") def estimate_record_count(start_time, end_time, timeframe: str) -> int: """ Estimate the number of records expected between start_time and end_time based on the given timeframe. Accepts either datetime objects or Unix timestamps in milliseconds. """ # Check if the input is in milliseconds (timestamp) if isinstance(start_time, (int, float, np.integer)) and isinstance(end_time, (int, float, np.integer)): # Convert timestamps from milliseconds to seconds for calculation start_time = int(start_time) / 1000 end_time = int(end_time) / 1000 start_datetime = dt.datetime.utcfromtimestamp(start_time).replace(tzinfo=dt.timezone.utc) end_datetime = dt.datetime.utcfromtimestamp(end_time).replace(tzinfo=dt.timezone.utc) elif isinstance(start_time, dt.datetime) and isinstance(end_time, dt.datetime): if start_time.tzinfo is None: raise ValueError("start_time is timezone naive. Please provide a timezone-aware datetime.") if end_time.tzinfo is None: raise ValueError("end_time is timezone naive. Please provide a timezone-aware datetime.") start_datetime = start_time end_datetime = end_time else: raise ValueError("start_time and end_time must be either both " "datetime objects or both Unix timestamps in milliseconds.") delta = timeframe_to_timedelta(timeframe) total_seconds = (end_datetime - start_datetime).total_seconds() expected_records = total_seconds // delta.total_seconds() return int(expected_records) # def generate_expected_timestamps(start_datetime: dt.datetime, end_datetime: dt.datetime, # timeframe: str) -> pd.DatetimeIndex: # """ # What it says. Todo: confirm this is unused and archive. # # :param start_datetime: # :param end_datetime: # :param timeframe: # :return: # """ # if start_datetime.tzinfo is None: # raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") # if end_datetime.tzinfo is None: # raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.") # # delta = timeframe_to_timedelta(timeframe) # if isinstance(delta, pd.Timedelta): # return pd.date_range(start=start_datetime, end=end_datetime, freq=delta) # elif isinstance(delta, pd.DateOffset): # current = start_datetime # timestamps = [] # while current <= end_datetime: # timestamps.append(current) # current += delta # return pd.DatetimeIndex(timestamps) class DataCache: TYPECHECKING_ENABLED = True def __init__(self, exchanges): self.db = Database() self.exchanges = exchanges # Single DataFrame for all cached data self.caches = {} logger.info("DataCache initialized.") def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None: """ Sets or updates an entry in the cache with the provided key. If the key already exists, the existing entry is replaced unless `do_not_overwrite` is True. In that case, the existing entry is preserved. Parameters: data: The data to be cached. This can be of any type. key: The unique key used to identify the cached data. do_not_overwrite : The default is False, meaning that the existing entry will be replaced. """ if do_not_overwrite and key in self.cache['key'].values: return # Construct a new DataFrame row with the key and data new_row = pd.DataFrame({'key': [key], 'data': [data]}) # If the key already exists in the cache, remove the old entry self.cache = self.cache[self.cache['key'] != key] # Append the new row to the cache self.cache = pd.concat([self.cache, new_row], ignore_index=True) print(f'Current Cache: {self.cache}') logger.debug(f'Cache set for key: {key}') def get_or_fetch_rows(self, table: str, filter_vals: Tuple[str, Any]) -> pd.DataFrame | None: """ Retrieves rows from the cache if available; otherwise, queries the database and caches the result. :param table: Name of the database table to query. :param filter_vals: A tuple containing the column name and the value to filter by. :return: A DataFrame containing the requested rows, or None if no matching rows are found. """ # Construct a filter condition for the cache based on the table name and filter values. cache_filter = (self.cache['table'] == table) & (self.cache[filter_vals[0]] == filter_vals[1]) cached_rows = self.cache[cache_filter] # If the data is found in the cache, return it. if not cached_rows.empty: return cached_rows # If the data is not found in the cache, query the database. rows = self.db.get_rows_where(table, filter_vals) if rows is not None: # Tag the rows with the table name and add them to the cache. rows['table'] = table self.cache = pd.concat([self.cache, rows]) return rows def remove_row(self, filter_vals: Tuple[str, Any], additional_filter: Tuple[str, Any] = None, remove_from_db: bool = True, table: str = None) -> None: """ Removes a specific row from the cache and optionally from the database based on filter criteria. :param filter_vals: A tuple containing the column name and the value to filter by. :param additional_filter: An optional additional filter to apply. :param remove_from_db: If True, also removes the row from the database. Default is True. :param table: The name of the table from which to remove the row in the database (optional). """ logger.debug( f"Removing row from cache: filter={filter_vals}," f" additional_filter={additional_filter}, remove_from_db={remove_from_db}, table={table}") # Construct the filter condition for the cache cache_filter = (self.cache[filter_vals[0]] == filter_vals[1]) if additional_filter: cache_filter = cache_filter & (self.cache[additional_filter[0]] == additional_filter[1]) # Remove the row from the cache self.cache = self.cache.drop(self.cache[cache_filter].index) logger.info(f"Row removed from cache: filter={filter_vals}") if remove_from_db and table: # Construct the SQL query to delete from the database sql = f"DELETE FROM {table} WHERE {filter_vals[0]} = ?" params = [filter_vals[1]] if additional_filter: sql += f" AND {additional_filter[0]} = ?" params.append(additional_filter[1]) # Execute the SQL query to remove the row from the database self.db.execute_sql(sql, params) logger.info( f"Row removed from database: table={table}, filter={filter_vals}," f" additional_filter={additional_filter}") def is_attr_taken(self, table: str, attr: str, val: Any) -> bool: """ Checks if a specific attribute in a table is already taken. :param table: The name of the table to check. :param attr: The attribute to check (e.g., 'user_name', 'email'). :param val: The value of the attribute to check. :return: True if the attribute is already taken, False otherwise. """ # Fetch rows from the specified table where the attribute matches the given value result = self.get_or_fetch_rows(table=table, filter_vals=(attr, val)) return result is not None and not result.empty def fetch_item(self, item_name: str, table_name: str, filter_vals: Tuple[str, Any]) -> Any: """ Retrieves a specific item from the cache or database, caching the result if necessary. :param item_name: The name of the column to retrieve. :param table_name: The name of the table where the item is stored. :param filter_vals: A tuple containing the column name and the value to filter by. :return: The value of the requested item. :raises ValueError: If the item is not found in either the cache or the database. """ # Fetch the relevant rows from the cache or database rows = self.get_or_fetch_rows(table_name, filter_vals) if rows is not None and not rows.empty: # Return the specific item from the first matching row. return rows.iloc[0][item_name] # If the item is not found, raise an error. raise ValueError(f"Item '{item_name}' not found in '{table_name}' where {filter_vals[0]} = {filter_vals[1]}") def modify_item(self, table: str, filter_vals: Tuple[str, Any], field_name: str, new_data: Any) -> None: """ Modifies a specific field in a row within the cache and updates the database accordingly. :param table: The name of the table where the data is stored. :param filter_vals: A tuple containing the column name and the value to filter by. :param field_name: The field to be updated. :param new_data: The new data to be set. """ # Retrieve the row from the cache or database row = self.get_or_fetch_rows(table, filter_vals) if row is None or row.empty: raise ValueError(f"Row not found in cache or database for {filter_vals[0]} = {filter_vals[1]}") # Modify the specified field if isinstance(new_data, str): row.loc[0, field_name] = new_data else: # If new_data is not a string, it’s converted to a JSON string before being inserted into the DataFrame. row.loc[0, field_name] = json.dumps(new_data) # Update the cache by removing the old entry and adding the modified row self.cache = self.cache.drop( self.cache[(self.cache['table'] == table) & (self.cache[filter_vals[0]] == filter_vals[1])].index ) self.cache = pd.concat([self.cache, row]) # Update the database with the modified row self.db.insert_dataframe(row.drop(columns='id'), table) def insert_df(self, df: pd.DataFrame, table: str, skip_cache: bool = False) -> None: """ Inserts data into the specified table in the database, with an option to skip cache insertion. :param df: The DataFrame containing the data to insert. :param table: The name of the table where the data should be inserted. :param skip_cache: If True, skips inserting the data into the cache. Default is False. """ # Insert the data into the database self.db.insert_dataframe(df=df, table=table) # Optionally insert the data into the cache if not skip_cache: df['table'] = table # Add table name for cache identification self.cache = pd.concat([self.cache, df]) def insert_row(self, table: str, columns: tuple, values: tuple) -> None: """ Inserts a single row into the specified table in the database. :param table: The name of the table where the row should be inserted. :param columns: A tuple of column names corresponding to the values. :param values: A tuple of values to insert into the specified columns. """ self.db.insert_row(table=table, columns=columns, values=values) def get_records_since(self, start_datetime: dt.datetime, ex_details: List[str]) -> pd.DataFrame: """ This gets up-to-date records from a specified market and exchange. :param start_datetime: The approximate time the first record should represent. :param ex_details: The user exchange and market. :return: The records. """ if self.TYPECHECKING_ENABLED: if not isinstance(start_datetime, dt.datetime): raise TypeError("start_datetime must be a datetime object") if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details): raise TypeError("ex_details must be a list of strings") if not len(ex_details) == 4: raise TypeError("ex_details must include [asset, timeframe, exchange, user_name]") if start_datetime.tzinfo is None: raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") end_datetime = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc) try: args = { 'start_datetime': start_datetime, 'end_datetime': end_datetime, 'ex_details': ex_details, } return self._get_or_fetch_from('data', **args) except Exception as e: logger.error(f"An error occurred: {str(e)}") raise def _get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame: """ Fetches market records from a resource stack (data, database, exchange). fills incomplete request by fetching down the stack then updates the rest. :param target: Starting point for the fetch. ['data', 'database', 'exchange'] :param kwargs: Details and credentials for the request. :return: Records in a dataframe. """ start_datetime = kwargs.get('start_datetime') if start_datetime.tzinfo is None: raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") end_datetime = kwargs.get('end_datetime') if end_datetime.tzinfo is None: raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.") ex_details = kwargs.get('ex_details') timeframe = kwargs.get('ex_details')[1] if self.TYPECHECKING_ENABLED: if not isinstance(start_datetime, dt.datetime): raise TypeError("start_datetime must be a datetime object") if not isinstance(end_datetime, dt.datetime): raise TypeError("end_datetime must be a datetime object") if not isinstance(timeframe, str): raise TypeError("record_length must be a string representing the timeframe") if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details): raise TypeError("ex_details must be a list of strings") if not all([start_datetime, end_datetime, timeframe, ex_details]): raise ValueError("Missing required arguments") request_criteria = { 'start_datetime': start_datetime, 'end_datetime': end_datetime, 'timeframe': timeframe, } key = self._make_key(ex_details) combined_data = pd.DataFrame() if target == 'data': resources = [self._get_candles_from_cache, self._get_from_database, self._get_from_server] elif target == 'database': resources = [self._get_from_database, self._get_from_server] elif target == 'server': resources = [self._get_from_server] else: raise ValueError('Not a valid Target!') for fetch_method in resources: result = fetch_method(**kwargs) if not result.empty: # Drop the 'id' column if it exists in the result if 'id' in result.columns: result = result.drop(columns=['id']) # Concatenate, drop duplicates based on 'time', and sort combined_data = pd.concat([combined_data, result]).drop_duplicates(subset='time').sort_values( by='time') is_complete, request_criteria = self._data_complete(combined_data, **request_criteria) if is_complete: if fetch_method in [self._get_from_database, self._get_from_server]: self._update_candle_cache(combined_data, key) if fetch_method == self._get_from_server: self._populate_db(ex_details, combined_data) return combined_data kwargs.update(request_criteria) # Update kwargs with new start/end times for next fetch attempt logger.error('Unable to fetch the requested data.') return combined_data if not combined_data.empty else pd.DataFrame() def _get_candles_from_cache(self, **kwargs) -> pd.DataFrame: start_datetime = kwargs.get('start_datetime') if start_datetime.tzinfo is None: raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") end_datetime = kwargs.get('end_datetime') if end_datetime.tzinfo is None: raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.") ex_details = kwargs.get('ex_details') if self.TYPECHECKING_ENABLED: if not isinstance(start_datetime, dt.datetime): raise TypeError("start_datetime must be a datetime object") if not isinstance(end_datetime, dt.datetime): raise TypeError("end_datetime must be a datetime object") if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details): raise TypeError("ex_details must be a list of strings") if not all([start_datetime, end_datetime, ex_details]): raise ValueError("Missing required arguments") key = self._make_key(ex_details) logger.debug('Getting records from data.') df = self.get_cache(key) if df is None: logger.debug("Cache records didn't exist.") return pd.DataFrame() logger.debug('Filtering records.') df_filtered = df[(df['time'] >= unix_time_millis(start_datetime)) & ( df['time'] <= unix_time_millis(end_datetime))].reset_index(drop=True) return df_filtered def _get_from_database(self, **kwargs) -> pd.DataFrame: start_datetime = kwargs.get('start_datetime') if start_datetime.tzinfo is None: raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") end_datetime = kwargs.get('end_datetime') if end_datetime.tzinfo is None: raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.") ex_details = kwargs.get('ex_details') if self.TYPECHECKING_ENABLED: if not isinstance(start_datetime, dt.datetime): raise TypeError("start_datetime must be a datetime object") if not isinstance(end_datetime, dt.datetime): raise TypeError("end_datetime must be a datetime object") if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details): raise TypeError("ex_details must be a list of strings") if not all([start_datetime, end_datetime, ex_details]): raise ValueError("Missing required arguments") table_name = self._make_key(ex_details) if not self.db.table_exists(table_name): logger.debug('Records not in database.') return pd.DataFrame() logger.debug('Getting records from database.') return self.db.get_timestamped_records(table_name=table_name, timestamp_field='time', st=start_datetime, et=end_datetime) def _get_from_server(self, **kwargs) -> pd.DataFrame: symbol = kwargs.get('ex_details')[0] interval = kwargs.get('ex_details')[1] exchange_name = kwargs.get('ex_details')[2] user_name = kwargs.get('ex_details')[3] start_datetime = kwargs.get('start_datetime') if start_datetime.tzinfo is None: raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") end_datetime = kwargs.get('end_datetime') if end_datetime.tzinfo is None: raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.") if self.TYPECHECKING_ENABLED: if not isinstance(symbol, str): raise TypeError("symbol must be a string") if not isinstance(interval, str): raise TypeError("interval must be a string") if not isinstance(exchange_name, str): raise TypeError("exchange_name must be a string") if not isinstance(user_name, str): raise TypeError("user_name must be a string") if not isinstance(start_datetime, dt.datetime): raise TypeError("start_datetime must be a datetime object") if not isinstance(end_datetime, dt.datetime): raise TypeError("end_datetime must be a datetime object") logger.debug('Getting records from server.') return self._fetch_candles_from_exchange(symbol, interval, exchange_name, user_name, start_datetime, end_datetime) @staticmethod def _data_complete(data: pd.DataFrame, **kwargs) -> (bool, dict): """ Checks if the data completely satisfies the request. :param data: DataFrame containing the records. :param kwargs: Arguments required for completeness check. :return: A tuple (is_complete, updated_request_criteria) where is_complete is True if the data is complete, False otherwise, and updated_request_criteria contains adjusted start/end times if data is incomplete. """ if data.empty: logger.debug("Data is empty.") return False, kwargs # No data at all, proceed with the full original request start_datetime: dt.datetime = kwargs.get('start_datetime') if start_datetime.tzinfo is None: raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") end_datetime: dt.datetime = kwargs.get('end_datetime') if end_datetime.tzinfo is None: raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.") timeframe: str = kwargs.get('timeframe') temp_data = data.copy() # Convert 'time' to datetime with unit='ms' and localize to UTC temp_data['time_dt'] = pd.to_datetime(temp_data['time'], unit='ms', errors='coerce').dt.tz_localize('UTC') min_timestamp = temp_data['time_dt'].min() max_timestamp = temp_data['time_dt'].max() logger.debug(f"Data time range: {min_timestamp} to {max_timestamp}") logger.debug(f"Expected time range: {start_datetime} to {end_datetime}") tolerance = pd.Timedelta(seconds=5) # Initialize updated request criteria updated_request_criteria = kwargs.copy() # Check if data covers the required time range with tolerance if min_timestamp > start_datetime + timeframe_to_timedelta(timeframe) + tolerance: logger.debug("Data does not start early enough, even with tolerance.") updated_request_criteria['end_datetime'] = min_timestamp # Fetch the missing earlier data return False, updated_request_criteria if max_timestamp < end_datetime - timeframe_to_timedelta(timeframe) - tolerance: logger.debug("Data does not extend late enough, even with tolerance.") updated_request_criteria['start_datetime'] = max_timestamp # Fetch the missing later data return False, updated_request_criteria # Filter data between start_datetime and end_datetime mask = (temp_data['time_dt'] >= start_datetime) & (temp_data['time_dt'] <= end_datetime) data_in_range = temp_data.loc[mask] expected_count = estimate_record_count(start_datetime, end_datetime, timeframe) actual_count = len(data_in_range) logger.debug(f"Expected record count: {expected_count}, Actual record count: {actual_count}") tolerance = 1 if actual_count < (expected_count - tolerance): logger.debug("Insufficient records within the specified time range, even with tolerance.") return False, updated_request_criteria logger.debug("Data completeness check passed.") return True, kwargs def cache_exists(self, key: str) -> bool: return key in self.cache['key'].values def get_cache(self, key: str) -> Any | None: # Check if the key exists in the cache if key not in self.cache['key'].values: logger.warning(f"The requested data key ({key}) doesn't exist!") return None # Retrieve the data associated with the key result = self.cache[self.cache['key'] == key]['data'].iloc[0] return result def _update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None: logger.debug('Updating data with new records.') # Concatenate the new records with the existing data records = pd.concat([self.get_cache(key), more_records], axis=0, ignore_index=True) # Drop duplicates based on 'time' and keep the first occurrence records = records.drop_duplicates(subset="time", keep='first') # Sort the records by 'time' records = records.sort_values(by='time').reset_index(drop=True) # Reindex 'id' to ensure the expected order records['id'] = range(1, len(records) + 1) # Set the updated DataFrame back to data self.set_cache(data=records, key=key) def update_cached_dict(self, cache_key: str, dict_key: str, data: Any) -> None: """ Updates a dictionary stored in the DataFrame cache. :param data: The data to insert into the dictionary. :param cache_key: The cache key for the dictionary. :param dict_key: The key within the dictionary to update. :return: None """ # Locate the row in the DataFrame that matches the cache_key cache_index = self.cache.index[self.cache['key'] == cache_key] if not cache_index.empty: # Update the dictionary stored in the 'data' column cache_dict = self.cache.at[cache_index[0], 'data'] if isinstance(cache_dict, dict): cache_dict[dict_key] = data # Ensure the DataFrame is updated with the new dictionary self.cache.at[cache_index[0], 'data'] = cache_dict else: raise ValueError(f"Expected a dictionary in cache, but found {type(cache_dict)}.") else: raise KeyError(f"Cache key '{cache_key}' not found.") def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str, start_datetime: dt.datetime = None, end_datetime: dt.datetime = None) -> pd.DataFrame: if start_datetime is None: start_datetime = dt.datetime(year=2017, month=1, day=1, tzinfo=dt.timezone.utc) if end_datetime is None: end_datetime = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc) if start_datetime > end_datetime: raise ValueError("Invalid start and end parameters: start_datetime must be before end_datetime.") exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name) expected_records = estimate_record_count(start_datetime, end_datetime, interval) logger.info( f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {expected_records}') if start_datetime == end_datetime: end_datetime = None candles = exchange.get_historical_klines(symbol=symbol, interval=interval, start_dt=start_datetime, end_dt=end_datetime) num_rec_records = len(candles.index) if num_rec_records == 0: logger.warning(f"No OHLCV data returned for {symbol}.") return pd.DataFrame(columns=['time', 'open', 'high', 'low', 'close', 'volume']) logger.info(f'{num_rec_records} candles retrieved from the exchange.') open_times = candles.time min_open_time = open_times.min() max_open_time = open_times.max() if min_open_time < 1e10: raise ValueError('Records are not in milliseconds') estimated_num_records = estimate_record_count(min_open_time, max_open_time, interval) + 1 logger.info(f'Estimated number of records: {estimated_num_records}') if num_rec_records < estimated_num_records: logger.info('Detected gaps in the data, attempting to fill missing records.') candles = self._fill_data_holes(candles, interval) return candles @staticmethod def _fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame: time_span = timeframe_to_timedelta(interval).total_seconds() / 60 last_timestamp = None filled_records = [] logger.info(f"Starting to fill data holes for interval: {interval}") for index, row in records.iterrows(): time_stamp = row['time'] if last_timestamp is None: last_timestamp = time_stamp filled_records.append(row) logger.debug(f"First timestamp: {time_stamp}") continue delta_ms = time_stamp - last_timestamp delta_minutes = (delta_ms / 1000) / 60 logger.debug(f"Timestamp: {time_stamp}, Delta minutes: {delta_minutes}") if delta_minutes > time_span: num_missing_rec = int(delta_minutes / time_span) step = int(delta_ms / num_missing_rec) logger.debug(f"Gap detected. Filling {num_missing_rec} records with step: {step}") for ts in range(int(last_timestamp) + step, int(time_stamp), step): new_row = row.copy() new_row['time'] = ts filled_records.append(new_row) logger.debug(f"Filled timestamp: {ts}") filled_records.append(row) last_timestamp = time_stamp logger.info("Data holes filled successfully.") return pd.DataFrame(filled_records) def _populate_db(self, ex_details: List[str], data: pd.DataFrame = None) -> None: if data is None or data.empty: logger.debug(f'No records to insert {data}') return table_name = self._make_key(ex_details) sym, _, ex, _ = ex_details self.db.insert_candles_into_db(data, table_name=table_name, symbol=sym, exchange_name=ex) logger.info(f'Data inserted into table {table_name}') @staticmethod def _make_key(ex_details: List[str]) -> str: sym, tf, ex, _ = ex_details key = f'{sym}_{tf}_{ex}' return key