From d288eebbecbce33dedfff05d449062b29fa739b3 Mon Sep 17 00:00:00 2001 From: Rob Date: Thu, 15 Aug 2024 22:39:38 -0300 Subject: [PATCH] Re-wrote DataCache.py and Implemented tests and got them all passing. --- src/DataCache.py | 323 +++++++++++-------- src/DataCache_v2.py | 454 ++++++++++++++++++++++++++ src/Database.py | 163 +++------- src/Exchange.py | 86 ++--- src/ExchangeInterface.py | 10 +- src/candles.py | 6 +- src/maintenence/debuging_testing.py | 46 ++- src/shared_utilities.py | 46 ++- tests/test_DataCache_v2.py | 478 ++++++++++++++++++++++++++++ 9 files changed, 1287 insertions(+), 325 deletions(-) create mode 100644 src/DataCache_v2.py create mode 100644 tests/test_DataCache_v2.py diff --git a/src/DataCache.py b/src/DataCache.py index cbbb5c8..c2388ca 100644 --- a/src/DataCache.py +++ b/src/DataCache.py @@ -5,13 +5,12 @@ from Database import Database from shared_utilities import query_satisfied, query_uptodate, unix_time_millis, timeframe_to_minutes import logging - logger = logging.getLogger(__name__) class DataCache: """ - Fetches manages data limits and optimises memory storage. + Fetches and manages data limits and optimizes memory storage. Handles connections and operations for the given exchanges. Example usage: @@ -19,156 +18,229 @@ class DataCache: db = DataCache(exchanges=some_exchanges_object) """ + # Disable during production for improved performance. + TYPECHECKING_ENABLED = True + + NO_RECORDS_FOUND = float('nan') + def __init__(self, exchanges): """ Initializes the DataCache class. :param exchanges: The exchanges object handling communication with connected exchanges. """ - # Maximum number of tables to cache at any given time. self.max_tables = 50 - # Maximum number of records to be kept per table. self.max_records = 1000 - # A dictionary that holds all the cached records. self.cached_data = {} - # The class that handles the DB interactions. self.db = Database() - # The class that handles exchange interactions. self.exchanges = exchanges def cache_exists(self, key: str) -> bool: """ - Return False if a cache doesn't exist for this key. + Checks if a cache exists for the given key. + + :param key: The access key. + :return: True if cache exists, False otherwise. """ return key in self.cached_data def update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None: """ + Adds records to existing cache. - :param more_records: Adds records to existing cache. + :param more_records: The new records to be added. :param key: The access key. :return: None. """ - # Combine the new candles with the previously cached dataframe. records = pd.concat([more_records, self.get_cache(key)], axis=0, ignore_index=True) - # Drop any duplicates from overlap. records = records.drop_duplicates(subset="open_time", keep='first') - # Sort the records by open_time. records = records.sort_values(by='open_time').reset_index(drop=True) - # Replace the incomplete dataframe with the modified one. self.set_cache(data=records, key=key) - return def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None: """ - Creates a new cache key and inserts some data. - Todo: This is where this data will be passed to a cache server. - :param data: The records to insert into cache. - :param key: The index key for the data. - :param do_not_overwrite: - Flag to prevent overwriting existing data. - :return: None + Creates a new cache key and inserts data. + + :param data: The records to insert into cache. + :param key: The index key for the data. + :param do_not_overwrite: Flag to prevent overwriting existing data. + :return: None """ - # If the flag is set don't overwrite existing data. if do_not_overwrite and key in self.cached_data: return - # Assign the data self.cached_data[key] = data - def update_cached_dict(self, cache_key, dict_key: str, data: Any) -> None: + def update_cached_dict(self, cache_key: str, dict_key: str, data: Any) -> None: """ - Updates a dictionary stored in cache. - Todo: This is where this data will be passed to a cache server. - :param data: - The records to insert into cache. - :param cache_key: - The cache index key for the dictionary. - :param dict_key: - The dictionary key for the data. - :return: None + Updates a dictionary stored in cache. + + :param data: The data to insert into cache. + :param cache_key: The cache index key for the dictionary. + :param dict_key: The dictionary key for the data. + :return: None """ - # Assign the data self.cached_data[cache_key].update({dict_key: data}) def get_cache(self, key: str) -> Any: """ - Returns data indexed by key. - Todo: This is where data will be retrieved from a cache server. + Returns data indexed by key. - :param key: The index key for the data. - :return: Any|None - The requested data or None on key error. + :param key: The index key for the data. + :return: Any|None - The requested data or None on key error. """ if key not in self.cached_data: - print(f"[WARNING: DataCache.py] The requested cache key({key}) doesn't exist!") + logger.warning(f"The requested cache key({key}) doesn't exist!") return None return self.cached_data[key] + def improved_get_records_since(self, key: str, start_datetime: dt.datetime, record_length: int, + ex_details: List[str]) -> pd.DataFrame: + """ + Fetches records since the specified start datetime. + + :param key: The cache key. + :param start_datetime: The start datetime to fetch records from. + :param record_length: The required number of records. + :param ex_details: Exchange details. + :return: DataFrame containing the records. + """ + try: + target = 'cache' + args = { + 'key': key, + 'start_datetime': start_datetime, + 'end_datetime': dt.datetime.utcnow(), + 'record_length': record_length, + 'ex_details': ex_details + } + + df = self.get_or_fetch_from(target=target, **args) + except Exception as e: + logger.error(f"An error occurred: {str(e)}") + raise + + def get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame: + key = kwargs.get('key') + start_datetime = kwargs.get('start_datetime') + end_datetime = kwargs.get('start_datetime') + record_length = kwargs.get('record_length') + ex_details = kwargs.get('ex_details') + + if self.TYPECHECKING_ENABLED: + # Type checking + if not isinstance(key, str): + raise TypeError("key must be a string") + if not isinstance(start_datetime, dt.datetime): + raise TypeError("start_datetime must be a datetime object") + if not isinstance(end_datetime, dt.datetime): + raise TypeError("end_datetime must be a datetime object") + if not isinstance(record_length, int): + raise TypeError("record_length must be an integer") + if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details): + raise TypeError("ex_details must be a list of strings") + + # Ensure all required arguments are provided + if not all([key, start_datetime, record_length, ex_details]): + raise ValueError("Missing required arguments") + + def get_from_cache(): + return pd.DataFrame + + def get_from_database(): + return pd.DataFrame + + def get_from_server(): + return pd.DataFrame + + def data_complete(data, **kwargs) -> bool: + """Check if a dataframe completely satisfied a request.""" + sd = kwargs.get('start_datetime') + ed = kwargs.get('start_datetime') + rl = kwargs.get('record_length') + + is_complete = True + return is_complete + + request_criteria = { + 'start_datetime': start_datetime, + 'end_datetime': end_datetime, + 'record_length': record_length, + } + + if target == 'cache': + result = get_from_cache() + if data_complete(result, **request_criteria): + return result + else: + self.get_or_fetch_from('database', **kwargs) + elif target == 'database': + result = get_from_database() + if data_complete(result, **request_criteria): + return result + else: + self.get_or_fetch_from('server', **kwargs) + elif target == 'server': + result = get_from_server() + if data_complete(result, **request_criteria): + return result + else: + logger.error('Unable to fetch the requested data.') + else: + raise ValueError(f'Not a valid target: {target}') + def get_records_since(self, key: str, start_datetime: dt.datetime, record_length: int, ex_details: List[str]) -> pd.DataFrame: """ - Return any records from the cache indexed by table_name that are newer than start_datetime. + Fetches records since the specified start datetime. - :param ex_details: List[str] - Details to pass to the server - :param key: str - The dictionary table_name of the records. - :param start_datetime: dt.datetime - The datetime of the first record requested. - :param record_length: int - The timespan of the records. - - :return: pd.DataFrame - The Requested records - - Example: - -------- - records = data_cache.get_records_since('BTC/USD_2h_binance', dt.datetime.utcnow() - dt.timedelta(minutes=60), 60, ['BTC/USD', '2h', 'binance']) + :param key: The cache key. + :param start_datetime: The start datetime to fetch records from. + :param record_length: The required number of records. + :param ex_details: Exchange details. + :return: DataFrame containing the records. """ try: - # End time of query defaults to the current time. end_datetime = dt.datetime.utcnow() if self.cache_exists(key=key): logger.debug('Getting records from cache.') - # If the records exist, retrieve them from the cache. records = self.get_cache(key) else: - # If they don't exist in cache, get them from the database. logger.debug( f'Records not in cache. Requesting from DB: starting at: {start_datetime} to: {end_datetime}') - records = self.get_records_since_from_db(table_name=key, st=start_datetime, - et=end_datetime, rl=record_length, ex_details=ex_details) + records = self.get_records_since_from_db(table_name=key, st=start_datetime, et=end_datetime, + rl=record_length, ex_details=ex_details) logger.debug(f'Got {len(records.index)} records from DB.') self.set_cache(data=records, key=key) - # Check if the records in the cache go far enough back to satisfy the query. - first_timestamp = query_satisfied(start_datetime=start_datetime, - records=records, + first_timestamp = query_satisfied(start_datetime=start_datetime, records=records, r_length_min=record_length) - if first_timestamp: - # The records didn't go far enough back if a timestamp was returned. + if pd.isna(first_timestamp): + logger.debug('No records found to satisfy the query, continuing to fetch more records.') + additional_records = self.get_records_since_from_db(table_name=key, st=start_datetime, et=end_datetime, + rl=record_length, ex_details=ex_details) + elif first_timestamp is not None: end_time = dt.datetime.utcfromtimestamp(first_timestamp) logger.debug(f'Requesting additional records from {start_datetime} to {end_time}') - # Request additional records from the database. - additional_records = self.get_records_since_from_db(table_name=key, st=start_datetime, - et=end_time, rl=record_length, - ex_details=ex_details) + additional_records = self.get_records_since_from_db(table_name=key, st=start_datetime, et=end_time, + rl=record_length, ex_details=ex_details) + if not additional_records.empty: logger.debug(f'Got {len(additional_records.index)} additional records from DB.') - if not additional_records.empty: - # If more records were received, update the cache. - self.update_candle_cache(additional_records, key) + self.update_candle_cache(additional_records, key) - # Check if the records received are up-to-date. last_timestamp = query_uptodate(records=records, r_length_min=record_length) - - if last_timestamp: - # The query was not up-to-date if a timestamp was returned. + if last_timestamp is not None: start_time = dt.datetime.utcfromtimestamp(last_timestamp) logger.debug(f'Requesting additional records from {start_time} to {end_datetime}') - # Request additional records from the database. - additional_records = self.get_records_since_from_db(table_name=key, st=start_time, - et=end_datetime, rl=record_length, - ex_details=ex_details) + additional_records = self.get_records_since_from_db(table_name=key, st=start_time, et=end_datetime, + rl=record_length, ex_details=ex_details) logger.debug(f'Got {len(additional_records.index)} additional records from DB.') if not additional_records.empty: self.update_candle_cache(additional_records, key) - # Create a UTC timestamp. _timestamp = unix_time_millis(start_datetime) logger.debug(f"Start timestamp in human-readable form: {dt.datetime.utcfromtimestamp(_timestamp / 1000.0)}") - # Return all records equal to or newer than the timestamp. result = self.get_cache(key).query('open_time >= @_timestamp') return result @@ -176,64 +248,61 @@ class DataCache: logger.error(f"An error occurred: {str(e)}") raise - def get_records_since_from_db(self, table_name: str, st: dt.datetime, - et: dt.datetime, rl: float, ex_details: list) -> pd.DataFrame: + def get_records_since_from_db(self, table_name: str, st: dt.datetime, et: dt.datetime, rl: float, + ex_details: List[str]) -> pd.DataFrame: """ - Returns records from a specified table that meet a criteria and ensures the records are complete. - If the records do not go back far enough or are not up-to-date, it fetches additional records - from an exchange and populates the table. + Fetches records from the database since the specified start datetime. - :param table_name: Database table name. - :param st: Start datetime. - :param et: End datetime. - :param rl: Timespan in minutes each record represents. - :param ex_details: Exchange details [symbol, interval, exchange_name]. - :return: DataFrame of records. - - Example: - -------- - records = db.get_records_since('test_table', start_time, end_time, 1, ['BTC/USDT', '1m', 'binance']) + :param table_name: The name of the table in the database. + :param st: The start datetime to fetch records from. + :param et: The end datetime to fetch records until. + :param rl: The required number of records. + :param ex_details: Exchange details. + :return: DataFrame containing the records. """ - def add_data(data, tn, start_t, end_t): + def add_data(data: pd.DataFrame, tn: str, start_t: dt.datetime, end_t: dt.datetime) -> pd.DataFrame: new_records = self._populate_db(table_name=tn, start_time=start_t, end_time=end_t, ex_details=ex_details) - print(f'Got {len(new_records.index)} records from exchange_name') + logger.debug(f'Got {len(new_records.index)} records from exchange_name') if not new_records.empty: data = pd.concat([data, new_records], axis=0, ignore_index=True) data = data.drop_duplicates(subset="open_time", keep='first') return data if self.db.table_exists(table_name=table_name): - print('\nTable existed retrieving records from DB') - print(f'Requesting from {st} to {et}') + logger.debug('Table existed retrieving records from DB') + logger.debug(f'Requesting from {st} to {et}') records = self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=st, et=et) - print(f'Got {len(records.index)} records from db') + logger.debug(f'Got {len(records.index)} records from db') else: - print(f'\nTable didnt exist fetching from {ex_details[2]}') + logger.debug(f"Table didn't exist fetching from {ex_details[2]}") temp = (((unix_time_millis(et) - unix_time_millis(st)) / 1000) / 60) / rl - print(f'Requesting from {st} to {et}, Should be {temp} records') + logger.debug(f'Requesting from {st} to {et}, Should be {temp} records') records = self._populate_db(table_name=table_name, start_time=st, end_time=et, ex_details=ex_details) - print(f'Got {len(records.index)} records from {ex_details[2]}') + logger.debug(f'Got {len(records.index)} records from {ex_details[2]}') first_timestamp = query_satisfied(start_datetime=st, records=records, r_length_min=rl) - if first_timestamp: - print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}') - print(f'first ts on record is: {first_timestamp}') + if pd.isna(first_timestamp): + logger.debug('No records found to satisfy the query, continuing to fetch more records.') + records = add_data(data=records, tn=table_name, start_t=st, end_t=et) + elif first_timestamp: + logger.debug(f'Records did not go far enough back. Requesting from {ex_details[2]}') + logger.debug(f'First ts on record is: {first_timestamp}') end_time = dt.datetime.utcfromtimestamp(first_timestamp) - print(f'Requesting from {st} to {end_time}') + logger.debug(f'Requesting from {st} to {end_time}') records = add_data(data=records, tn=table_name, start_t=st, end_t=end_time) last_timestamp = query_uptodate(records=records, r_length_min=rl) if last_timestamp: - print(f'\nRecords were not updated. Requesting from {ex_details[2]}.') - print(f'the last record on file is: {last_timestamp}') + logger.debug(f'Records were not updated. Requesting from {ex_details[2]}.') + logger.debug(f'The last record on file is: {last_timestamp}') start_time = dt.datetime.utcfromtimestamp(last_timestamp) - print(f'Requesting from {start_time} to {et}') + logger.debug(f'Requesting from {start_time} to {et}') records = add_data(data=records, tn=table_name, start_t=start_time, end_t=et) return records - def _populate_db(self, table_name: str, start_time: dt.datetime, ex_details: list, + def _populate_db(self, table_name: str, start_time: dt.datetime, ex_details: List[str], end_time: dt.datetime = None) -> pd.DataFrame: """ Populates a database table with records from the exchange. @@ -243,10 +312,6 @@ class DataCache: :param end_time: End time to fetch the records until (optional). :param ex_details: Exchange details [symbol, interval, exchange_name, user_name]. :return: DataFrame of the data downloaded. - - Example: - -------- - records = db._populate_table('test_table', start_time, ['BTC/USDT', '1m', 'binance', 'user1']) """ if end_time is None: end_time = dt.datetime.utcnow() @@ -256,7 +321,7 @@ class DataCache: if not records.empty: self.db.insert_candles_into_db(records, table_name=table_name, symbol=sym, exchange_name=ex) else: - print(f'No records inserted {records}') + logger.debug(f'No records inserted {records}') return records def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str, @@ -272,10 +337,6 @@ class DataCache: :param start_datetime: Start datetime for fetching data (optional). :param end_datetime: End datetime for fetching data (optional). :return: DataFrame of candle data. - - Example: - -------- - candles = db._fetch_candles_from_exchange('BTC/USDT', '1m', 'binance', 'user1', start_time, end_time) """ def fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame: @@ -285,10 +346,6 @@ class DataCache: :param records: DataFrame containing the original records. :param interval: Interval of the data (e.g., '1m', '5m'). :return: DataFrame with gaps filled. - - Example: - -------- - filled_records = fill_data_holes(df, '1m') """ time_span = timeframe_to_minutes(interval) last_timestamp = None @@ -299,20 +356,17 @@ class DataCache: for index, row in records.iterrows(): time_stamp = row['open_time'] - # If last_timestamp is None, this is the first record if last_timestamp is None: last_timestamp = time_stamp filled_records.append(row) logger.debug(f"First timestamp: {time_stamp}") continue - # Calculate the difference in milliseconds and minutes delta_ms = time_stamp - last_timestamp delta_minutes = (delta_ms / 1000) / 60 logger.debug(f"Timestamp: {time_stamp}, Delta minutes: {delta_minutes}") - # If the gap is larger than the time span of the interval, fill the gap if delta_minutes > time_span: num_missing_rec = int(delta_minutes / time_span) step = int(delta_ms / num_missing_rec) @@ -330,44 +384,47 @@ class DataCache: logger.info("Data holes filled successfully.") return pd.DataFrame(filled_records) - # Default start date if not provided if start_datetime is None: start_datetime = dt.datetime(year=2017, month=1, day=1) - # Default end date if not provided if end_datetime is None: end_datetime = dt.datetime.utcnow() - # Check if start date is greater than end date if start_datetime > end_datetime: raise ValueError("Invalid start and end parameters: start_datetime must be before end_datetime.") - # Get the exchange object exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name) - # Calculate the expected number of records - temp = (((unix_time_millis(end_datetime) - unix_time_millis( + expected_records = (((unix_time_millis(end_datetime) - unix_time_millis( start_datetime)) / 1000) / 60) / timeframe_to_minutes(interval) + logger.info( + f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {expected_records}') - logger.info(f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {temp}') - - # If start and end times are the same, set end_datetime to None if start_datetime == end_datetime: end_datetime = None - # Fetch historical candlestick data from the exchange candles = exchange.get_historical_klines(symbol=symbol, interval=interval, start_dt=start_datetime, end_dt=end_datetime) num_rec_records = len(candles.index) logger.info(f'{num_rec_records} candles retrieved from the exchange.') - # Check if the retrieved data covers the expected time range open_times = candles.open_time - estimated_num_records = (((open_times.max() - open_times.min()) / 1000) / 60) / timeframe_to_minutes(interval) + min_open_time = open_times.min() + max_open_time = open_times.max() + + if min_open_time < 1e10: + raise ValueError('Records are not in milliseconds') + + max_open_time /= 1000 + min_open_time /= 1000 + + estimated_num_records = ((max_open_time - min_open_time) / 60) / timeframe_to_minutes(interval) + + logger.info(f'Estimated number of records: {estimated_num_records}') - # Fill in any missing data if the retrieved data is less than expected if num_rec_records < estimated_num_records: + logger.info('Detected gaps in the data, attempting to fill missing records.') candles = fill_data_holes(candles, interval) return candles diff --git a/src/DataCache_v2.py b/src/DataCache_v2.py new file mode 100644 index 0000000..957cbcd --- /dev/null +++ b/src/DataCache_v2.py @@ -0,0 +1,454 @@ +from typing import List, Any +import pandas as pd +import datetime as dt +import logging +from shared_utilities import unix_time_millis +from Database import Database +import numpy as np + +# Configure logging +logger = logging.getLogger(__name__) + + +def timeframe_to_timedelta(timeframe: str) -> pd.Timedelta | pd.DateOffset: + digits = int("".join([i if i.isdigit() else "" for i in timeframe])) + unit = "".join([i if i.isalpha() else "" for i in timeframe]) + + if unit == 'm': + return pd.Timedelta(minutes=digits) + elif unit == 'h': + return pd.Timedelta(hours=digits) + elif unit == 'd': + return pd.Timedelta(days=digits) + elif unit == 'w': + return pd.Timedelta(weeks=digits) + elif unit == 'M': + return pd.DateOffset(months=digits) + elif unit == 'Y': + return pd.DateOffset(years=digits) + else: + raise ValueError(f"Invalid timeframe unit: {unit}") + + +def estimate_record_count(start_time, end_time, timeframe: str) -> int: + """ + Estimate the number of records expected between start_time and end_time based on the given timeframe. + Accepts either datetime objects or Unix timestamps in milliseconds. + """ + # Check if the input is in milliseconds (timestamp) + if isinstance(start_time, (int, float, np.integer)) and isinstance(end_time, (int, float, np.integer)): + # Convert timestamps from milliseconds to seconds for calculation + start_time = int(start_time) / 1000 + end_time = int(end_time) / 1000 + start_datetime = dt.datetime.utcfromtimestamp(start_time) + end_datetime = dt.datetime.utcfromtimestamp(end_time) + elif isinstance(start_time, dt.datetime) and isinstance(end_time, dt.datetime): + start_datetime = start_time + end_datetime = end_time + else: + raise ValueError("start_time and end_time must be either both " + "datetime objects or both Unix timestamps in milliseconds.") + + delta = timeframe_to_timedelta(timeframe) + total_seconds = (end_datetime - start_datetime).total_seconds() + expected_records = total_seconds // delta.total_seconds() + return int(expected_records) + + +def generate_expected_timestamps(start_datetime: dt.datetime, end_datetime: dt.datetime, + timeframe: str) -> pd.DatetimeIndex: + delta = timeframe_to_timedelta(timeframe) + if isinstance(delta, pd.Timedelta): + return pd.date_range(start=start_datetime, end=end_datetime, freq=delta) + elif isinstance(delta, pd.DateOffset): + current = start_datetime + timestamps = [] + while current <= end_datetime: + timestamps.append(current) + current += delta + return pd.DatetimeIndex(timestamps) + + +class DataCache: + TYPECHECKING_ENABLED = True + + def __init__(self, exchanges): + self.db = Database() + self.exchanges = exchanges + self.cached_data = {} + logger.info("DataCache initialized.") + + def get_records_since(self, start_datetime: dt.datetime, ex_details: List[str]) -> pd.DataFrame: + if self.TYPECHECKING_ENABLED: + if not isinstance(start_datetime, dt.datetime): + raise TypeError("start_datetime must be a datetime object") + if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details): + raise TypeError("ex_details must be a list of strings") + if not len(ex_details) == 4: + raise TypeError("ex_details must include [asset, timeframe, exchange, user_name]") + + try: + args = { + 'start_datetime': start_datetime, + 'end_datetime': dt.datetime.utcnow(), + 'ex_details': ex_details, + } + return self.get_or_fetch_from('cache', **args) + except Exception as e: + logger.error(f"An error occurred: {str(e)}") + raise + + def get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame: + start_datetime = kwargs.get('start_datetime') + end_datetime = kwargs.get('end_datetime') + ex_details = kwargs.get('ex_details') + timeframe = kwargs.get('ex_details')[1] + + if self.TYPECHECKING_ENABLED: + if not isinstance(start_datetime, dt.datetime): + raise TypeError("start_datetime must be a datetime object") + if not isinstance(end_datetime, dt.datetime): + raise TypeError("end_datetime must be a datetime object") + if not isinstance(timeframe, str): + raise TypeError("record_length must be a string representing the timeframe") + if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details): + raise TypeError("ex_details must be a list of strings") + if not all([start_datetime, end_datetime, timeframe, ex_details]): + raise ValueError("Missing required arguments") + + request_criteria = { + 'start_datetime': start_datetime, + 'end_datetime': end_datetime, + 'timeframe': timeframe, + } + + key = self._make_key(ex_details) + combined_data = pd.DataFrame() + + if target == 'cache': + resources = [self.get_candles_from_cache, self.get_from_database, self.get_from_server] + elif target == 'database': + resources = [self.get_from_database, self.get_from_server] + elif target == 'server': + resources = [self.get_from_server] + else: + raise ValueError('Not a valid Target!') + + for fetch_method in resources: + result = fetch_method(**kwargs) + + if not result.empty: + combined_data = pd.concat([combined_data, result]).drop_duplicates() + + if not combined_data.empty and 'open_time' in combined_data.columns: + combined_data = combined_data.sort_values(by='open_time').drop_duplicates(subset='open_time', + keep='first') + + is_complete, request_criteria = self.data_complete(combined_data, **request_criteria) + if is_complete: + if fetch_method in [self.get_from_database, self.get_from_server]: + self.update_candle_cache(combined_data, key) + if fetch_method == self.get_from_server: + self._populate_db(ex_details, combined_data) + return combined_data + + kwargs.update(request_criteria) # Update kwargs with new start/end times for next fetch attempt + + logger.error('Unable to fetch the requested data.') + return combined_data if not combined_data.empty else pd.DataFrame() + + def get_candles_from_cache(self, **kwargs) -> pd.DataFrame: + start_datetime = kwargs.get('start_datetime') + end_datetime = kwargs.get('end_datetime') + ex_details = kwargs.get('ex_details') + + if self.TYPECHECKING_ENABLED: + if not isinstance(start_datetime, dt.datetime): + raise TypeError("start_datetime must be a datetime object") + if not isinstance(end_datetime, dt.datetime): + raise TypeError("end_datetime must be a datetime object") + if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details): + raise TypeError("ex_details must be a list of strings") + if not all([start_datetime, end_datetime, ex_details]): + raise ValueError("Missing required arguments") + + key = self._make_key(ex_details) + logger.debug('Getting records from cache.') + df = self.get_cache(key) + if df is None: + logger.debug("Cache records didn't exist.") + return pd.DataFrame() + logger.debug('Filtering records.') + df_filtered = df[(df['open_time'] >= unix_time_millis(start_datetime)) & ( + df['open_time'] <= unix_time_millis(end_datetime))].reset_index(drop=True) + return df_filtered + + def get_from_database(self, **kwargs) -> pd.DataFrame: + start_datetime = kwargs.get('start_datetime') + end_datetime = kwargs.get('end_datetime') + ex_details = kwargs.get('ex_details') + + if self.TYPECHECKING_ENABLED: + if not isinstance(start_datetime, dt.datetime): + raise TypeError("start_datetime must be a datetime object") + if not isinstance(end_datetime, dt.datetime): + raise TypeError("end_datetime must be a datetime object") + if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details): + raise TypeError("ex_details must be a list of strings") + if not all([start_datetime, end_datetime, ex_details]): + raise ValueError("Missing required arguments") + + table_name = self._make_key(ex_details) + if not self.db.table_exists(table_name): + logger.debug('Records not in database.') + return pd.DataFrame() + + logger.debug('Getting records from database.') + return self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=start_datetime, + et=end_datetime) + + def get_from_server(self, **kwargs) -> pd.DataFrame: + symbol = kwargs.get('ex_details')[0] + interval = kwargs.get('ex_details')[1] + exchange_name = kwargs.get('ex_details')[2] + user_name = kwargs.get('ex_details')[3] + start_datetime = kwargs.get('start_datetime') + end_datetime = kwargs.get('end_datetime') + + if self.TYPECHECKING_ENABLED: + if not isinstance(symbol, str): + raise TypeError("symbol must be a string") + if not isinstance(interval, str): + raise TypeError("interval must be a string") + if not isinstance(exchange_name, str): + raise TypeError("exchange_name must be a string") + if not isinstance(user_name, str): + raise TypeError("user_name must be a string") + if not isinstance(start_datetime, dt.datetime): + raise TypeError("start_datetime must be a datetime object") + if not isinstance(end_datetime, dt.datetime): + raise TypeError("end_datetime must be a datetime object") + + logger.debug('Getting records from server.') + return self._fetch_candles_from_exchange(symbol, interval, exchange_name, user_name, start_datetime, + end_datetime) + + @staticmethod + def data_complete(data: pd.DataFrame, **kwargs) -> (bool, dict): + """ + Checks if the data completely satisfies the request. + + :param data: DataFrame containing the records. + :param kwargs: Arguments required for completeness check. + :return: A tuple (is_complete, updated_request_criteria) where is_complete is True if the data is complete, + False otherwise, and updated_request_criteria contains adjusted start/end times if data is incomplete. + """ + start_datetime: dt.datetime = kwargs.get('start_datetime') + end_datetime: dt.datetime = kwargs.get('end_datetime') + timeframe: str = kwargs.get('timeframe') + + if data.empty: + logger.debug("Data is empty.") + return False, kwargs # No data at all, proceed with the full original request + + temp_data = data.copy() + + # Ensure 'open_time' is in datetime format + if temp_data['open_time'].dtype != ' start_datetime + timeframe_to_timedelta(timeframe) + tolerance: + logger.debug("Data does not start early enough, even with tolerance.") + updated_request_criteria['end_datetime'] = min_timestamp # Fetch the missing earlier data + return False, updated_request_criteria + + if max_timestamp < end_datetime - timeframe_to_timedelta(timeframe) - tolerance: + logger.debug("Data does not extend late enough, even with tolerance.") + updated_request_criteria['start_datetime'] = max_timestamp # Fetch the missing later data + return False, updated_request_criteria + + # Filter data between start_datetime and end_datetime + mask = (temp_data['open_time_dt'] >= start_datetime) & (temp_data['open_time_dt'] <= end_datetime) + data_in_range = temp_data.loc[mask] + + expected_count = estimate_record_count(start_datetime, end_datetime, timeframe) + actual_count = len(data_in_range) + + logger.debug(f"Expected record count: {expected_count}, Actual record count: {actual_count}") + + tolerance = 1 + if actual_count < (expected_count - tolerance): + logger.debug("Insufficient records within the specified time range, even with tolerance.") + return False, updated_request_criteria + + logger.debug("Data completeness check passed.") + return True, kwargs + + def cache_exists(self, key: str) -> bool: + return key in self.cached_data + + def get_cache(self, key: str) -> Any | None: + if key not in self.cached_data: + logger.warning(f"The requested cache key({key}) doesn't exist!") + return None + return self.cached_data[key] + + def update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None: + logger.debug('Updating cache with new records.') + # Concatenate the new records with the existing cache + records = pd.concat([self.get_cache(key), more_records], axis=0, ignore_index=True) + # Drop duplicates based on 'open_time' and keep the first occurrence + records = records.drop_duplicates(subset="open_time", keep='first') + # Sort the records by 'open_time' + records = records.sort_values(by='open_time').reset_index(drop=True) + # Reindex 'id' to ensure the expected order + records['id'] = range(1, len(records) + 1) + # Set the updated DataFrame back to cache + self.set_cache(data=records, key=key) + + def update_cached_dict(self, cache_key: str, dict_key: str, data: Any) -> None: + """ + Updates a dictionary stored in cache. + + :param data: The data to insert into cache. + :param cache_key: The cache index key for the dictionary. + :param dict_key: The dictionary key for the data. + :return: None + """ + self.cached_data[cache_key].update({dict_key: data}) + + def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None: + if do_not_overwrite and key in self.cached_data: + return + self.cached_data[key] = data + logger.debug(f'Cache set for key: {key}') + + def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str, + start_datetime: dt.datetime = None, + end_datetime: dt.datetime = None) -> pd.DataFrame: + if start_datetime is None: + start_datetime = dt.datetime(year=2017, month=1, day=1) + + if end_datetime is None: + end_datetime = dt.datetime.utcnow() + + if start_datetime > end_datetime: + raise ValueError("Invalid start and end parameters: start_datetime must be before end_datetime.") + + exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name) + + expected_records = estimate_record_count(start_datetime, end_datetime, interval) + logger.info( + f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {expected_records}') + + if start_datetime == end_datetime: + end_datetime = None + + candles = exchange.get_historical_klines(symbol=symbol, interval=interval, start_dt=start_datetime, + end_dt=end_datetime) + + num_rec_records = len(candles.index) + if num_rec_records == 0: + logger.warning(f"No OHLCV data returned for {symbol}.") + return pd.DataFrame(columns=['open_time', 'open', 'high', 'low', 'close', 'volume']) + + logger.info(f'{num_rec_records} candles retrieved from the exchange.') + + open_times = candles.open_time + min_open_time = open_times.min() + max_open_time = open_times.max() + + if min_open_time < 1e10: + raise ValueError('Records are not in milliseconds') + + estimated_num_records = estimate_record_count(min_open_time, max_open_time, interval) + 1 + logger.info(f'Estimated number of records: {estimated_num_records}') + + if num_rec_records < estimated_num_records: + logger.info('Detected gaps in the data, attempting to fill missing records.') + candles = self.fill_data_holes(candles, interval) + + return candles + + @staticmethod + def fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame: + time_span = timeframe_to_timedelta(interval).total_seconds() / 60 + last_timestamp = None + filled_records = [] + + logger.info(f"Starting to fill data holes for interval: {interval}") + + for index, row in records.iterrows(): + time_stamp = row['open_time'] + + if last_timestamp is None: + last_timestamp = time_stamp + filled_records.append(row) + logger.debug(f"First timestamp: {time_stamp}") + continue + + delta_ms = time_stamp - last_timestamp + delta_minutes = (delta_ms / 1000) / 60 + + logger.debug(f"Timestamp: {time_stamp}, Delta minutes: {delta_minutes}") + + if delta_minutes > time_span: + num_missing_rec = int(delta_minutes / time_span) + step = int(delta_ms / num_missing_rec) + logger.debug(f"Gap detected. Filling {num_missing_rec} records with step: {step}") + + for ts in range(int(last_timestamp) + step, int(time_stamp), step): + new_row = row.copy() + new_row['open_time'] = ts + filled_records.append(new_row) + logger.debug(f"Filled timestamp: {ts}") + + filled_records.append(row) + last_timestamp = time_stamp + + logger.info("Data holes filled successfully.") + return pd.DataFrame(filled_records) + + def _populate_db(self, ex_details: List[str], data: pd.DataFrame = None) -> None: + if data is None or data.empty: + logger.debug(f'No records to insert {data}') + return + + table_name = self._make_key(ex_details) + sym, _, ex, _ = ex_details + + self.db.insert_candles_into_db(data, table_name=table_name, symbol=sym, exchange_name=ex) + logger.info(f'Data inserted into table {table_name}') + + @staticmethod + def _make_key(ex_details: List[str]) -> str: + sym, tf, ex, _ = ex_details + key = f'{sym}_{tf}_{ex}' + return key + +# Example usage +# args = { +# 'start_datetime': dt.datetime.now() - dt.timedelta(hours=1), # Example start time +# 'ex_details': ['BTCUSDT', '15m', 'Binance', 'user1'], +# } +# +# exchanges = ExchangeHandler() +# data = DataCache(exchanges) +# df = data.get_records_since(**args) +# +# # Disabling type checking for a specific instance +# data.TYPECHECKING_ENABLED = False diff --git a/src/Database.py b/src/Database.py index cecbe7c..27f1aa4 100644 --- a/src/Database.py +++ b/src/Database.py @@ -1,6 +1,6 @@ import sqlite3 from functools import lru_cache -from typing import Any +from typing import Any, Dict, List, Tuple import config import datetime as dt import pandas as pd @@ -19,7 +19,7 @@ class SQLite: cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)') """ - def __init__(self, db_file=None): + def __init__(self, db_file: str = None): self.db_file = db_file if db_file else config.DB_FILE self.connection = sqlite3.connect(self.db_file) @@ -41,11 +41,11 @@ class HDict(dict): hash(hdict) """ - def __hash__(self): + def __hash__(self) -> int: return hash(frozenset(self.items())) -def make_query(item: str, table: str, columns: list) -> str: +def make_query(item: str, table: str, columns: List[str]) -> str: """ Creates a SQL select query string with the required number of placeholders. @@ -53,40 +53,22 @@ def make_query(item: str, table: str, columns: list) -> str: :param table: The table to select from. :param columns: List of columns for the where clause. :return: The query string. - - Example: - -------- - query = make_query('id', 'test_table', ['name', 'age']) - # Result: 'SELECT id FROM test_table WHERE name = ? AND age = ?;' """ - an_itr = iter(columns) - k = next(an_itr) - where_str = f"SELECT {item} FROM {table} WHERE {k} = ?" - where_str += "".join([f" AND {k} = ?" for k in an_itr]) + ';' - return where_str + placeholders = " AND ".join([f"{col} = ?" for col in columns]) + return f"SELECT {item} FROM {table} WHERE {placeholders};" -def make_insert(table: str, values: tuple) -> str: +def make_insert(table: str, columns: Tuple[str, ...]) -> str: """ Creates a SQL insert query string with the required number of placeholders. :param table: The table to insert into. - :param values: Tuple of values to insert. + :param columns: Tuple of column names. :return: The query string. - - Example: - -------- - insert = make_insert('test_table', ('name', 'age')) - # Result: "INSERT INTO test_table ('name', 'age') VALUES(?, ?);" """ - itr1 = iter(values) - itr2 = iter(values) - k1 = next(itr1) - _ = next(itr2) - insert_str = f"INSERT INTO {table} ('{k1}'" - insert_str += "".join([f", '{k1}'" for k1 in itr1]) + ") VALUES(?" + "".join( - [", ?" for _ in enumerate(itr2)]) + ");" - return insert_str + col_names = ", ".join([f"'{col}'" for col in columns]) + placeholders = ", ".join(["?" for _ in columns]) + return f"INSERT INTO {table} ({col_names}) VALUES ({placeholders});" class Database: @@ -96,16 +78,11 @@ class Database: Example usage: -------------- - db = Database(exchanges=some_exchanges_object, db_file='test_db.sqlite') + db = Database(db_file='test_db.sqlite') db.execute_sql('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)') """ - def __init__(self, db_file=None): - """ - Initializes the Database class. - - :param db_file: Optional database file name. - """ + def __init__(self, db_file: str = None): self.db_file = db_file def execute_sql(self, sql: str) -> None: @@ -113,17 +90,12 @@ class Database: Executes a raw SQL statement. :param sql: SQL statement to execute. - - Example: - -------- - db = Database(exchanges=some_exchanges_object, db_file='test_db.sqlite') - db.execute_sql('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)') """ with SQLite(self.db_file) as con: cur = con.cursor() cur.execute(sql) - def get_item_where(self, item_name: str, table_name: str, filter_vals: tuple) -> int: + def get_item_where(self, item_name: str, table_name: str, filter_vals: Tuple[str, Any]) -> Any: """ Returns an item from a table where the filter results should isolate a single row. @@ -131,42 +103,29 @@ class Database: :param table_name: Name of the table. :param filter_vals: Tuple of column name and value to filter by. :return: The item. - - Example: - -------- - item = db.get_item_where('name', 'test_table', ('id', 1)) - # Fetches the 'name' from 'test_table' where 'id' is 1 """ with SQLite(self.db_file) as con: cur = con.cursor() qry = make_query(item_name, table_name, [filter_vals[0]]) cur.execute(qry, (filter_vals[1],)) - if user_id := cur.fetchone(): - return user_id[0] + if result := cur.fetchone(): + return result[0] else: error = f"Couldn't fetch item {item_name} from {table_name} where {filter_vals[0]} = {filter_vals[1]}" raise ValueError(error) - def get_rows_where(self, table: str, filter_vals: tuple) -> pd.DataFrame | None: + def get_rows_where(self, table: str, filter_vals: Tuple[str, Any]) -> pd.DataFrame | None: """ Returns a DataFrame containing all rows of a table that meet the filter criteria. :param table: Name of the table. :param filter_vals: Tuple of column name and value to filter by. :return: DataFrame of the query result or None if empty. - - Example: - -------- - rows = db.get_rows_where('test_table', ('name', 'test')) - # Fetches all rows from 'test_table' where 'name' is 'test' """ with SQLite(self.db_file) as con: - qry = f"SELECT * FROM {table} WHERE {filter_vals[0]}='{filter_vals[1]}'" - result = pd.read_sql(qry, con=con) - if not result.empty: - return result - else: - return None + qry = f"SELECT * FROM {table} WHERE {filter_vals[0]} = ?" + result = pd.read_sql(qry, con, params=(filter_vals[1],)) + return result if not result.empty else None def insert_dataframe(self, df: pd.DataFrame, table: str) -> None: """ @@ -174,30 +133,21 @@ class Database: :param df: DataFrame to insert. :param table: Name of the table. - - Example: - -------- - df = pd.DataFrame({'id': [1], 'name': ['test']}) - db.insert_dataframe(df, 'test_table') """ with SQLite(self.db_file) as con: df.to_sql(name=table, con=con, index=False, if_exists='append') - def insert_row(self, table: str, columns: tuple, values: tuple) -> None: + def insert_row(self, table: str, columns: Tuple[str, ...], values: Tuple[Any, ...]) -> None: """ Inserts a row into a specified table. :param table: Name of the table. :param columns: Tuple of column names. :param values: Tuple of values to insert. - - Example: - -------- - db.insert_row('test_table', ('id', 'name'), (1, 'test')) """ with SQLite(self.db_file) as conn: cursor = conn.cursor() - sql = make_insert(table=table, values=columns) + sql = make_insert(table=table, columns=columns) cursor.execute(sql, values) def table_exists(self, table_name: str) -> bool: @@ -206,11 +156,6 @@ class Database: :param table_name: Name of the table. :return: True if the table exists, False otherwise. - - Example: - -------- - exists = db._table_exists('test_table') - # Checks if 'test_table' exists in the database """ with SQLite(self.db_file) as conn: cursor = conn.cursor() @@ -220,7 +165,7 @@ class Database: return result is not None @lru_cache(maxsize=1000) - def get_from_static_table(self, item: str, table: str, indexes: HDict, create_id: bool = False) -> Any: + def get_from_static_table(self, item: str, table: str, indexes: dict, create_id: bool = False) -> Any: """ Returns the row id of an item from a table specified. If the item isn't listed in the table, it will insert the item into a new row and return the autoincremented id. The item received as a hashable @@ -231,23 +176,21 @@ class Database: :param indexes: Hashable dictionary of indexing columns and their values. :param create_id: If True, create a row if it doesn't exist and return the autoincrement ID. :return: The content of the field. - - Example: - -------- - exchange_id = db.get_from_static_table('id', 'exchange', HDict({'name': 'binance'}), create_id=True) """ with SQLite(self.db_file) as conn: cursor = conn.cursor() sql = make_query(item, table, list(indexes.keys())) cursor.execute(sql, tuple(indexes.values())) result = cursor.fetchone() + if result is None and create_id: sql = make_insert(table, tuple(indexes.keys())) cursor.execute(sql, tuple(indexes.values())) - sql = make_query(item, table, list(indexes.keys())) - cursor.execute(sql, tuple(indexes.values())) - result = cursor.fetchone() - return result[0] if result else None + result = cursor.lastrowid # Get the last inserted row ID + else: + result = result[0] if result else None + + return result def _fetch_exchange_id(self, exchange_name: str) -> int: """ @@ -255,25 +198,17 @@ class Database: :param exchange_name: Name of the exchange. :return: Primary ID of the exchange. - - Example: - -------- - exchange_id = db._fetch_exchange_id('binance') """ return self.get_from_static_table(item='id', table='exchange', create_id=True, indexes=HDict({'name': exchange_name})) def _fetch_market_id(self, symbol: str, exchange_name: str) -> int: """ - Returns the market ID for a trading pair listed in the database. + Returns the markets ID for a trading pair listed in the database. :param symbol: Symbol of the trading pair. :param exchange_name: Name of the exchange. :return: Market ID. - - Example: - -------- - market_id = db._fetch_market_id('BTC/USDT', 'binance') """ exchange_id = self._fetch_exchange_id(exchange_name) market_id = self.get_from_static_table(item='id', table='markets', create_id=True, @@ -289,21 +224,17 @@ class Database: :param table_name: Name of the table to insert into. :param symbol: Symbol of the trading pair. :param exchange_name: Name of the exchange. - - Example: - -------- - df = pd.DataFrame({ - 'open_time': [unix_time_millis(dt.datetime.utcnow())], - 'open': [1.0], - 'high': [1.0], - 'low': [1.0], - 'close': [1.0], - 'volume': [1.0] - }) - db._insert_candles_into_db(df, 'test_table', 'BTC/USDT', 'binance') """ market_id = self._fetch_market_id(symbol, exchange_name) - candlesticks.insert(0, 'market_id', market_id) + + # Check if 'market_id' column already exists in the DataFrame + if 'market_id' in candlesticks.columns: + # If it exists, set its value to the fetched market_id + candlesticks['market_id'] = market_id + else: + # If it doesn't exist, insert it as the first column + candlesticks.insert(0, 'market_id', market_id) + sql_create = f""" CREATE TABLE IF NOT EXISTS '{table_name}' ( id INTEGER PRIMARY KEY, @@ -332,21 +263,19 @@ class Database: :param st: Start datetime. :param et: End datetime (optional). :return: DataFrame of records. - - Example: - -------- - records = db.get_timestamped_records('test_table', 'open_time', start_time, end_time) """ with SQLite(self.db_file) as conn: start_stamp = unix_time_millis(st) if et is not None: end_stamp = unix_time_millis(et) q_str = ( - f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= {start_stamp} " - f"AND {timestamp_field} <= {end_stamp};" + f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= ? " + f"AND {timestamp_field} <= ?;" ) + records = pd.read_sql(q_str, conn, params=(start_stamp, end_stamp)) else: - q_str = f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= {start_stamp};" - records = pd.read_sql(q_str, conn) - records = records.drop('id', axis=1) + q_str = f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= ?;" + records = pd.read_sql(q_str, conn, params=(start_stamp,)) + + # records = records.drop('id', axis=1) Todo: Reminder I may need to put this back later. return records diff --git a/src/Exchange.py b/src/Exchange.py index ea8693a..3bde476 100644 --- a/src/Exchange.py +++ b/src/Exchange.py @@ -1,6 +1,6 @@ import ccxt import pandas as pd -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Tuple, Dict, List, Union, Any import time import logging @@ -84,6 +84,8 @@ class Exchange: Returns: int: The Unix timestamp in milliseconds. """ + if dt.tzinfo is None: + raise ValueError("datetime object must be timezone-aware or in UTC.") return int(dt.timestamp() * 1000) def _fetch_historical_klines(self, symbol: str, interval: str, @@ -103,6 +105,12 @@ class Exchange: if end_dt is None: end_dt = datetime.utcnow() + # Convert start_dt and end_dt to UTC if they are naive + if start_dt.tzinfo is None: + start_dt = start_dt.replace(tzinfo=timezone.utc) + if end_dt.tzinfo is None: + end_dt = end_dt.replace(tzinfo=timezone.utc) + max_interval = timedelta(days=200) data_frames = [] current_start = start_dt @@ -122,7 +130,7 @@ class Exchange: df_columns = ['open_time', 'open', 'high', 'low', 'close', 'volume'] candles_df = pd.DataFrame(candles, columns=df_columns) - candles_df['open_time'] = candles_df['open_time'] // 1000 + data_frames.append(candles_df) current_start = current_end @@ -515,39 +523,41 @@ class Exchange: return [] -# Usage Examples - -# Example 1: Initializing the Exchange class -api_keys = { - 'key': 'your_api_key', - 'secret': 'your_api_secret' -} -exchange = Exchange(name='Binance', api_keys=api_keys, exchange_id='binance') - -# Example 2: Fetching historical data -start_date = datetime(2022, 1, 1) -end_date = datetime(2022, 6, 1) -historical_data = exchange.get_historical_klines(symbol='BTC/USDT', interval='1d', - start_dt=start_date, end_dt=end_date) -print(historical_data) - -# Example 3: Fetching the current price of a symbol -current_price = exchange.get_price(symbol='BTC/USDT') -print(f"Current price of BTC/USDT: {current_price}") - -# Example 4: Placing a limit buy order -order_result, order_details = exchange.place_order(symbol='BTC/USDT', side='buy', type='limit', - timeInForce='GTC', quantity=0.001, price=30000) -print(order_result, order_details) - -# Example 5: Getting account balances -balances = exchange.get_balances() -print(balances) - -# Example 6: Fetching open orders -open_orders = exchange.get_open_orders() -print(open_orders) - -# Example 7: Fetching active trades -active_trades = exchange.get_active_trades() -print(active_trades) +# +# # Usage Examples +# +# # Example 1: Initializing the Exchange class +# api_keys = { +# 'key': 'your_api_key', +# 'secret': 'your_api_secret' +# } +# exchange = Exchange(name='Binance', api_keys=api_keys, exchange_id='binance') +# +# # Example 2: Fetching historical data +# start_date = datetime(2022, 1, 1) +# end_date = datetime(2022, 6, 1) +# historical_data = exchange.get_historical_klines(symbol='BTC/USDT', interval='1d', +# start_dt=start_date, end_dt=end_date) +# print(historical_data) +# +# # Example 3: Fetching the current price of a symbol +# current_price = exchange.get_price(symbol='BTC/USDT') +# print(f"Current price of BTC/USDT: {current_price}") +# +# # Example 4: Placing a limit buy order +# order_result, order_details = exchange.place_order(symbol='BTC/USDT', side='buy', type='limit', +# timeInForce='GTC', quantity=0.001, price=30000) +# print(order_result, order_details) +# +# # Example 5: Getting account balances +# balances = exchange.get_balances() +# print(balances) +# +# # Example 6: Fetching open orders +# open_orders = exchange.get_open_orders() +# print(open_orders) +# +# # Example 7: Fetching active trades +# active_trades = exchange.get_active_trades() +# print(active_trades) +# diff --git a/src/ExchangeInterface.py b/src/ExchangeInterface.py index 345b034..21296d9 100644 --- a/src/ExchangeInterface.py +++ b/src/ExchangeInterface.py @@ -1,8 +1,6 @@ import logging -import json from typing import List, Any, Dict import pandas as pd -import requests import ccxt from Exchange import Exchange @@ -46,7 +44,7 @@ class ExchangeInterface: self.add_exchange(user_name, exchange) return True except Exception as e: - logging.error(f"Failed to connect user '{user_name}' to exchange '{exchange_name}': {str(e)}") + logger.error(f"Failed to connect user '{user_name}' to exchange '{exchange_name}': {str(e)}") return False def add_exchange(self, user_name: str, exchange: Exchange): @@ -60,7 +58,7 @@ class ExchangeInterface: row = {'user': user_name, 'name': exchange.name, 'reference': exchange, 'balances': exchange.balances} self.exchange_data = add_row(self.exchange_data, row) except Exception as e: - logging.error(f"Couldn't create an instance of the exchange! {str(e)}") + logger.error(f"Couldn't create an instance of the exchange! {str(e)}") raise def get_exchange(self, ename: str, uname: str) -> Exchange: @@ -144,12 +142,12 @@ class ExchangeInterface: elif fetch_type == 'orders': data = reference.get_open_orders() else: - logging.error(f"Invalid fetch type: {fetch_type}") + logger.error(f"Invalid fetch type: {fetch_type}") return {} data_dict[name] = data except Exception as e: - logging.error(f"Error retrieving data for {name}: {str(e)}") + logger.error(f"Error retrieving data for {name}: {str(e)}") return data_dict diff --git a/src/candles.py b/src/candles.py index def398a..118c26d 100644 --- a/src/candles.py +++ b/src/candles.py @@ -43,12 +43,8 @@ class Candles: # Calculate the approximate start_datetime the first of n record will have. start_datetime = ts_of_n_minutes_ago(n=num_candles, candle_length=minutes_per_candle) - # Table name format is: __. Example: "BTCUSDT_15m_binance_spot" - key = f'{asset}_{timeframe}_{exchange}' - # Fetch records older than start_datetime. - candles = self.data.get_records_since(key=key, start_datetime=start_datetime, - record_length=minutes_per_candle, + candles = self.data.get_records_since(start_datetime=start_datetime, ex_details=[asset, timeframe, exchange, user_name]) if len(candles.index) < num_candles: timesince = dt.datetime.utcnow() - start_datetime diff --git a/src/maintenence/debuging_testing.py b/src/maintenence/debuging_testing.py index d8f9e2d..5d3e367 100644 --- a/src/maintenence/debuging_testing.py +++ b/src/maintenence/debuging_testing.py @@ -1,20 +1,38 @@ import ccxt +import pandas as pd +import datetime -def main(): - # Create an instance of the Binance exchange - binance = ccxt.binance({ - 'enableRateLimit': True, - 'verbose': False, # Ensure verbose mode is disabled - }) +def fetch_and_print_ohlcv(exchange_name, symbol, timeframe, since, limit=5): + # Initialize the exchange + exchange_class = getattr(ccxt, exchange_name) + exchange = exchange_class() - try: - # Load markets to test the connection - markets = binance.load_markets() - print("Markets loaded successfully") - except ccxt.BaseError as e: - print(f"Error loading markets: {str(e)}") + # Fetch historical candlestick data with a limit + ohlcv = exchange.fetch_ohlcv(symbol, timeframe, since=since, limit=limit) + + # Convert to DataFrame for better readability + df = pd.DataFrame(ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']) + + # Print the first few rows of the DataFrame + print("First few rows of the fetched OHLCV data:") + print(df.head()) + + # Print the timestamps in human-readable format + df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms') + print("\nFirst few timestamps in human-readable format:") + print(df[['timestamp', 'datetime']].head()) + + # Confirm the format of the timestamps + print("\nTimestamp format confirmation:") + for ts in df['timestamp']: + print(f"{ts} (milliseconds since Unix epoch)") -if __name__ == "__main__": - main() +# Example usage +exchange_name = 'binance' # Change this to your exchange +symbol = 'BTC/USDT' +timeframe = '5m' +since = int((datetime.datetime(2024, 8, 1) - datetime.datetime(1970, 1, 1)).total_seconds() * 1000) + +fetch_and_print_ohlcv(exchange_name, symbol, timeframe, since, limit=5) diff --git a/src/shared_utilities.py b/src/shared_utilities.py index fabe323..1e9c76f 100644 --- a/src/shared_utilities.py +++ b/src/shared_utilities.py @@ -36,19 +36,37 @@ def query_uptodate(records: pd.DataFrame, r_length_min: float) -> Union[float, N tolerance_minutes = 10 / 60 # 10 seconds tolerance in minutes if minutes_since_update > (r_length_min - tolerance_minutes): # Return the last timestamp in seconds - return ms_to_seconds(last_timestamp) + return last_timestamp return None -def ms_to_seconds(timestamp): +def ms_to_seconds(timestamp: float) -> float: + """ + Converts milliseconds to seconds. + + :param timestamp: The timestamp in milliseconds. + :return: The timestamp in seconds. + """ return timestamp / 1000 -def unix_time_seconds(d_time): +def unix_time_seconds(d_time: dt.datetime) -> float: + """ + Converts a datetime object to Unix timestamp in seconds. + + :param d_time: The datetime object to convert. + :return: The Unix timestamp in seconds. + """ return (d_time - epoch).total_seconds() -def unix_time_millis(d_time): +def unix_time_millis(d_time: dt.datetime) -> float: + """ + Converts a datetime object to Unix timestamp in milliseconds. + + :param d_time: The datetime object to convert. + :return: The Unix timestamp in milliseconds. + """ return (d_time - epoch).total_seconds() * 1000.0 @@ -72,6 +90,10 @@ def query_satisfied(start_datetime: dt.datetime, records: pd.DataFrame, r_length start_timestamp = unix_time_millis(start_datetime) print(f'Start timestamp: {start_timestamp}') + if records.empty: + print('No records found. Query cannot be satisfied.') + return float('nan') + # Get the oldest timestamp from the records passed in first_timestamp = float(records.open_time.min()) print(f'First timestamp in records: {first_timestamp}') @@ -84,17 +106,17 @@ def query_satisfied(start_datetime: dt.datetime, records: pd.DataFrame, r_length if start_timestamp <= first_timestamp + total_duration: return None - return first_timestamp / 1000 # Return in seconds + return first_timestamp @lru_cache(maxsize=500) -def ts_of_n_minutes_ago(n: int, candle_length: float) -> dt.datetime.timestamp: +def ts_of_n_minutes_ago(n: int, candle_length: float) -> dt.datetime: """ Returns the approximate datetime for the start of a candle that was 'n' candles ago. - :param n: int - The number of candles ago to calculate. - :param candle_length: float - The length of each candle in minutes. - :return: datetime - The approximate datetime for the start of the 'n'-th candle ago. + :param n: The number of candles ago to calculate. + :param candle_length: The length of each candle in minutes. + :return: The approximate datetime for the start of the 'n'-th candle ago. """ # Increment 'n' by 1 to ensure we account for the time that has passed since the last candle closed. n += 1 @@ -113,12 +135,12 @@ def ts_of_n_minutes_ago(n: int, candle_length: float) -> dt.datetime.timestamp: @lru_cache(maxsize=20) -def timeframe_to_minutes(timeframe): +def timeframe_to_minutes(timeframe: str) -> int: """ Converts a string representing a timeframe into an integer representing the approximate minutes. - :param timeframe: str - Timeframe format is [multiplier:focus]. eg '15m', '4h', '1d' - :return: int - Minutes the timeframe represents ex. '2h'-> 120(minutes). + :param timeframe: Timeframe format is [multiplier:focus]. e.g., '15m', '4h', '1d' + :return: Minutes the timeframe represents, e.g., '2h' -> 120 (minutes). """ # Extract the numerical part of the timeframe param. digits = int("".join([i if i.isdigit() else "" for i in timeframe])) diff --git a/tests/test_DataCache_v2.py b/tests/test_DataCache_v2.py new file mode 100644 index 0000000..8d4f302 --- /dev/null +++ b/tests/test_DataCache_v2.py @@ -0,0 +1,478 @@ +from DataCache_v2 import DataCache +from ExchangeInterface import ExchangeInterface +import unittest +import pandas as pd +import datetime as dt +import os +from Database import SQLite, Database + + +class DataGenerator: + def __init__(self, timeframe_str): + """ + Initialize the DataGenerator with a timeframe string like '2h', '5m', '1d', '1w', '1M', or '1y'. + """ + # Initialize attributes with placeholder values + self.timeframe_amount = None + self.timeframe_unit = None + # Set the actual timeframe + self.set_timeframe(timeframe_str) + + def set_timeframe(self, timeframe_str): + """ + Set the timeframe unit and amount based on a string like '2h', '5m', '1d', '1w', '1M', or '1y'. + """ + self.timeframe_amount = int(timeframe_str[:-1]) + unit = timeframe_str[-1] + + if unit == 's': + self.timeframe_unit = 'seconds' + elif unit == 'm': + self.timeframe_unit = 'minutes' + elif unit == 'h': + self.timeframe_unit = 'hours' + elif unit == 'd': + self.timeframe_unit = 'days' + elif unit == 'w': + self.timeframe_unit = 'weeks' + elif unit == 'M': + self.timeframe_unit = 'months' + elif unit == 'Y': + self.timeframe_unit = 'years' + else: + raise ValueError( + "Unsupported timeframe unit. Use 's,m,h,d,w,M,Y'.") + + def create_table(self, num_rec=None, start=None, end=None): + """ + Create a table with simulated data. If both start and end are provided, num_rec is derived from the interval. + If neither are provided the table will have num_rec and end at the current time. + + Parameters: + num_rec (int, optional): The number of records to generate. + start (datetime, optional): The start time for the first record. + end (datetime, optional): The end time for the last record. + + Returns: + pd.DataFrame: A DataFrame with the simulated data. + """ + # If neither start nor end are provided. + if start is None and end is None: + end = dt.datetime.utcnow() + if num_rec is None: + raise ValueError("num_rec must be provided if both start and end are not specified.") + + # If only start is provided. + if start is not None and end is not None: + total_duration = (end - start).total_seconds() + interval_seconds = self.timeframe_amount * self._get_seconds_per_unit(self.timeframe_unit) + num_rec = int(total_duration // interval_seconds) + 1 + + # If only end is provided. + if end is not None and start is None: + if num_rec is None: + raise ValueError("num_rec must be provided if both start and end are not specified.") + interval_seconds = self.timeframe_amount * self._get_seconds_per_unit(self.timeframe_unit) + start = end - dt.timedelta(seconds=(num_rec - 1) * interval_seconds) + + # Ensure start is aligned to the timeframe interval + start = self.round_down_datetime(start, self.timeframe_unit[0], self.timeframe_amount) + + # Generate times + times = [self.unix_time_millis(start + self._delta(i)) for i in range(num_rec)] + + df = pd.DataFrame({ + 'market_id': 1, + 'open_time': times, + 'open': [100 + i for i in range(num_rec)], + 'high': [110 + i for i in range(num_rec)], + 'low': [90 + i for i in range(num_rec)], + 'close': [105 + i for i in range(num_rec)], + 'volume': [1000 + i for i in range(num_rec)] + }) + + return df + + @staticmethod + def _get_seconds_per_unit(unit): + """Helper method to convert timeframe units to seconds.""" + units_in_seconds = { + 'seconds': 1, + 'minutes': 60, + 'hours': 3600, + 'days': 86400, + 'weeks': 604800, + 'months': 2592000, # Assuming 30 days per month + 'years': 31536000 # Assuming 365 days per year + } + if unit not in units_in_seconds: + raise ValueError(f"Unsupported timeframe unit: {unit}") + return units_in_seconds[unit] + + def generate_incomplete_data(self, query_offset, num_rec=5): + """ + Generate data that is incomplete, i.e., starts before the query but doesn't fully satisfy it. + """ + query_start_time = self.x_time_ago(query_offset) + start_time_for_data = self.get_start_time(query_start_time) + return self.create_table(num_rec, start=start_time_for_data) + + @staticmethod + def generate_missing_section(df, drop_start=5, drop_end=8): + """ + Generate data with a missing section. + """ + df = df.drop(df.index[drop_start:drop_end]).reset_index(drop=True) + return df + + def get_start_time(self, query_start_time): + margin = 2 + delta_args = {self.timeframe_unit: margin * self.timeframe_amount} + return query_start_time - dt.timedelta(**delta_args) + + def x_time_ago(self, offset): + """ + Returns a datetime object representing the current time minus the offset in the specified units. + """ + delta_args = {self.timeframe_unit: offset} + return dt.datetime.utcnow() - dt.timedelta(**delta_args) + + def _delta(self, i): + """ + Returns a timedelta object for the ith increment based on the timeframe unit and amount. + """ + delta_args = {self.timeframe_unit: i * self.timeframe_amount} + return dt.timedelta(**delta_args) + + @staticmethod + def unix_time_millis(dt_obj): + """ + Convert a datetime object to Unix time in milliseconds. + """ + epoch = dt.datetime(1970, 1, 1) + return int((dt_obj - epoch).total_seconds() * 1000) + + @staticmethod + def round_down_datetime(dt_obj: dt.datetime, unit: str, interval: int) -> dt.datetime: + if unit == 's': # Round down to the nearest interval of seconds + seconds = (dt_obj.second // interval) * interval + dt_obj = dt_obj.replace(second=seconds, microsecond=0) + elif unit == 'm': # Round down to the nearest interval of minutes + minutes = (dt_obj.minute // interval) * interval + dt_obj = dt_obj.replace(minute=minutes, second=0, microsecond=0) + elif unit == 'h': # Round down to the nearest interval of hours + hours = (dt_obj.hour // interval) * interval + dt_obj = dt_obj.replace(hour=hours, minute=0, second=0, microsecond=0) + elif unit == 'd': # Round down to the nearest interval of days + days = (dt_obj.day // interval) * interval + dt_obj = dt_obj.replace(day=days, hour=0, minute=0, second=0, microsecond=0) + elif unit == 'w': # Round down to the nearest interval of weeks + dt_obj -= dt.timedelta(days=dt_obj.weekday() % (interval * 7)) + dt_obj = dt_obj.replace(hour=0, minute=0, second=0, microsecond=0) + elif unit == 'M': # Round down to the nearest interval of months + months = ((dt_obj.month - 1) // interval) * interval + 1 + dt_obj = dt_obj.replace(month=months, day=1, hour=0, minute=0, second=0, microsecond=0) + elif unit == 'y': # Round down to the nearest interval of years + years = (dt_obj.year // interval) * interval + dt_obj = dt_obj.replace(year=years, month=1, day=1, hour=0, minute=0, second=0, microsecond=0) + return dt_obj + + +class TestDataCacheV2(unittest.TestCase): + def setUp(self): + # Set up database and exchanges + self.exchanges = ExchangeInterface() + self.exchanges.connect_exchange(exchange_name='binance', user_name='test_guy', api_keys=None) + self.db_file = 'test_db.sqlite' + self.database = Database(db_file=self.db_file) + + # Create necessary tables + sql_create_table_1 = f""" + CREATE TABLE IF NOT EXISTS test_table ( + id INTEGER PRIMARY KEY, + market_id INTEGER, + open_time INTEGER UNIQUE ON CONFLICT IGNORE, + open REAL NOT NULL, + high REAL NOT NULL, + low REAL NOT NULL, + close REAL NOT NULL, + volume REAL NOT NULL, + FOREIGN KEY (market_id) REFERENCES market (id) + )""" + sql_create_table_2 = """ + CREATE TABLE IF NOT EXISTS exchange ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE + )""" + sql_create_table_3 = """ + CREATE TABLE IF NOT EXISTS markets ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + symbol TEXT, + exchange_id INTEGER, + FOREIGN KEY (exchange_id) REFERENCES exchange(id) + )""" + with SQLite(db_file=self.db_file) as con: + con.execute(sql_create_table_1) + con.execute(sql_create_table_2) + con.execute(sql_create_table_3) + self.data = DataCache(self.exchanges) + self.data.db = self.database + + self.ex_details = ['BTC/USD', '2h', 'binance', 'test_guy'] + self.key = f'{self.ex_details[0]}_{self.ex_details[1]}_{self.ex_details[2]}' + + def tearDown(self): + if os.path.exists(self.db_file): + os.remove(self.db_file) + + def test_set_cache(self): + print('\nTesting set_cache() method without no-overwrite flag:') + self.data.set_cache(data='data', key=self.key) + attr = self.data.__getattribute__('cached_data') + self.assertEqual(attr[self.key], 'data') + print(' - Set cache without no-overwrite flag passed.') + + print('Testing set_cache() once again with new data without no-overwrite flag:') + self.data.set_cache(data='more_data', key=self.key) + attr = self.data.__getattribute__('cached_data') + self.assertEqual(attr[self.key], 'more_data') + print(' - Set cache with new data without no-overwrite flag passed.') + + print('Testing set_cache() method once again with more data with no-overwrite flag set:') + self.data.set_cache(data='even_more_data', key=self.key, do_not_overwrite=True) + attr = self.data.__getattribute__('cached_data') + self.assertEqual(attr[self.key], 'more_data') + print(' - Set cache with no-overwrite flag passed.') + + def test_cache_exists(self): + print('Testing cache_exists() method:') + self.assertFalse(self.data.cache_exists(key=self.key)) + print(' - Check for non-existent cache passed.') + + self.data.set_cache(data='data', key=self.key) + self.assertTrue(self.data.cache_exists(key=self.key)) + print(' - Check for existent cache passed.') + + def test_update_candle_cache(self): + print('Testing update_candle_cache() method:') + + # Initialize the DataGenerator with the 5-minute timeframe + data_gen = DataGenerator('5m') + + # Create initial DataFrame and insert into cache + df_initial = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 0, 0)) + print(f'Inserting this table into cache:\n{df_initial}\n') + self.data.set_cache(data=df_initial, key=self.key) + + # Create new DataFrame to be added to cache + df_new = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 15, 0)) + print(f'Updating cache with this table:\n{df_new}\n') + self.data.update_candle_cache(more_records=df_new, key=self.key) + + # Retrieve the resulting DataFrame from cache + result = self.data.get_cache(key=self.key) + print(f'The resulting table in cache is:\n{result}\n') + + # Create the expected DataFrame + expected = data_gen.create_table(num_rec=6, start=dt.datetime(2024, 8, 9, 0, 0, 0)) + print(f'The expected open_time values are:\n{expected["open_time"].tolist()}\n') + + # Assert that the open_time values in the result match those in the expected DataFrame, in order + assert result['open_time'].tolist() == expected['open_time'].tolist(), \ + f"open_time values in result are {result['open_time'].tolist()}" \ + f" but expected {expected['open_time'].tolist()}" + + print(f'The results open_time values match:\n{result["open_time"].tolist()}\n') + print(' - Update cache with new records passed.') + + def test_update_cached_dict(self): + print('Testing update_cached_dict() method:') + self.data.set_cache(data={}, key=self.key) + self.data.update_cached_dict(cache_key=self.key, dict_key='sub_key', data='value') + + cache = self.data.get_cache(key=self.key) + self.assertEqual(cache['sub_key'], 'value') + print(' - Update dictionary in cache passed.') + + def test_get_cache(self): + print('Testing get_cache() method:') + self.data.set_cache(data='data', key=self.key) + result = self.data.get_cache(key=self.key) + self.assertEqual(result, 'data') + print(' - Retrieve cache passed.') + + def _test_get_records_since(self, set_cache=True, set_db=True, query_offset=None, num_rec=None, ex_details=None, + simulate_scenarios=None): + """ + Test the get_records_since() method by generating a table of simulated data, + inserting it into cache and/or database, and then querying the records. + + Parameters: + set_cache (bool): If True, the generated table is inserted into the cache. + set_db (bool): If True, the generated table is inserted into the database. + query_offset (int, optional): The offset in the timeframe units for the query. + num_rec (int, optional): The number of records to generate in the simulated table. + ex_details (list, optional): Exchange details to generate the cache key. + simulate_scenarios (str, optional): The type of scenario to simulate. Options are: + - 'not_enough_data': The table data doesn't go far enough back. + - 'incomplete_data': The table doesn't have enough records to satisfy the query. + - 'missing_section': The table has missing records in the middle. + """ + + print('Testing get_records_since() method:') + + # Use provided ex_details or fallback to the class attribute. + ex_details = ex_details or self.ex_details + # Generate a cache/database key using exchange details. + key = f'{ex_details[0]}_{ex_details[1]}_{ex_details[2]}' + + # Set default number of records if not provided. + num_rec = num_rec or 12 + table_timeframe = ex_details[1] # Extract timeframe from exchange details. + + # Initialize DataGenerator with the given timeframe. + data_gen = DataGenerator(table_timeframe) + + if simulate_scenarios == 'not_enough_data': + # Set query_offset to a time earlier than the start of the table data. + query_offset = (num_rec + 5) * data_gen.timeframe_amount + else: + # Default to querying for 1 record length less than the table duration. + query_offset = query_offset or (num_rec - 1) * data_gen.timeframe_amount + + if simulate_scenarios == 'incomplete_data': + # Set start time to generate fewer records than required. + start_time_for_data = data_gen.x_time_ago(num_rec * data_gen.timeframe_amount) + num_rec = 5 # Set a smaller number of records to simulate incomplete data. + else: + # No specific start time for data generation. + start_time_for_data = None + + # Create the initial data table. + df_initial = data_gen.create_table(num_rec, start=start_time_for_data) + + if simulate_scenarios == 'missing_section': + # Simulate missing section in the data by dropping records. + df_initial = data_gen.generate_missing_section(df_initial, drop_start=2, drop_end=5) + + # Convert 'open_time' to datetime for better readability. + temp_df = df_initial.copy() + temp_df['open_time'] = pd.to_datetime(temp_df['open_time'], unit='ms') + print(f'Table Created:\n{temp_df}') + + if set_cache: + # Insert the generated table into cache. + print('Inserting table into cache.') + self.data.set_cache(data=df_initial, key=key) + + if set_db: + # Insert the generated table into the database. + print('Inserting table into database.') + with SQLite(self.db_file) as con: + df_initial.to_sql(key, con, if_exists='replace', index=False) + + # Calculate the start time for querying the records. + start_datetime = data_gen.x_time_ago(query_offset) + # Defaults to current time if not provided to get_records_since() + query_end_time = dt.datetime.utcnow() + print(f'Requesting records from {start_datetime} to {query_end_time}') + + # Query the records since the calculated start time. + result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details) + + # Filter the initial data table to match the query time. + expected = df_initial[df_initial['open_time'] >= data_gen.unix_time_millis(start_datetime)].reset_index( + drop=True) + temp_df = expected.copy() + temp_df['open_time'] = pd.to_datetime(temp_df['open_time'], unit='ms') + print(f'Expected table:\n{temp_df}') + + # Print the result from the query for comparison. + temp_df = result.copy() + temp_df['open_time'] = pd.to_datetime(temp_df['open_time'], unit='ms') + print(f'Resulting table:\n{temp_df}') + + if simulate_scenarios in ['not_enough_data', 'incomplete_data', 'missing_section']: + # Check that the result has more rows than the expected incomplete data. + assert result.shape[0] > expected.shape[ + 0], "Result has fewer or equal rows compared to the incomplete data." + print("\nThe returned DataFrames has filled in the missing data!") + else: + # Ensure the result and expected dataframes match in shape and content. + assert result.shape == expected.shape, f"Shape mismatch: {result.shape} vs {expected.shape}" + pd.testing.assert_series_equal(result['open_time'], expected['open_time'], check_dtype=False) + print("\nThe DataFrames have the same shape and the 'open_time' columns match.") + + # Verify that the oldest timestamp in the result is within the allowed time difference. + oldest_timestamp = pd.to_datetime(result['open_time'].min(), unit='ms') + time_diff = oldest_timestamp - start_datetime + max_allowed_time_diff = dt.timedelta(**{data_gen.timeframe_unit: data_gen.timeframe_amount}) + + assert dt.timedelta(0) <= time_diff <= max_allowed_time_diff, \ + f"Oldest timestamp {oldest_timestamp} is not within " \ + f"{data_gen.timeframe_amount} {data_gen.timeframe_unit} of {start_datetime}" + + print(f'The first timestamp is {time_diff} from {start_datetime}') + + # Verify that the newest timestamp in the result is within the allowed time difference. + newest_timestamp = pd.to_datetime(result['open_time'].max(), unit='ms') + time_diff_end = abs(query_end_time - newest_timestamp) + + assert dt.timedelta(0) <= time_diff_end <= max_allowed_time_diff, \ + f"Newest timestamp {newest_timestamp} is not within {data_gen.timeframe_amount} " \ + f"{data_gen.timeframe_unit} of {query_end_time}" + + print(f'The last timestamp is {time_diff_end} from {query_end_time}') + print(' - Fetch records within the specified time range passed.') + + def test_get_records_since(self): + print('\nTest get_records_since with records set in cache') + self._test_get_records_since() + + print('\nTest get_records_since with records not in cache') + self._test_get_records_since(set_cache=False) + + print('\nTest get_records_since with records not in database') + self._test_get_records_since(set_cache=False, set_db=False) + + print('\nTest get_records_since with a different timeframe') + self._test_get_records_since(query_offset=None, num_rec=None, + ex_details=['BTC/USD', '15m', 'binance', 'test_guy']) + + print('\nTest get_records_since where data does not go far enough back') + self._test_get_records_since(simulate_scenarios='not_enough_data') + + print('\nTest get_records_since with incomplete data') + self._test_get_records_since(simulate_scenarios='incomplete_data') + + print('\nTest get_records_since with missing section in data') + self._test_get_records_since(simulate_scenarios='missing_section') + + def test_populate_db(self): + print('Testing _populate_db() method:') + # Create a table of candle records. + data_gen = DataGenerator(self.ex_details[1]) + data = data_gen.create_table(num_rec=5) + + self.data._populate_db(ex_details=self.ex_details, data=data) + + with SQLite(self.db_file) as con: + result = pd.read_sql(f'SELECT * FROM "{self.key}"', con) + self.assertFalse(result.empty) + print(' - Populate database with data passed.') + + def test_fetch_candles_from_exchange(self): + print('Testing _fetch_candles_from_exchange() method:') + start_time = dt.datetime.utcnow() - dt.timedelta(days=1) + end_time = dt.datetime.utcnow() + + result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h', exchange_name='binance', + user_name='test_guy', start_datetime=start_time, + end_datetime=end_time) + self.assertIsInstance(result, pd.DataFrame) + self.assertFalse(result.empty) + print(' - Fetch candle data from exchange passed.') + + +if __name__ == '__main__': + unittest.main()