from DataCache_v2 import DataCache from ExchangeInterface import ExchangeInterface import unittest import pandas as pd import datetime as dt import os from Database import SQLite, Database class DataGenerator: def __init__(self, timeframe_str): """ Initialize the DataGenerator with a timeframe string like '2h', '5m', '1d', '1w', '1M', or '1y'. """ # Initialize attributes with placeholder values self.timeframe_amount = None self.timeframe_unit = None # Set the actual timeframe self.set_timeframe(timeframe_str) def set_timeframe(self, timeframe_str): """ Set the timeframe unit and amount based on a string like '2h', '5m', '1d', '1w', '1M', or '1y'. """ self.timeframe_amount = int(timeframe_str[:-1]) unit = timeframe_str[-1] if unit == 's': self.timeframe_unit = 'seconds' elif unit == 'm': self.timeframe_unit = 'minutes' elif unit == 'h': self.timeframe_unit = 'hours' elif unit == 'd': self.timeframe_unit = 'days' elif unit == 'w': self.timeframe_unit = 'weeks' elif unit == 'M': self.timeframe_unit = 'months' elif unit == 'Y': self.timeframe_unit = 'years' else: raise ValueError( "Unsupported timeframe unit. Use 's,m,h,d,w,M,Y'.") def create_table(self, num_rec=None, start=None, end=None): """ Create a table with simulated data. If both start and end are provided, num_rec is derived from the interval. If neither are provided the table will have num_rec and end at the current time. Parameters: num_rec (int, optional): The number of records to generate. start (datetime, optional): The start time for the first record. end (datetime, optional): The end time for the last record. Returns: pd.DataFrame: A DataFrame with the simulated data. """ # If neither start nor end are provided. if start is None and end is None: end = dt.datetime.utcnow() if num_rec is None: raise ValueError("num_rec must be provided if both start and end are not specified.") # If only start is provided. if start is not None and end is not None: total_duration = (end - start).total_seconds() interval_seconds = self.timeframe_amount * self._get_seconds_per_unit(self.timeframe_unit) num_rec = int(total_duration // interval_seconds) + 1 # If only end is provided. if end is not None and start is None: if num_rec is None: raise ValueError("num_rec must be provided if both start and end are not specified.") interval_seconds = self.timeframe_amount * self._get_seconds_per_unit(self.timeframe_unit) start = end - dt.timedelta(seconds=(num_rec - 1) * interval_seconds) # Ensure start is aligned to the timeframe interval start = self.round_down_datetime(start, self.timeframe_unit[0], self.timeframe_amount) # Generate times times = [self.unix_time_millis(start + self._delta(i)) for i in range(num_rec)] df = pd.DataFrame({ 'market_id': 1, 'open_time': times, 'open': [100 + i for i in range(num_rec)], 'high': [110 + i for i in range(num_rec)], 'low': [90 + i for i in range(num_rec)], 'close': [105 + i for i in range(num_rec)], 'volume': [1000 + i for i in range(num_rec)] }) return df @staticmethod def _get_seconds_per_unit(unit): """Helper method to convert timeframe units to seconds.""" units_in_seconds = { 'seconds': 1, 'minutes': 60, 'hours': 3600, 'days': 86400, 'weeks': 604800, 'months': 2592000, # Assuming 30 days per month 'years': 31536000 # Assuming 365 days per year } if unit not in units_in_seconds: raise ValueError(f"Unsupported timeframe unit: {unit}") return units_in_seconds[unit] def generate_incomplete_data(self, query_offset, num_rec=5): """ Generate data that is incomplete, i.e., starts before the query but doesn't fully satisfy it. """ query_start_time = self.x_time_ago(query_offset) start_time_for_data = self.get_start_time(query_start_time) return self.create_table(num_rec, start=start_time_for_data) @staticmethod def generate_missing_section(df, drop_start=5, drop_end=8): """ Generate data with a missing section. """ df = df.drop(df.index[drop_start:drop_end]).reset_index(drop=True) return df def get_start_time(self, query_start_time): margin = 2 delta_args = {self.timeframe_unit: margin * self.timeframe_amount} return query_start_time - dt.timedelta(**delta_args) def x_time_ago(self, offset): """ Returns a datetime object representing the current time minus the offset in the specified units. """ delta_args = {self.timeframe_unit: offset} return dt.datetime.utcnow() - dt.timedelta(**delta_args) def _delta(self, i): """ Returns a timedelta object for the ith increment based on the timeframe unit and amount. """ delta_args = {self.timeframe_unit: i * self.timeframe_amount} return dt.timedelta(**delta_args) @staticmethod def unix_time_millis(dt_obj): """ Convert a datetime object to Unix time in milliseconds. """ epoch = dt.datetime(1970, 1, 1) return int((dt_obj - epoch).total_seconds() * 1000) @staticmethod def round_down_datetime(dt_obj: dt.datetime, unit: str, interval: int) -> dt.datetime: if unit == 's': # Round down to the nearest interval of seconds seconds = (dt_obj.second // interval) * interval dt_obj = dt_obj.replace(second=seconds, microsecond=0) elif unit == 'm': # Round down to the nearest interval of minutes minutes = (dt_obj.minute // interval) * interval dt_obj = dt_obj.replace(minute=minutes, second=0, microsecond=0) elif unit == 'h': # Round down to the nearest interval of hours hours = (dt_obj.hour // interval) * interval dt_obj = dt_obj.replace(hour=hours, minute=0, second=0, microsecond=0) elif unit == 'd': # Round down to the nearest interval of days days = (dt_obj.day // interval) * interval dt_obj = dt_obj.replace(day=days, hour=0, minute=0, second=0, microsecond=0) elif unit == 'w': # Round down to the nearest interval of weeks dt_obj -= dt.timedelta(days=dt_obj.weekday() % (interval * 7)) dt_obj = dt_obj.replace(hour=0, minute=0, second=0, microsecond=0) elif unit == 'M': # Round down to the nearest interval of months months = ((dt_obj.month - 1) // interval) * interval + 1 dt_obj = dt_obj.replace(month=months, day=1, hour=0, minute=0, second=0, microsecond=0) elif unit == 'y': # Round down to the nearest interval of years years = (dt_obj.year // interval) * interval dt_obj = dt_obj.replace(year=years, month=1, day=1, hour=0, minute=0, second=0, microsecond=0) return dt_obj class TestDataCacheV2(unittest.TestCase): def setUp(self): # Set up database and exchanges self.exchanges = ExchangeInterface() self.exchanges.connect_exchange(exchange_name='binance', user_name='test_guy', api_keys=None) self.db_file = 'test_db.sqlite' self.database = Database(db_file=self.db_file) # Create necessary tables sql_create_table_1 = f""" CREATE TABLE IF NOT EXISTS test_table ( id INTEGER PRIMARY KEY, market_id INTEGER, open_time INTEGER UNIQUE ON CONFLICT IGNORE, open REAL NOT NULL, high REAL NOT NULL, low REAL NOT NULL, close REAL NOT NULL, volume REAL NOT NULL, FOREIGN KEY (market_id) REFERENCES market (id) )""" sql_create_table_2 = """ CREATE TABLE IF NOT EXISTS exchange ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT UNIQUE )""" sql_create_table_3 = """ CREATE TABLE IF NOT EXISTS markets ( id INTEGER PRIMARY KEY AUTOINCREMENT, symbol TEXT, exchange_id INTEGER, FOREIGN KEY (exchange_id) REFERENCES exchange(id) )""" with SQLite(db_file=self.db_file) as con: con.execute(sql_create_table_1) con.execute(sql_create_table_2) con.execute(sql_create_table_3) self.data = DataCache(self.exchanges) self.data.db = self.database self.ex_details = ['BTC/USD', '2h', 'binance', 'test_guy'] self.key = f'{self.ex_details[0]}_{self.ex_details[1]}_{self.ex_details[2]}' def tearDown(self): if os.path.exists(self.db_file): os.remove(self.db_file) def test_set_cache(self): print('\nTesting set_cache() method without no-overwrite flag:') self.data.set_cache(data='data', key=self.key) attr = self.data.__getattribute__('cached_data') self.assertEqual(attr[self.key], 'data') print(' - Set cache without no-overwrite flag passed.') print('Testing set_cache() once again with new data without no-overwrite flag:') self.data.set_cache(data='more_data', key=self.key) attr = self.data.__getattribute__('cached_data') self.assertEqual(attr[self.key], 'more_data') print(' - Set cache with new data without no-overwrite flag passed.') print('Testing set_cache() method once again with more data with no-overwrite flag set:') self.data.set_cache(data='even_more_data', key=self.key, do_not_overwrite=True) attr = self.data.__getattribute__('cached_data') self.assertEqual(attr[self.key], 'more_data') print(' - Set cache with no-overwrite flag passed.') def test_cache_exists(self): print('Testing cache_exists() method:') self.assertFalse(self.data.cache_exists(key=self.key)) print(' - Check for non-existent cache passed.') self.data.set_cache(data='data', key=self.key) self.assertTrue(self.data.cache_exists(key=self.key)) print(' - Check for existent cache passed.') def test_update_candle_cache(self): print('Testing update_candle_cache() method:') # Initialize the DataGenerator with the 5-minute timeframe data_gen = DataGenerator('5m') # Create initial DataFrame and insert into cache df_initial = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 0, 0)) print(f'Inserting this table into cache:\n{df_initial}\n') self.data.set_cache(data=df_initial, key=self.key) # Create new DataFrame to be added to cache df_new = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 15, 0)) print(f'Updating cache with this table:\n{df_new}\n') self.data.update_candle_cache(more_records=df_new, key=self.key) # Retrieve the resulting DataFrame from cache result = self.data.get_cache(key=self.key) print(f'The resulting table in cache is:\n{result}\n') # Create the expected DataFrame expected = data_gen.create_table(num_rec=6, start=dt.datetime(2024, 8, 9, 0, 0, 0)) print(f'The expected open_time values are:\n{expected["open_time"].tolist()}\n') # Assert that the open_time values in the result match those in the expected DataFrame, in order assert result['open_time'].tolist() == expected['open_time'].tolist(), \ f"open_time values in result are {result['open_time'].tolist()}" \ f" but expected {expected['open_time'].tolist()}" print(f'The results open_time values match:\n{result["open_time"].tolist()}\n') print(' - Update cache with new records passed.') def test_update_cached_dict(self): print('Testing update_cached_dict() method:') self.data.set_cache(data={}, key=self.key) self.data.update_cached_dict(cache_key=self.key, dict_key='sub_key', data='value') cache = self.data.get_cache(key=self.key) self.assertEqual(cache['sub_key'], 'value') print(' - Update dictionary in cache passed.') def test_get_cache(self): print('Testing get_cache() method:') self.data.set_cache(data='data', key=self.key) result = self.data.get_cache(key=self.key) self.assertEqual(result, 'data') print(' - Retrieve cache passed.') def _test_get_records_since(self, set_cache=True, set_db=True, query_offset=None, num_rec=None, ex_details=None, simulate_scenarios=None): """ Test the get_records_since() method by generating a table of simulated data, inserting it into cache and/or database, and then querying the records. Parameters: set_cache (bool): If True, the generated table is inserted into the cache. set_db (bool): If True, the generated table is inserted into the database. query_offset (int, optional): The offset in the timeframe units for the query. num_rec (int, optional): The number of records to generate in the simulated table. ex_details (list, optional): Exchange details to generate the cache key. simulate_scenarios (str, optional): The type of scenario to simulate. Options are: - 'not_enough_data': The table data doesn't go far enough back. - 'incomplete_data': The table doesn't have enough records to satisfy the query. - 'missing_section': The table has missing records in the middle. """ print('Testing get_records_since() method:') # Use provided ex_details or fallback to the class attribute. ex_details = ex_details or self.ex_details # Generate a cache/database key using exchange details. key = f'{ex_details[0]}_{ex_details[1]}_{ex_details[2]}' # Set default number of records if not provided. num_rec = num_rec or 12 table_timeframe = ex_details[1] # Extract timeframe from exchange details. # Initialize DataGenerator with the given timeframe. data_gen = DataGenerator(table_timeframe) if simulate_scenarios == 'not_enough_data': # Set query_offset to a time earlier than the start of the table data. query_offset = (num_rec + 5) * data_gen.timeframe_amount else: # Default to querying for 1 record length less than the table duration. query_offset = query_offset or (num_rec - 1) * data_gen.timeframe_amount if simulate_scenarios == 'incomplete_data': # Set start time to generate fewer records than required. start_time_for_data = data_gen.x_time_ago(num_rec * data_gen.timeframe_amount) num_rec = 5 # Set a smaller number of records to simulate incomplete data. else: # No specific start time for data generation. start_time_for_data = None # Create the initial data table. df_initial = data_gen.create_table(num_rec, start=start_time_for_data) if simulate_scenarios == 'missing_section': # Simulate missing section in the data by dropping records. df_initial = data_gen.generate_missing_section(df_initial, drop_start=2, drop_end=5) # Convert 'open_time' to datetime for better readability. temp_df = df_initial.copy() temp_df['open_time'] = pd.to_datetime(temp_df['open_time'], unit='ms') print(f'Table Created:\n{temp_df}') if set_cache: # Insert the generated table into cache. print('Inserting table into cache.') self.data.set_cache(data=df_initial, key=key) if set_db: # Insert the generated table into the database. print('Inserting table into database.') with SQLite(self.db_file) as con: df_initial.to_sql(key, con, if_exists='replace', index=False) # Calculate the start time for querying the records. start_datetime = data_gen.x_time_ago(query_offset) # Defaults to current time if not provided to get_records_since() query_end_time = dt.datetime.utcnow() print(f'Requesting records from {start_datetime} to {query_end_time}') # Query the records since the calculated start time. result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details) # Filter the initial data table to match the query time. expected = df_initial[df_initial['open_time'] >= data_gen.unix_time_millis(start_datetime)].reset_index( drop=True) temp_df = expected.copy() temp_df['open_time'] = pd.to_datetime(temp_df['open_time'], unit='ms') print(f'Expected table:\n{temp_df}') # Print the result from the query for comparison. temp_df = result.copy() temp_df['open_time'] = pd.to_datetime(temp_df['open_time'], unit='ms') print(f'Resulting table:\n{temp_df}') if simulate_scenarios in ['not_enough_data', 'incomplete_data', 'missing_section']: # Check that the result has more rows than the expected incomplete data. assert result.shape[0] > expected.shape[ 0], "Result has fewer or equal rows compared to the incomplete data." print("\nThe returned DataFrames has filled in the missing data!") else: # Ensure the result and expected dataframes match in shape and content. assert result.shape == expected.shape, f"Shape mismatch: {result.shape} vs {expected.shape}" pd.testing.assert_series_equal(result['open_time'], expected['open_time'], check_dtype=False) print("\nThe DataFrames have the same shape and the 'open_time' columns match.") # Verify that the oldest timestamp in the result is within the allowed time difference. oldest_timestamp = pd.to_datetime(result['open_time'].min(), unit='ms') time_diff = oldest_timestamp - start_datetime max_allowed_time_diff = dt.timedelta(**{data_gen.timeframe_unit: data_gen.timeframe_amount}) assert dt.timedelta(0) <= time_diff <= max_allowed_time_diff, \ f"Oldest timestamp {oldest_timestamp} is not within " \ f"{data_gen.timeframe_amount} {data_gen.timeframe_unit} of {start_datetime}" print(f'The first timestamp is {time_diff} from {start_datetime}') # Verify that the newest timestamp in the result is within the allowed time difference. newest_timestamp = pd.to_datetime(result['open_time'].max(), unit='ms') time_diff_end = abs(query_end_time - newest_timestamp) assert dt.timedelta(0) <= time_diff_end <= max_allowed_time_diff, \ f"Newest timestamp {newest_timestamp} is not within {data_gen.timeframe_amount} " \ f"{data_gen.timeframe_unit} of {query_end_time}" print(f'The last timestamp is {time_diff_end} from {query_end_time}') print(' - Fetch records within the specified time range passed.') def test_get_records_since(self): print('\nTest get_records_since with records set in cache') self._test_get_records_since() print('\nTest get_records_since with records not in cache') self._test_get_records_since(set_cache=False) print('\nTest get_records_since with records not in database') self._test_get_records_since(set_cache=False, set_db=False) print('\nTest get_records_since with a different timeframe') self._test_get_records_since(query_offset=None, num_rec=None, ex_details=['BTC/USD', '15m', 'binance', 'test_guy']) print('\nTest get_records_since where data does not go far enough back') self._test_get_records_since(simulate_scenarios='not_enough_data') print('\nTest get_records_since with incomplete data') self._test_get_records_since(simulate_scenarios='incomplete_data') print('\nTest get_records_since with missing section in data') self._test_get_records_since(simulate_scenarios='missing_section') def test_populate_db(self): print('Testing _populate_db() method:') # Create a table of candle records. data_gen = DataGenerator(self.ex_details[1]) data = data_gen.create_table(num_rec=5) self.data._populate_db(ex_details=self.ex_details, data=data) with SQLite(self.db_file) as con: result = pd.read_sql(f'SELECT * FROM "{self.key}"', con) self.assertFalse(result.empty) print(' - Populate database with data passed.') def test_fetch_candles_from_exchange(self): print('Testing _fetch_candles_from_exchange() method:') start_time = dt.datetime.utcnow() - dt.timedelta(days=1) end_time = dt.datetime.utcnow() result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h', exchange_name='binance', user_name='test_guy', start_datetime=start_time, end_datetime=end_time) self.assertIsInstance(result, pd.DataFrame) self.assertFalse(result.empty) print(' - Fetch candle data from exchange passed.') if __name__ == '__main__': unittest.main()