Re-wrote DataCache.py and Implemented tests and got them all passing.

This commit is contained in:
Rob 2024-08-15 22:39:38 -03:00
parent c398a423a3
commit d288eebbec
9 changed files with 1287 additions and 325 deletions

View File

@ -5,13 +5,12 @@ from Database import Database
from shared_utilities import query_satisfied, query_uptodate, unix_time_millis, timeframe_to_minutes from shared_utilities import query_satisfied, query_uptodate, unix_time_millis, timeframe_to_minutes
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DataCache: class DataCache:
""" """
Fetches manages data limits and optimises memory storage. Fetches and manages data limits and optimizes memory storage.
Handles connections and operations for the given exchanges. Handles connections and operations for the given exchanges.
Example usage: Example usage:
@ -19,156 +18,229 @@ class DataCache:
db = DataCache(exchanges=some_exchanges_object) db = DataCache(exchanges=some_exchanges_object)
""" """
# Disable during production for improved performance.
TYPECHECKING_ENABLED = True
NO_RECORDS_FOUND = float('nan')
def __init__(self, exchanges): def __init__(self, exchanges):
""" """
Initializes the DataCache class. Initializes the DataCache class.
:param exchanges: The exchanges object handling communication with connected exchanges. :param exchanges: The exchanges object handling communication with connected exchanges.
""" """
# Maximum number of tables to cache at any given time.
self.max_tables = 50 self.max_tables = 50
# Maximum number of records to be kept per table.
self.max_records = 1000 self.max_records = 1000
# A dictionary that holds all the cached records.
self.cached_data = {} self.cached_data = {}
# The class that handles the DB interactions.
self.db = Database() self.db = Database()
# The class that handles exchange interactions.
self.exchanges = exchanges self.exchanges = exchanges
def cache_exists(self, key: str) -> bool: def cache_exists(self, key: str) -> bool:
""" """
Return False if a cache doesn't exist for this key. Checks if a cache exists for the given key.
:param key: The access key.
:return: True if cache exists, False otherwise.
""" """
return key in self.cached_data return key in self.cached_data
def update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None: def update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None:
""" """
Adds records to existing cache.
:param more_records: Adds records to existing cache. :param more_records: The new records to be added.
:param key: The access key. :param key: The access key.
:return: None. :return: None.
""" """
# Combine the new candles with the previously cached dataframe.
records = pd.concat([more_records, self.get_cache(key)], axis=0, ignore_index=True) records = pd.concat([more_records, self.get_cache(key)], axis=0, ignore_index=True)
# Drop any duplicates from overlap.
records = records.drop_duplicates(subset="open_time", keep='first') records = records.drop_duplicates(subset="open_time", keep='first')
# Sort the records by open_time.
records = records.sort_values(by='open_time').reset_index(drop=True) records = records.sort_values(by='open_time').reset_index(drop=True)
# Replace the incomplete dataframe with the modified one.
self.set_cache(data=records, key=key) self.set_cache(data=records, key=key)
return
def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None: def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None:
""" """
Creates a new cache key and inserts some data. Creates a new cache key and inserts data.
Todo: This is where this data will be passed to a cache server.
:param data: The records to insert into cache. :param data: The records to insert into cache.
:param key: The index key for the data. :param key: The index key for the data.
:param do_not_overwrite: - Flag to prevent overwriting existing data. :param do_not_overwrite: Flag to prevent overwriting existing data.
:return: None :return: None
""" """
# If the flag is set don't overwrite existing data.
if do_not_overwrite and key in self.cached_data: if do_not_overwrite and key in self.cached_data:
return return
# Assign the data
self.cached_data[key] = data self.cached_data[key] = data
def update_cached_dict(self, cache_key, dict_key: str, data: Any) -> None: def update_cached_dict(self, cache_key: str, dict_key: str, data: Any) -> None:
""" """
Updates a dictionary stored in cache. Updates a dictionary stored in cache.
Todo: This is where this data will be passed to a cache server.
:param data: - The records to insert into cache. :param data: The data to insert into cache.
:param cache_key: - The cache index key for the dictionary. :param cache_key: The cache index key for the dictionary.
:param dict_key: - The dictionary key for the data. :param dict_key: The dictionary key for the data.
:return: None :return: None
""" """
# Assign the data
self.cached_data[cache_key].update({dict_key: data}) self.cached_data[cache_key].update({dict_key: data})
def get_cache(self, key: str) -> Any: def get_cache(self, key: str) -> Any:
""" """
Returns data indexed by key. Returns data indexed by key.
Todo: This is where data will be retrieved from a cache server.
:param key: The index key for the data. :param key: The index key for the data.
:return: Any|None - The requested data or None on key error. :return: Any|None - The requested data or None on key error.
""" """
if key not in self.cached_data: if key not in self.cached_data:
print(f"[WARNING: DataCache.py] The requested cache key({key}) doesn't exist!") logger.warning(f"The requested cache key({key}) doesn't exist!")
return None return None
return self.cached_data[key] return self.cached_data[key]
def improved_get_records_since(self, key: str, start_datetime: dt.datetime, record_length: int,
ex_details: List[str]) -> pd.DataFrame:
"""
Fetches records since the specified start datetime.
:param key: The cache key.
:param start_datetime: The start datetime to fetch records from.
:param record_length: The required number of records.
:param ex_details: Exchange details.
:return: DataFrame containing the records.
"""
try:
target = 'cache'
args = {
'key': key,
'start_datetime': start_datetime,
'end_datetime': dt.datetime.utcnow(),
'record_length': record_length,
'ex_details': ex_details
}
df = self.get_or_fetch_from(target=target, **args)
except Exception as e:
logger.error(f"An error occurred: {str(e)}")
raise
def get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame:
key = kwargs.get('key')
start_datetime = kwargs.get('start_datetime')
end_datetime = kwargs.get('start_datetime')
record_length = kwargs.get('record_length')
ex_details = kwargs.get('ex_details')
if self.TYPECHECKING_ENABLED:
# Type checking
if not isinstance(key, str):
raise TypeError("key must be a string")
if not isinstance(start_datetime, dt.datetime):
raise TypeError("start_datetime must be a datetime object")
if not isinstance(end_datetime, dt.datetime):
raise TypeError("end_datetime must be a datetime object")
if not isinstance(record_length, int):
raise TypeError("record_length must be an integer")
if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details):
raise TypeError("ex_details must be a list of strings")
# Ensure all required arguments are provided
if not all([key, start_datetime, record_length, ex_details]):
raise ValueError("Missing required arguments")
def get_from_cache():
return pd.DataFrame
def get_from_database():
return pd.DataFrame
def get_from_server():
return pd.DataFrame
def data_complete(data, **kwargs) -> bool:
"""Check if a dataframe completely satisfied a request."""
sd = kwargs.get('start_datetime')
ed = kwargs.get('start_datetime')
rl = kwargs.get('record_length')
is_complete = True
return is_complete
request_criteria = {
'start_datetime': start_datetime,
'end_datetime': end_datetime,
'record_length': record_length,
}
if target == 'cache':
result = get_from_cache()
if data_complete(result, **request_criteria):
return result
else:
self.get_or_fetch_from('database', **kwargs)
elif target == 'database':
result = get_from_database()
if data_complete(result, **request_criteria):
return result
else:
self.get_or_fetch_from('server', **kwargs)
elif target == 'server':
result = get_from_server()
if data_complete(result, **request_criteria):
return result
else:
logger.error('Unable to fetch the requested data.')
else:
raise ValueError(f'Not a valid target: {target}')
def get_records_since(self, key: str, start_datetime: dt.datetime, record_length: int, def get_records_since(self, key: str, start_datetime: dt.datetime, record_length: int,
ex_details: List[str]) -> pd.DataFrame: ex_details: List[str]) -> pd.DataFrame:
""" """
Return any records from the cache indexed by table_name that are newer than start_datetime. Fetches records since the specified start datetime.
:param ex_details: List[str] - Details to pass to the server :param key: The cache key.
:param key: str - The dictionary table_name of the records. :param start_datetime: The start datetime to fetch records from.
:param start_datetime: dt.datetime - The datetime of the first record requested. :param record_length: The required number of records.
:param record_length: int - The timespan of the records. :param ex_details: Exchange details.
:return: DataFrame containing the records.
:return: pd.DataFrame - The Requested records
Example:
--------
records = data_cache.get_records_since('BTC/USD_2h_binance', dt.datetime.utcnow() - dt.timedelta(minutes=60), 60, ['BTC/USD', '2h', 'binance'])
""" """
try: try:
# End time of query defaults to the current time.
end_datetime = dt.datetime.utcnow() end_datetime = dt.datetime.utcnow()
if self.cache_exists(key=key): if self.cache_exists(key=key):
logger.debug('Getting records from cache.') logger.debug('Getting records from cache.')
# If the records exist, retrieve them from the cache.
records = self.get_cache(key) records = self.get_cache(key)
else: else:
# If they don't exist in cache, get them from the database.
logger.debug( logger.debug(
f'Records not in cache. Requesting from DB: starting at: {start_datetime} to: {end_datetime}') f'Records not in cache. Requesting from DB: starting at: {start_datetime} to: {end_datetime}')
records = self.get_records_since_from_db(table_name=key, st=start_datetime, records = self.get_records_since_from_db(table_name=key, st=start_datetime, et=end_datetime,
et=end_datetime, rl=record_length, ex_details=ex_details) rl=record_length, ex_details=ex_details)
logger.debug(f'Got {len(records.index)} records from DB.') logger.debug(f'Got {len(records.index)} records from DB.')
self.set_cache(data=records, key=key) self.set_cache(data=records, key=key)
# Check if the records in the cache go far enough back to satisfy the query. first_timestamp = query_satisfied(start_datetime=start_datetime, records=records,
first_timestamp = query_satisfied(start_datetime=start_datetime,
records=records,
r_length_min=record_length) r_length_min=record_length)
if first_timestamp: if pd.isna(first_timestamp):
# The records didn't go far enough back if a timestamp was returned. logger.debug('No records found to satisfy the query, continuing to fetch more records.')
additional_records = self.get_records_since_from_db(table_name=key, st=start_datetime, et=end_datetime,
rl=record_length, ex_details=ex_details)
elif first_timestamp is not None:
end_time = dt.datetime.utcfromtimestamp(first_timestamp) end_time = dt.datetime.utcfromtimestamp(first_timestamp)
logger.debug(f'Requesting additional records from {start_datetime} to {end_time}') logger.debug(f'Requesting additional records from {start_datetime} to {end_time}')
# Request additional records from the database. additional_records = self.get_records_since_from_db(table_name=key, st=start_datetime, et=end_time,
additional_records = self.get_records_since_from_db(table_name=key, st=start_datetime, rl=record_length, ex_details=ex_details)
et=end_time, rl=record_length, if not additional_records.empty:
ex_details=ex_details)
logger.debug(f'Got {len(additional_records.index)} additional records from DB.') logger.debug(f'Got {len(additional_records.index)} additional records from DB.')
if not additional_records.empty: self.update_candle_cache(additional_records, key)
# If more records were received, update the cache.
self.update_candle_cache(additional_records, key)
# Check if the records received are up-to-date.
last_timestamp = query_uptodate(records=records, r_length_min=record_length) last_timestamp = query_uptodate(records=records, r_length_min=record_length)
if last_timestamp is not None:
if last_timestamp:
# The query was not up-to-date if a timestamp was returned.
start_time = dt.datetime.utcfromtimestamp(last_timestamp) start_time = dt.datetime.utcfromtimestamp(last_timestamp)
logger.debug(f'Requesting additional records from {start_time} to {end_datetime}') logger.debug(f'Requesting additional records from {start_time} to {end_datetime}')
# Request additional records from the database. additional_records = self.get_records_since_from_db(table_name=key, st=start_time, et=end_datetime,
additional_records = self.get_records_since_from_db(table_name=key, st=start_time, rl=record_length, ex_details=ex_details)
et=end_datetime, rl=record_length,
ex_details=ex_details)
logger.debug(f'Got {len(additional_records.index)} additional records from DB.') logger.debug(f'Got {len(additional_records.index)} additional records from DB.')
if not additional_records.empty: if not additional_records.empty:
self.update_candle_cache(additional_records, key) self.update_candle_cache(additional_records, key)
# Create a UTC timestamp.
_timestamp = unix_time_millis(start_datetime) _timestamp = unix_time_millis(start_datetime)
logger.debug(f"Start timestamp in human-readable form: {dt.datetime.utcfromtimestamp(_timestamp / 1000.0)}") logger.debug(f"Start timestamp in human-readable form: {dt.datetime.utcfromtimestamp(_timestamp / 1000.0)}")
# Return all records equal to or newer than the timestamp.
result = self.get_cache(key).query('open_time >= @_timestamp') result = self.get_cache(key).query('open_time >= @_timestamp')
return result return result
@ -176,64 +248,61 @@ class DataCache:
logger.error(f"An error occurred: {str(e)}") logger.error(f"An error occurred: {str(e)}")
raise raise
def get_records_since_from_db(self, table_name: str, st: dt.datetime, def get_records_since_from_db(self, table_name: str, st: dt.datetime, et: dt.datetime, rl: float,
et: dt.datetime, rl: float, ex_details: list) -> pd.DataFrame: ex_details: List[str]) -> pd.DataFrame:
""" """
Returns records from a specified table that meet a criteria and ensures the records are complete. Fetches records from the database since the specified start datetime.
If the records do not go back far enough or are not up-to-date, it fetches additional records
from an exchange and populates the table.
:param table_name: Database table name. :param table_name: The name of the table in the database.
:param st: Start datetime. :param st: The start datetime to fetch records from.
:param et: End datetime. :param et: The end datetime to fetch records until.
:param rl: Timespan in minutes each record represents. :param rl: The required number of records.
:param ex_details: Exchange details [symbol, interval, exchange_name]. :param ex_details: Exchange details.
:return: DataFrame of records. :return: DataFrame containing the records.
Example:
--------
records = db.get_records_since('test_table', start_time, end_time, 1, ['BTC/USDT', '1m', 'binance'])
""" """
def add_data(data, tn, start_t, end_t): def add_data(data: pd.DataFrame, tn: str, start_t: dt.datetime, end_t: dt.datetime) -> pd.DataFrame:
new_records = self._populate_db(table_name=tn, start_time=start_t, end_time=end_t, ex_details=ex_details) new_records = self._populate_db(table_name=tn, start_time=start_t, end_time=end_t, ex_details=ex_details)
print(f'Got {len(new_records.index)} records from exchange_name') logger.debug(f'Got {len(new_records.index)} records from exchange_name')
if not new_records.empty: if not new_records.empty:
data = pd.concat([data, new_records], axis=0, ignore_index=True) data = pd.concat([data, new_records], axis=0, ignore_index=True)
data = data.drop_duplicates(subset="open_time", keep='first') data = data.drop_duplicates(subset="open_time", keep='first')
return data return data
if self.db.table_exists(table_name=table_name): if self.db.table_exists(table_name=table_name):
print('\nTable existed retrieving records from DB') logger.debug('Table existed retrieving records from DB')
print(f'Requesting from {st} to {et}') logger.debug(f'Requesting from {st} to {et}')
records = self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=st, et=et) records = self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=st, et=et)
print(f'Got {len(records.index)} records from db') logger.debug(f'Got {len(records.index)} records from db')
else: else:
print(f'\nTable didnt exist fetching from {ex_details[2]}') logger.debug(f"Table didn't exist fetching from {ex_details[2]}")
temp = (((unix_time_millis(et) - unix_time_millis(st)) / 1000) / 60) / rl temp = (((unix_time_millis(et) - unix_time_millis(st)) / 1000) / 60) / rl
print(f'Requesting from {st} to {et}, Should be {temp} records') logger.debug(f'Requesting from {st} to {et}, Should be {temp} records')
records = self._populate_db(table_name=table_name, start_time=st, end_time=et, ex_details=ex_details) records = self._populate_db(table_name=table_name, start_time=st, end_time=et, ex_details=ex_details)
print(f'Got {len(records.index)} records from {ex_details[2]}') logger.debug(f'Got {len(records.index)} records from {ex_details[2]}')
first_timestamp = query_satisfied(start_datetime=st, records=records, r_length_min=rl) first_timestamp = query_satisfied(start_datetime=st, records=records, r_length_min=rl)
if first_timestamp: if pd.isna(first_timestamp):
print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}') logger.debug('No records found to satisfy the query, continuing to fetch more records.')
print(f'first ts on record is: {first_timestamp}') records = add_data(data=records, tn=table_name, start_t=st, end_t=et)
elif first_timestamp:
logger.debug(f'Records did not go far enough back. Requesting from {ex_details[2]}')
logger.debug(f'First ts on record is: {first_timestamp}')
end_time = dt.datetime.utcfromtimestamp(first_timestamp) end_time = dt.datetime.utcfromtimestamp(first_timestamp)
print(f'Requesting from {st} to {end_time}') logger.debug(f'Requesting from {st} to {end_time}')
records = add_data(data=records, tn=table_name, start_t=st, end_t=end_time) records = add_data(data=records, tn=table_name, start_t=st, end_t=end_time)
last_timestamp = query_uptodate(records=records, r_length_min=rl) last_timestamp = query_uptodate(records=records, r_length_min=rl)
if last_timestamp: if last_timestamp:
print(f'\nRecords were not updated. Requesting from {ex_details[2]}.') logger.debug(f'Records were not updated. Requesting from {ex_details[2]}.')
print(f'the last record on file is: {last_timestamp}') logger.debug(f'The last record on file is: {last_timestamp}')
start_time = dt.datetime.utcfromtimestamp(last_timestamp) start_time = dt.datetime.utcfromtimestamp(last_timestamp)
print(f'Requesting from {start_time} to {et}') logger.debug(f'Requesting from {start_time} to {et}')
records = add_data(data=records, tn=table_name, start_t=start_time, end_t=et) records = add_data(data=records, tn=table_name, start_t=start_time, end_t=et)
return records return records
def _populate_db(self, table_name: str, start_time: dt.datetime, ex_details: list, def _populate_db(self, table_name: str, start_time: dt.datetime, ex_details: List[str],
end_time: dt.datetime = None) -> pd.DataFrame: end_time: dt.datetime = None) -> pd.DataFrame:
""" """
Populates a database table with records from the exchange. Populates a database table with records from the exchange.
@ -243,10 +312,6 @@ class DataCache:
:param end_time: End time to fetch the records until (optional). :param end_time: End time to fetch the records until (optional).
:param ex_details: Exchange details [symbol, interval, exchange_name, user_name]. :param ex_details: Exchange details [symbol, interval, exchange_name, user_name].
:return: DataFrame of the data downloaded. :return: DataFrame of the data downloaded.
Example:
--------
records = db._populate_table('test_table', start_time, ['BTC/USDT', '1m', 'binance', 'user1'])
""" """
if end_time is None: if end_time is None:
end_time = dt.datetime.utcnow() end_time = dt.datetime.utcnow()
@ -256,7 +321,7 @@ class DataCache:
if not records.empty: if not records.empty:
self.db.insert_candles_into_db(records, table_name=table_name, symbol=sym, exchange_name=ex) self.db.insert_candles_into_db(records, table_name=table_name, symbol=sym, exchange_name=ex)
else: else:
print(f'No records inserted {records}') logger.debug(f'No records inserted {records}')
return records return records
def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str, def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str,
@ -272,10 +337,6 @@ class DataCache:
:param start_datetime: Start datetime for fetching data (optional). :param start_datetime: Start datetime for fetching data (optional).
:param end_datetime: End datetime for fetching data (optional). :param end_datetime: End datetime for fetching data (optional).
:return: DataFrame of candle data. :return: DataFrame of candle data.
Example:
--------
candles = db._fetch_candles_from_exchange('BTC/USDT', '1m', 'binance', 'user1', start_time, end_time)
""" """
def fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame: def fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame:
@ -285,10 +346,6 @@ class DataCache:
:param records: DataFrame containing the original records. :param records: DataFrame containing the original records.
:param interval: Interval of the data (e.g., '1m', '5m'). :param interval: Interval of the data (e.g., '1m', '5m').
:return: DataFrame with gaps filled. :return: DataFrame with gaps filled.
Example:
--------
filled_records = fill_data_holes(df, '1m')
""" """
time_span = timeframe_to_minutes(interval) time_span = timeframe_to_minutes(interval)
last_timestamp = None last_timestamp = None
@ -299,20 +356,17 @@ class DataCache:
for index, row in records.iterrows(): for index, row in records.iterrows():
time_stamp = row['open_time'] time_stamp = row['open_time']
# If last_timestamp is None, this is the first record
if last_timestamp is None: if last_timestamp is None:
last_timestamp = time_stamp last_timestamp = time_stamp
filled_records.append(row) filled_records.append(row)
logger.debug(f"First timestamp: {time_stamp}") logger.debug(f"First timestamp: {time_stamp}")
continue continue
# Calculate the difference in milliseconds and minutes
delta_ms = time_stamp - last_timestamp delta_ms = time_stamp - last_timestamp
delta_minutes = (delta_ms / 1000) / 60 delta_minutes = (delta_ms / 1000) / 60
logger.debug(f"Timestamp: {time_stamp}, Delta minutes: {delta_minutes}") logger.debug(f"Timestamp: {time_stamp}, Delta minutes: {delta_minutes}")
# If the gap is larger than the time span of the interval, fill the gap
if delta_minutes > time_span: if delta_minutes > time_span:
num_missing_rec = int(delta_minutes / time_span) num_missing_rec = int(delta_minutes / time_span)
step = int(delta_ms / num_missing_rec) step = int(delta_ms / num_missing_rec)
@ -330,44 +384,47 @@ class DataCache:
logger.info("Data holes filled successfully.") logger.info("Data holes filled successfully.")
return pd.DataFrame(filled_records) return pd.DataFrame(filled_records)
# Default start date if not provided
if start_datetime is None: if start_datetime is None:
start_datetime = dt.datetime(year=2017, month=1, day=1) start_datetime = dt.datetime(year=2017, month=1, day=1)
# Default end date if not provided
if end_datetime is None: if end_datetime is None:
end_datetime = dt.datetime.utcnow() end_datetime = dt.datetime.utcnow()
# Check if start date is greater than end date
if start_datetime > end_datetime: if start_datetime > end_datetime:
raise ValueError("Invalid start and end parameters: start_datetime must be before end_datetime.") raise ValueError("Invalid start and end parameters: start_datetime must be before end_datetime.")
# Get the exchange object
exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name) exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name)
# Calculate the expected number of records expected_records = (((unix_time_millis(end_datetime) - unix_time_millis(
temp = (((unix_time_millis(end_datetime) - unix_time_millis(
start_datetime)) / 1000) / 60) / timeframe_to_minutes(interval) start_datetime)) / 1000) / 60) / timeframe_to_minutes(interval)
logger.info(
f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {expected_records}')
logger.info(f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {temp}')
# If start and end times are the same, set end_datetime to None
if start_datetime == end_datetime: if start_datetime == end_datetime:
end_datetime = None end_datetime = None
# Fetch historical candlestick data from the exchange
candles = exchange.get_historical_klines(symbol=symbol, interval=interval, start_dt=start_datetime, candles = exchange.get_historical_klines(symbol=symbol, interval=interval, start_dt=start_datetime,
end_dt=end_datetime) end_dt=end_datetime)
num_rec_records = len(candles.index) num_rec_records = len(candles.index)
logger.info(f'{num_rec_records} candles retrieved from the exchange.') logger.info(f'{num_rec_records} candles retrieved from the exchange.')
# Check if the retrieved data covers the expected time range
open_times = candles.open_time open_times = candles.open_time
estimated_num_records = (((open_times.max() - open_times.min()) / 1000) / 60) / timeframe_to_minutes(interval) min_open_time = open_times.min()
max_open_time = open_times.max()
if min_open_time < 1e10:
raise ValueError('Records are not in milliseconds')
max_open_time /= 1000
min_open_time /= 1000
estimated_num_records = ((max_open_time - min_open_time) / 60) / timeframe_to_minutes(interval)
logger.info(f'Estimated number of records: {estimated_num_records}')
# Fill in any missing data if the retrieved data is less than expected
if num_rec_records < estimated_num_records: if num_rec_records < estimated_num_records:
logger.info('Detected gaps in the data, attempting to fill missing records.')
candles = fill_data_holes(candles, interval) candles = fill_data_holes(candles, interval)
return candles return candles

454
src/DataCache_v2.py Normal file
View File

@ -0,0 +1,454 @@
from typing import List, Any
import pandas as pd
import datetime as dt
import logging
from shared_utilities import unix_time_millis
from Database import Database
import numpy as np
# Configure logging
logger = logging.getLogger(__name__)
def timeframe_to_timedelta(timeframe: str) -> pd.Timedelta | pd.DateOffset:
digits = int("".join([i if i.isdigit() else "" for i in timeframe]))
unit = "".join([i if i.isalpha() else "" for i in timeframe])
if unit == 'm':
return pd.Timedelta(minutes=digits)
elif unit == 'h':
return pd.Timedelta(hours=digits)
elif unit == 'd':
return pd.Timedelta(days=digits)
elif unit == 'w':
return pd.Timedelta(weeks=digits)
elif unit == 'M':
return pd.DateOffset(months=digits)
elif unit == 'Y':
return pd.DateOffset(years=digits)
else:
raise ValueError(f"Invalid timeframe unit: {unit}")
def estimate_record_count(start_time, end_time, timeframe: str) -> int:
"""
Estimate the number of records expected between start_time and end_time based on the given timeframe.
Accepts either datetime objects or Unix timestamps in milliseconds.
"""
# Check if the input is in milliseconds (timestamp)
if isinstance(start_time, (int, float, np.integer)) and isinstance(end_time, (int, float, np.integer)):
# Convert timestamps from milliseconds to seconds for calculation
start_time = int(start_time) / 1000
end_time = int(end_time) / 1000
start_datetime = dt.datetime.utcfromtimestamp(start_time)
end_datetime = dt.datetime.utcfromtimestamp(end_time)
elif isinstance(start_time, dt.datetime) and isinstance(end_time, dt.datetime):
start_datetime = start_time
end_datetime = end_time
else:
raise ValueError("start_time and end_time must be either both "
"datetime objects or both Unix timestamps in milliseconds.")
delta = timeframe_to_timedelta(timeframe)
total_seconds = (end_datetime - start_datetime).total_seconds()
expected_records = total_seconds // delta.total_seconds()
return int(expected_records)
def generate_expected_timestamps(start_datetime: dt.datetime, end_datetime: dt.datetime,
timeframe: str) -> pd.DatetimeIndex:
delta = timeframe_to_timedelta(timeframe)
if isinstance(delta, pd.Timedelta):
return pd.date_range(start=start_datetime, end=end_datetime, freq=delta)
elif isinstance(delta, pd.DateOffset):
current = start_datetime
timestamps = []
while current <= end_datetime:
timestamps.append(current)
current += delta
return pd.DatetimeIndex(timestamps)
class DataCache:
TYPECHECKING_ENABLED = True
def __init__(self, exchanges):
self.db = Database()
self.exchanges = exchanges
self.cached_data = {}
logger.info("DataCache initialized.")
def get_records_since(self, start_datetime: dt.datetime, ex_details: List[str]) -> pd.DataFrame:
if self.TYPECHECKING_ENABLED:
if not isinstance(start_datetime, dt.datetime):
raise TypeError("start_datetime must be a datetime object")
if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details):
raise TypeError("ex_details must be a list of strings")
if not len(ex_details) == 4:
raise TypeError("ex_details must include [asset, timeframe, exchange, user_name]")
try:
args = {
'start_datetime': start_datetime,
'end_datetime': dt.datetime.utcnow(),
'ex_details': ex_details,
}
return self.get_or_fetch_from('cache', **args)
except Exception as e:
logger.error(f"An error occurred: {str(e)}")
raise
def get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame:
start_datetime = kwargs.get('start_datetime')
end_datetime = kwargs.get('end_datetime')
ex_details = kwargs.get('ex_details')
timeframe = kwargs.get('ex_details')[1]
if self.TYPECHECKING_ENABLED:
if not isinstance(start_datetime, dt.datetime):
raise TypeError("start_datetime must be a datetime object")
if not isinstance(end_datetime, dt.datetime):
raise TypeError("end_datetime must be a datetime object")
if not isinstance(timeframe, str):
raise TypeError("record_length must be a string representing the timeframe")
if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details):
raise TypeError("ex_details must be a list of strings")
if not all([start_datetime, end_datetime, timeframe, ex_details]):
raise ValueError("Missing required arguments")
request_criteria = {
'start_datetime': start_datetime,
'end_datetime': end_datetime,
'timeframe': timeframe,
}
key = self._make_key(ex_details)
combined_data = pd.DataFrame()
if target == 'cache':
resources = [self.get_candles_from_cache, self.get_from_database, self.get_from_server]
elif target == 'database':
resources = [self.get_from_database, self.get_from_server]
elif target == 'server':
resources = [self.get_from_server]
else:
raise ValueError('Not a valid Target!')
for fetch_method in resources:
result = fetch_method(**kwargs)
if not result.empty:
combined_data = pd.concat([combined_data, result]).drop_duplicates()
if not combined_data.empty and 'open_time' in combined_data.columns:
combined_data = combined_data.sort_values(by='open_time').drop_duplicates(subset='open_time',
keep='first')
is_complete, request_criteria = self.data_complete(combined_data, **request_criteria)
if is_complete:
if fetch_method in [self.get_from_database, self.get_from_server]:
self.update_candle_cache(combined_data, key)
if fetch_method == self.get_from_server:
self._populate_db(ex_details, combined_data)
return combined_data
kwargs.update(request_criteria) # Update kwargs with new start/end times for next fetch attempt
logger.error('Unable to fetch the requested data.')
return combined_data if not combined_data.empty else pd.DataFrame()
def get_candles_from_cache(self, **kwargs) -> pd.DataFrame:
start_datetime = kwargs.get('start_datetime')
end_datetime = kwargs.get('end_datetime')
ex_details = kwargs.get('ex_details')
if self.TYPECHECKING_ENABLED:
if not isinstance(start_datetime, dt.datetime):
raise TypeError("start_datetime must be a datetime object")
if not isinstance(end_datetime, dt.datetime):
raise TypeError("end_datetime must be a datetime object")
if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details):
raise TypeError("ex_details must be a list of strings")
if not all([start_datetime, end_datetime, ex_details]):
raise ValueError("Missing required arguments")
key = self._make_key(ex_details)
logger.debug('Getting records from cache.')
df = self.get_cache(key)
if df is None:
logger.debug("Cache records didn't exist.")
return pd.DataFrame()
logger.debug('Filtering records.')
df_filtered = df[(df['open_time'] >= unix_time_millis(start_datetime)) & (
df['open_time'] <= unix_time_millis(end_datetime))].reset_index(drop=True)
return df_filtered
def get_from_database(self, **kwargs) -> pd.DataFrame:
start_datetime = kwargs.get('start_datetime')
end_datetime = kwargs.get('end_datetime')
ex_details = kwargs.get('ex_details')
if self.TYPECHECKING_ENABLED:
if not isinstance(start_datetime, dt.datetime):
raise TypeError("start_datetime must be a datetime object")
if not isinstance(end_datetime, dt.datetime):
raise TypeError("end_datetime must be a datetime object")
if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details):
raise TypeError("ex_details must be a list of strings")
if not all([start_datetime, end_datetime, ex_details]):
raise ValueError("Missing required arguments")
table_name = self._make_key(ex_details)
if not self.db.table_exists(table_name):
logger.debug('Records not in database.')
return pd.DataFrame()
logger.debug('Getting records from database.')
return self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=start_datetime,
et=end_datetime)
def get_from_server(self, **kwargs) -> pd.DataFrame:
symbol = kwargs.get('ex_details')[0]
interval = kwargs.get('ex_details')[1]
exchange_name = kwargs.get('ex_details')[2]
user_name = kwargs.get('ex_details')[3]
start_datetime = kwargs.get('start_datetime')
end_datetime = kwargs.get('end_datetime')
if self.TYPECHECKING_ENABLED:
if not isinstance(symbol, str):
raise TypeError("symbol must be a string")
if not isinstance(interval, str):
raise TypeError("interval must be a string")
if not isinstance(exchange_name, str):
raise TypeError("exchange_name must be a string")
if not isinstance(user_name, str):
raise TypeError("user_name must be a string")
if not isinstance(start_datetime, dt.datetime):
raise TypeError("start_datetime must be a datetime object")
if not isinstance(end_datetime, dt.datetime):
raise TypeError("end_datetime must be a datetime object")
logger.debug('Getting records from server.')
return self._fetch_candles_from_exchange(symbol, interval, exchange_name, user_name, start_datetime,
end_datetime)
@staticmethod
def data_complete(data: pd.DataFrame, **kwargs) -> (bool, dict):
"""
Checks if the data completely satisfies the request.
:param data: DataFrame containing the records.
:param kwargs: Arguments required for completeness check.
:return: A tuple (is_complete, updated_request_criteria) where is_complete is True if the data is complete,
False otherwise, and updated_request_criteria contains adjusted start/end times if data is incomplete.
"""
start_datetime: dt.datetime = kwargs.get('start_datetime')
end_datetime: dt.datetime = kwargs.get('end_datetime')
timeframe: str = kwargs.get('timeframe')
if data.empty:
logger.debug("Data is empty.")
return False, kwargs # No data at all, proceed with the full original request
temp_data = data.copy()
# Ensure 'open_time' is in datetime format
if temp_data['open_time'].dtype != '<M8[ns]':
temp_data['open_time_dt'] = pd.to_datetime(temp_data['open_time'], unit='ms')
else:
temp_data['open_time_dt'] = temp_data['open_time']
min_timestamp = temp_data['open_time_dt'].min()
max_timestamp = temp_data['open_time_dt'].max()
logger.debug(f"Data time range: {min_timestamp} to {max_timestamp}")
logger.debug(f"Expected time range: {start_datetime} to {end_datetime}")
tolerance = pd.Timedelta(seconds=5)
# Initialize updated request criteria
updated_request_criteria = kwargs.copy()
# Check if data covers the required time range with tolerance
if min_timestamp > start_datetime + timeframe_to_timedelta(timeframe) + tolerance:
logger.debug("Data does not start early enough, even with tolerance.")
updated_request_criteria['end_datetime'] = min_timestamp # Fetch the missing earlier data
return False, updated_request_criteria
if max_timestamp < end_datetime - timeframe_to_timedelta(timeframe) - tolerance:
logger.debug("Data does not extend late enough, even with tolerance.")
updated_request_criteria['start_datetime'] = max_timestamp # Fetch the missing later data
return False, updated_request_criteria
# Filter data between start_datetime and end_datetime
mask = (temp_data['open_time_dt'] >= start_datetime) & (temp_data['open_time_dt'] <= end_datetime)
data_in_range = temp_data.loc[mask]
expected_count = estimate_record_count(start_datetime, end_datetime, timeframe)
actual_count = len(data_in_range)
logger.debug(f"Expected record count: {expected_count}, Actual record count: {actual_count}")
tolerance = 1
if actual_count < (expected_count - tolerance):
logger.debug("Insufficient records within the specified time range, even with tolerance.")
return False, updated_request_criteria
logger.debug("Data completeness check passed.")
return True, kwargs
def cache_exists(self, key: str) -> bool:
return key in self.cached_data
def get_cache(self, key: str) -> Any | None:
if key not in self.cached_data:
logger.warning(f"The requested cache key({key}) doesn't exist!")
return None
return self.cached_data[key]
def update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None:
logger.debug('Updating cache with new records.')
# Concatenate the new records with the existing cache
records = pd.concat([self.get_cache(key), more_records], axis=0, ignore_index=True)
# Drop duplicates based on 'open_time' and keep the first occurrence
records = records.drop_duplicates(subset="open_time", keep='first')
# Sort the records by 'open_time'
records = records.sort_values(by='open_time').reset_index(drop=True)
# Reindex 'id' to ensure the expected order
records['id'] = range(1, len(records) + 1)
# Set the updated DataFrame back to cache
self.set_cache(data=records, key=key)
def update_cached_dict(self, cache_key: str, dict_key: str, data: Any) -> None:
"""
Updates a dictionary stored in cache.
:param data: The data to insert into cache.
:param cache_key: The cache index key for the dictionary.
:param dict_key: The dictionary key for the data.
:return: None
"""
self.cached_data[cache_key].update({dict_key: data})
def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None:
if do_not_overwrite and key in self.cached_data:
return
self.cached_data[key] = data
logger.debug(f'Cache set for key: {key}')
def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str,
start_datetime: dt.datetime = None,
end_datetime: dt.datetime = None) -> pd.DataFrame:
if start_datetime is None:
start_datetime = dt.datetime(year=2017, month=1, day=1)
if end_datetime is None:
end_datetime = dt.datetime.utcnow()
if start_datetime > end_datetime:
raise ValueError("Invalid start and end parameters: start_datetime must be before end_datetime.")
exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name)
expected_records = estimate_record_count(start_datetime, end_datetime, interval)
logger.info(
f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {expected_records}')
if start_datetime == end_datetime:
end_datetime = None
candles = exchange.get_historical_klines(symbol=symbol, interval=interval, start_dt=start_datetime,
end_dt=end_datetime)
num_rec_records = len(candles.index)
if num_rec_records == 0:
logger.warning(f"No OHLCV data returned for {symbol}.")
return pd.DataFrame(columns=['open_time', 'open', 'high', 'low', 'close', 'volume'])
logger.info(f'{num_rec_records} candles retrieved from the exchange.')
open_times = candles.open_time
min_open_time = open_times.min()
max_open_time = open_times.max()
if min_open_time < 1e10:
raise ValueError('Records are not in milliseconds')
estimated_num_records = estimate_record_count(min_open_time, max_open_time, interval) + 1
logger.info(f'Estimated number of records: {estimated_num_records}')
if num_rec_records < estimated_num_records:
logger.info('Detected gaps in the data, attempting to fill missing records.')
candles = self.fill_data_holes(candles, interval)
return candles
@staticmethod
def fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame:
time_span = timeframe_to_timedelta(interval).total_seconds() / 60
last_timestamp = None
filled_records = []
logger.info(f"Starting to fill data holes for interval: {interval}")
for index, row in records.iterrows():
time_stamp = row['open_time']
if last_timestamp is None:
last_timestamp = time_stamp
filled_records.append(row)
logger.debug(f"First timestamp: {time_stamp}")
continue
delta_ms = time_stamp - last_timestamp
delta_minutes = (delta_ms / 1000) / 60
logger.debug(f"Timestamp: {time_stamp}, Delta minutes: {delta_minutes}")
if delta_minutes > time_span:
num_missing_rec = int(delta_minutes / time_span)
step = int(delta_ms / num_missing_rec)
logger.debug(f"Gap detected. Filling {num_missing_rec} records with step: {step}")
for ts in range(int(last_timestamp) + step, int(time_stamp), step):
new_row = row.copy()
new_row['open_time'] = ts
filled_records.append(new_row)
logger.debug(f"Filled timestamp: {ts}")
filled_records.append(row)
last_timestamp = time_stamp
logger.info("Data holes filled successfully.")
return pd.DataFrame(filled_records)
def _populate_db(self, ex_details: List[str], data: pd.DataFrame = None) -> None:
if data is None or data.empty:
logger.debug(f'No records to insert {data}')
return
table_name = self._make_key(ex_details)
sym, _, ex, _ = ex_details
self.db.insert_candles_into_db(data, table_name=table_name, symbol=sym, exchange_name=ex)
logger.info(f'Data inserted into table {table_name}')
@staticmethod
def _make_key(ex_details: List[str]) -> str:
sym, tf, ex, _ = ex_details
key = f'{sym}_{tf}_{ex}'
return key
# Example usage
# args = {
# 'start_datetime': dt.datetime.now() - dt.timedelta(hours=1), # Example start time
# 'ex_details': ['BTCUSDT', '15m', 'Binance', 'user1'],
# }
#
# exchanges = ExchangeHandler()
# data = DataCache(exchanges)
# df = data.get_records_since(**args)
#
# # Disabling type checking for a specific instance
# data.TYPECHECKING_ENABLED = False

View File

@ -1,6 +1,6 @@
import sqlite3 import sqlite3
from functools import lru_cache from functools import lru_cache
from typing import Any from typing import Any, Dict, List, Tuple
import config import config
import datetime as dt import datetime as dt
import pandas as pd import pandas as pd
@ -19,7 +19,7 @@ class SQLite:
cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)') cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
""" """
def __init__(self, db_file=None): def __init__(self, db_file: str = None):
self.db_file = db_file if db_file else config.DB_FILE self.db_file = db_file if db_file else config.DB_FILE
self.connection = sqlite3.connect(self.db_file) self.connection = sqlite3.connect(self.db_file)
@ -41,11 +41,11 @@ class HDict(dict):
hash(hdict) hash(hdict)
""" """
def __hash__(self): def __hash__(self) -> int:
return hash(frozenset(self.items())) return hash(frozenset(self.items()))
def make_query(item: str, table: str, columns: list) -> str: def make_query(item: str, table: str, columns: List[str]) -> str:
""" """
Creates a SQL select query string with the required number of placeholders. Creates a SQL select query string with the required number of placeholders.
@ -53,40 +53,22 @@ def make_query(item: str, table: str, columns: list) -> str:
:param table: The table to select from. :param table: The table to select from.
:param columns: List of columns for the where clause. :param columns: List of columns for the where clause.
:return: The query string. :return: The query string.
Example:
--------
query = make_query('id', 'test_table', ['name', 'age'])
# Result: 'SELECT id FROM test_table WHERE name = ? AND age = ?;'
""" """
an_itr = iter(columns) placeholders = " AND ".join([f"{col} = ?" for col in columns])
k = next(an_itr) return f"SELECT {item} FROM {table} WHERE {placeholders};"
where_str = f"SELECT {item} FROM {table} WHERE {k} = ?"
where_str += "".join([f" AND {k} = ?" for k in an_itr]) + ';'
return where_str
def make_insert(table: str, values: tuple) -> str: def make_insert(table: str, columns: Tuple[str, ...]) -> str:
""" """
Creates a SQL insert query string with the required number of placeholders. Creates a SQL insert query string with the required number of placeholders.
:param table: The table to insert into. :param table: The table to insert into.
:param values: Tuple of values to insert. :param columns: Tuple of column names.
:return: The query string. :return: The query string.
Example:
--------
insert = make_insert('test_table', ('name', 'age'))
# Result: "INSERT INTO test_table ('name', 'age') VALUES(?, ?);"
""" """
itr1 = iter(values) col_names = ", ".join([f"'{col}'" for col in columns])
itr2 = iter(values) placeholders = ", ".join(["?" for _ in columns])
k1 = next(itr1) return f"INSERT INTO {table} ({col_names}) VALUES ({placeholders});"
_ = next(itr2)
insert_str = f"INSERT INTO {table} ('{k1}'"
insert_str += "".join([f", '{k1}'" for k1 in itr1]) + ") VALUES(?" + "".join(
[", ?" for _ in enumerate(itr2)]) + ");"
return insert_str
class Database: class Database:
@ -96,16 +78,11 @@ class Database:
Example usage: Example usage:
-------------- --------------
db = Database(exchanges=some_exchanges_object, db_file='test_db.sqlite') db = Database(db_file='test_db.sqlite')
db.execute_sql('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)') db.execute_sql('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
""" """
def __init__(self, db_file=None): def __init__(self, db_file: str = None):
"""
Initializes the Database class.
:param db_file: Optional database file name.
"""
self.db_file = db_file self.db_file = db_file
def execute_sql(self, sql: str) -> None: def execute_sql(self, sql: str) -> None:
@ -113,17 +90,12 @@ class Database:
Executes a raw SQL statement. Executes a raw SQL statement.
:param sql: SQL statement to execute. :param sql: SQL statement to execute.
Example:
--------
db = Database(exchanges=some_exchanges_object, db_file='test_db.sqlite')
db.execute_sql('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
""" """
with SQLite(self.db_file) as con: with SQLite(self.db_file) as con:
cur = con.cursor() cur = con.cursor()
cur.execute(sql) cur.execute(sql)
def get_item_where(self, item_name: str, table_name: str, filter_vals: tuple) -> int: def get_item_where(self, item_name: str, table_name: str, filter_vals: Tuple[str, Any]) -> Any:
""" """
Returns an item from a table where the filter results should isolate a single row. Returns an item from a table where the filter results should isolate a single row.
@ -131,42 +103,29 @@ class Database:
:param table_name: Name of the table. :param table_name: Name of the table.
:param filter_vals: Tuple of column name and value to filter by. :param filter_vals: Tuple of column name and value to filter by.
:return: The item. :return: The item.
Example:
--------
item = db.get_item_where('name', 'test_table', ('id', 1))
# Fetches the 'name' from 'test_table' where 'id' is 1
""" """
with SQLite(self.db_file) as con: with SQLite(self.db_file) as con:
cur = con.cursor() cur = con.cursor()
qry = make_query(item_name, table_name, [filter_vals[0]]) qry = make_query(item_name, table_name, [filter_vals[0]])
cur.execute(qry, (filter_vals[1],)) cur.execute(qry, (filter_vals[1],))
if user_id := cur.fetchone(): if result := cur.fetchone():
return user_id[0] return result[0]
else: else:
error = f"Couldn't fetch item {item_name} from {table_name} where {filter_vals[0]} = {filter_vals[1]}" error = f"Couldn't fetch item {item_name} from {table_name} where {filter_vals[0]} = {filter_vals[1]}"
raise ValueError(error) raise ValueError(error)
def get_rows_where(self, table: str, filter_vals: tuple) -> pd.DataFrame | None: def get_rows_where(self, table: str, filter_vals: Tuple[str, Any]) -> pd.DataFrame | None:
""" """
Returns a DataFrame containing all rows of a table that meet the filter criteria. Returns a DataFrame containing all rows of a table that meet the filter criteria.
:param table: Name of the table. :param table: Name of the table.
:param filter_vals: Tuple of column name and value to filter by. :param filter_vals: Tuple of column name and value to filter by.
:return: DataFrame of the query result or None if empty. :return: DataFrame of the query result or None if empty.
Example:
--------
rows = db.get_rows_where('test_table', ('name', 'test'))
# Fetches all rows from 'test_table' where 'name' is 'test'
""" """
with SQLite(self.db_file) as con: with SQLite(self.db_file) as con:
qry = f"SELECT * FROM {table} WHERE {filter_vals[0]}='{filter_vals[1]}'" qry = f"SELECT * FROM {table} WHERE {filter_vals[0]} = ?"
result = pd.read_sql(qry, con=con) result = pd.read_sql(qry, con, params=(filter_vals[1],))
if not result.empty: return result if not result.empty else None
return result
else:
return None
def insert_dataframe(self, df: pd.DataFrame, table: str) -> None: def insert_dataframe(self, df: pd.DataFrame, table: str) -> None:
""" """
@ -174,30 +133,21 @@ class Database:
:param df: DataFrame to insert. :param df: DataFrame to insert.
:param table: Name of the table. :param table: Name of the table.
Example:
--------
df = pd.DataFrame({'id': [1], 'name': ['test']})
db.insert_dataframe(df, 'test_table')
""" """
with SQLite(self.db_file) as con: with SQLite(self.db_file) as con:
df.to_sql(name=table, con=con, index=False, if_exists='append') df.to_sql(name=table, con=con, index=False, if_exists='append')
def insert_row(self, table: str, columns: tuple, values: tuple) -> None: def insert_row(self, table: str, columns: Tuple[str, ...], values: Tuple[Any, ...]) -> None:
""" """
Inserts a row into a specified table. Inserts a row into a specified table.
:param table: Name of the table. :param table: Name of the table.
:param columns: Tuple of column names. :param columns: Tuple of column names.
:param values: Tuple of values to insert. :param values: Tuple of values to insert.
Example:
--------
db.insert_row('test_table', ('id', 'name'), (1, 'test'))
""" """
with SQLite(self.db_file) as conn: with SQLite(self.db_file) as conn:
cursor = conn.cursor() cursor = conn.cursor()
sql = make_insert(table=table, values=columns) sql = make_insert(table=table, columns=columns)
cursor.execute(sql, values) cursor.execute(sql, values)
def table_exists(self, table_name: str) -> bool: def table_exists(self, table_name: str) -> bool:
@ -206,11 +156,6 @@ class Database:
:param table_name: Name of the table. :param table_name: Name of the table.
:return: True if the table exists, False otherwise. :return: True if the table exists, False otherwise.
Example:
--------
exists = db._table_exists('test_table')
# Checks if 'test_table' exists in the database
""" """
with SQLite(self.db_file) as conn: with SQLite(self.db_file) as conn:
cursor = conn.cursor() cursor = conn.cursor()
@ -220,7 +165,7 @@ class Database:
return result is not None return result is not None
@lru_cache(maxsize=1000) @lru_cache(maxsize=1000)
def get_from_static_table(self, item: str, table: str, indexes: HDict, create_id: bool = False) -> Any: def get_from_static_table(self, item: str, table: str, indexes: dict, create_id: bool = False) -> Any:
""" """
Returns the row id of an item from a table specified. If the item isn't listed in the table, Returns the row id of an item from a table specified. If the item isn't listed in the table,
it will insert the item into a new row and return the autoincremented id. The item received as a hashable it will insert the item into a new row and return the autoincremented id. The item received as a hashable
@ -231,23 +176,21 @@ class Database:
:param indexes: Hashable dictionary of indexing columns and their values. :param indexes: Hashable dictionary of indexing columns and their values.
:param create_id: If True, create a row if it doesn't exist and return the autoincrement ID. :param create_id: If True, create a row if it doesn't exist and return the autoincrement ID.
:return: The content of the field. :return: The content of the field.
Example:
--------
exchange_id = db.get_from_static_table('id', 'exchange', HDict({'name': 'binance'}), create_id=True)
""" """
with SQLite(self.db_file) as conn: with SQLite(self.db_file) as conn:
cursor = conn.cursor() cursor = conn.cursor()
sql = make_query(item, table, list(indexes.keys())) sql = make_query(item, table, list(indexes.keys()))
cursor.execute(sql, tuple(indexes.values())) cursor.execute(sql, tuple(indexes.values()))
result = cursor.fetchone() result = cursor.fetchone()
if result is None and create_id: if result is None and create_id:
sql = make_insert(table, tuple(indexes.keys())) sql = make_insert(table, tuple(indexes.keys()))
cursor.execute(sql, tuple(indexes.values())) cursor.execute(sql, tuple(indexes.values()))
sql = make_query(item, table, list(indexes.keys())) result = cursor.lastrowid # Get the last inserted row ID
cursor.execute(sql, tuple(indexes.values())) else:
result = cursor.fetchone() result = result[0] if result else None
return result[0] if result else None
return result
def _fetch_exchange_id(self, exchange_name: str) -> int: def _fetch_exchange_id(self, exchange_name: str) -> int:
""" """
@ -255,25 +198,17 @@ class Database:
:param exchange_name: Name of the exchange. :param exchange_name: Name of the exchange.
:return: Primary ID of the exchange. :return: Primary ID of the exchange.
Example:
--------
exchange_id = db._fetch_exchange_id('binance')
""" """
return self.get_from_static_table(item='id', table='exchange', create_id=True, return self.get_from_static_table(item='id', table='exchange', create_id=True,
indexes=HDict({'name': exchange_name})) indexes=HDict({'name': exchange_name}))
def _fetch_market_id(self, symbol: str, exchange_name: str) -> int: def _fetch_market_id(self, symbol: str, exchange_name: str) -> int:
""" """
Returns the market ID for a trading pair listed in the database. Returns the markets ID for a trading pair listed in the database.
:param symbol: Symbol of the trading pair. :param symbol: Symbol of the trading pair.
:param exchange_name: Name of the exchange. :param exchange_name: Name of the exchange.
:return: Market ID. :return: Market ID.
Example:
--------
market_id = db._fetch_market_id('BTC/USDT', 'binance')
""" """
exchange_id = self._fetch_exchange_id(exchange_name) exchange_id = self._fetch_exchange_id(exchange_name)
market_id = self.get_from_static_table(item='id', table='markets', create_id=True, market_id = self.get_from_static_table(item='id', table='markets', create_id=True,
@ -289,21 +224,17 @@ class Database:
:param table_name: Name of the table to insert into. :param table_name: Name of the table to insert into.
:param symbol: Symbol of the trading pair. :param symbol: Symbol of the trading pair.
:param exchange_name: Name of the exchange. :param exchange_name: Name of the exchange.
Example:
--------
df = pd.DataFrame({
'open_time': [unix_time_millis(dt.datetime.utcnow())],
'open': [1.0],
'high': [1.0],
'low': [1.0],
'close': [1.0],
'volume': [1.0]
})
db._insert_candles_into_db(df, 'test_table', 'BTC/USDT', 'binance')
""" """
market_id = self._fetch_market_id(symbol, exchange_name) market_id = self._fetch_market_id(symbol, exchange_name)
candlesticks.insert(0, 'market_id', market_id)
# Check if 'market_id' column already exists in the DataFrame
if 'market_id' in candlesticks.columns:
# If it exists, set its value to the fetched market_id
candlesticks['market_id'] = market_id
else:
# If it doesn't exist, insert it as the first column
candlesticks.insert(0, 'market_id', market_id)
sql_create = f""" sql_create = f"""
CREATE TABLE IF NOT EXISTS '{table_name}' ( CREATE TABLE IF NOT EXISTS '{table_name}' (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
@ -332,21 +263,19 @@ class Database:
:param st: Start datetime. :param st: Start datetime.
:param et: End datetime (optional). :param et: End datetime (optional).
:return: DataFrame of records. :return: DataFrame of records.
Example:
--------
records = db.get_timestamped_records('test_table', 'open_time', start_time, end_time)
""" """
with SQLite(self.db_file) as conn: with SQLite(self.db_file) as conn:
start_stamp = unix_time_millis(st) start_stamp = unix_time_millis(st)
if et is not None: if et is not None:
end_stamp = unix_time_millis(et) end_stamp = unix_time_millis(et)
q_str = ( q_str = (
f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= {start_stamp} " f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= ? "
f"AND {timestamp_field} <= {end_stamp};" f"AND {timestamp_field} <= ?;"
) )
records = pd.read_sql(q_str, conn, params=(start_stamp, end_stamp))
else: else:
q_str = f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= {start_stamp};" q_str = f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= ?;"
records = pd.read_sql(q_str, conn) records = pd.read_sql(q_str, conn, params=(start_stamp,))
records = records.drop('id', axis=1)
# records = records.drop('id', axis=1) Todo: Reminder I may need to put this back later.
return records return records

View File

@ -1,6 +1,6 @@
import ccxt import ccxt
import pandas as pd import pandas as pd
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
from typing import Tuple, Dict, List, Union, Any from typing import Tuple, Dict, List, Union, Any
import time import time
import logging import logging
@ -84,6 +84,8 @@ class Exchange:
Returns: Returns:
int: The Unix timestamp in milliseconds. int: The Unix timestamp in milliseconds.
""" """
if dt.tzinfo is None:
raise ValueError("datetime object must be timezone-aware or in UTC.")
return int(dt.timestamp() * 1000) return int(dt.timestamp() * 1000)
def _fetch_historical_klines(self, symbol: str, interval: str, def _fetch_historical_klines(self, symbol: str, interval: str,
@ -103,6 +105,12 @@ class Exchange:
if end_dt is None: if end_dt is None:
end_dt = datetime.utcnow() end_dt = datetime.utcnow()
# Convert start_dt and end_dt to UTC if they are naive
if start_dt.tzinfo is None:
start_dt = start_dt.replace(tzinfo=timezone.utc)
if end_dt.tzinfo is None:
end_dt = end_dt.replace(tzinfo=timezone.utc)
max_interval = timedelta(days=200) max_interval = timedelta(days=200)
data_frames = [] data_frames = []
current_start = start_dt current_start = start_dt
@ -122,7 +130,7 @@ class Exchange:
df_columns = ['open_time', 'open', 'high', 'low', 'close', 'volume'] df_columns = ['open_time', 'open', 'high', 'low', 'close', 'volume']
candles_df = pd.DataFrame(candles, columns=df_columns) candles_df = pd.DataFrame(candles, columns=df_columns)
candles_df['open_time'] = candles_df['open_time'] // 1000
data_frames.append(candles_df) data_frames.append(candles_df)
current_start = current_end current_start = current_end
@ -515,39 +523,41 @@ class Exchange:
return [] return []
# Usage Examples #
# # Usage Examples
# Example 1: Initializing the Exchange class #
api_keys = { # # Example 1: Initializing the Exchange class
'key': 'your_api_key', # api_keys = {
'secret': 'your_api_secret' # 'key': 'your_api_key',
} # 'secret': 'your_api_secret'
exchange = Exchange(name='Binance', api_keys=api_keys, exchange_id='binance') # }
# exchange = Exchange(name='Binance', api_keys=api_keys, exchange_id='binance')
# Example 2: Fetching historical data #
start_date = datetime(2022, 1, 1) # # Example 2: Fetching historical data
end_date = datetime(2022, 6, 1) # start_date = datetime(2022, 1, 1)
historical_data = exchange.get_historical_klines(symbol='BTC/USDT', interval='1d', # end_date = datetime(2022, 6, 1)
start_dt=start_date, end_dt=end_date) # historical_data = exchange.get_historical_klines(symbol='BTC/USDT', interval='1d',
print(historical_data) # start_dt=start_date, end_dt=end_date)
# print(historical_data)
# Example 3: Fetching the current price of a symbol #
current_price = exchange.get_price(symbol='BTC/USDT') # # Example 3: Fetching the current price of a symbol
print(f"Current price of BTC/USDT: {current_price}") # current_price = exchange.get_price(symbol='BTC/USDT')
# print(f"Current price of BTC/USDT: {current_price}")
# Example 4: Placing a limit buy order #
order_result, order_details = exchange.place_order(symbol='BTC/USDT', side='buy', type='limit', # # Example 4: Placing a limit buy order
timeInForce='GTC', quantity=0.001, price=30000) # order_result, order_details = exchange.place_order(symbol='BTC/USDT', side='buy', type='limit',
print(order_result, order_details) # timeInForce='GTC', quantity=0.001, price=30000)
# print(order_result, order_details)
# Example 5: Getting account balances #
balances = exchange.get_balances() # # Example 5: Getting account balances
print(balances) # balances = exchange.get_balances()
# print(balances)
# Example 6: Fetching open orders #
open_orders = exchange.get_open_orders() # # Example 6: Fetching open orders
print(open_orders) # open_orders = exchange.get_open_orders()
# print(open_orders)
# Example 7: Fetching active trades #
active_trades = exchange.get_active_trades() # # Example 7: Fetching active trades
print(active_trades) # active_trades = exchange.get_active_trades()
# print(active_trades)
#

View File

@ -1,8 +1,6 @@
import logging import logging
import json
from typing import List, Any, Dict from typing import List, Any, Dict
import pandas as pd import pandas as pd
import requests
import ccxt import ccxt
from Exchange import Exchange from Exchange import Exchange
@ -46,7 +44,7 @@ class ExchangeInterface:
self.add_exchange(user_name, exchange) self.add_exchange(user_name, exchange)
return True return True
except Exception as e: except Exception as e:
logging.error(f"Failed to connect user '{user_name}' to exchange '{exchange_name}': {str(e)}") logger.error(f"Failed to connect user '{user_name}' to exchange '{exchange_name}': {str(e)}")
return False return False
def add_exchange(self, user_name: str, exchange: Exchange): def add_exchange(self, user_name: str, exchange: Exchange):
@ -60,7 +58,7 @@ class ExchangeInterface:
row = {'user': user_name, 'name': exchange.name, 'reference': exchange, 'balances': exchange.balances} row = {'user': user_name, 'name': exchange.name, 'reference': exchange, 'balances': exchange.balances}
self.exchange_data = add_row(self.exchange_data, row) self.exchange_data = add_row(self.exchange_data, row)
except Exception as e: except Exception as e:
logging.error(f"Couldn't create an instance of the exchange! {str(e)}") logger.error(f"Couldn't create an instance of the exchange! {str(e)}")
raise raise
def get_exchange(self, ename: str, uname: str) -> Exchange: def get_exchange(self, ename: str, uname: str) -> Exchange:
@ -144,12 +142,12 @@ class ExchangeInterface:
elif fetch_type == 'orders': elif fetch_type == 'orders':
data = reference.get_open_orders() data = reference.get_open_orders()
else: else:
logging.error(f"Invalid fetch type: {fetch_type}") logger.error(f"Invalid fetch type: {fetch_type}")
return {} return {}
data_dict[name] = data data_dict[name] = data
except Exception as e: except Exception as e:
logging.error(f"Error retrieving data for {name}: {str(e)}") logger.error(f"Error retrieving data for {name}: {str(e)}")
return data_dict return data_dict

View File

@ -43,12 +43,8 @@ class Candles:
# Calculate the approximate start_datetime the first of n record will have. # Calculate the approximate start_datetime the first of n record will have.
start_datetime = ts_of_n_minutes_ago(n=num_candles, candle_length=minutes_per_candle) start_datetime = ts_of_n_minutes_ago(n=num_candles, candle_length=minutes_per_candle)
# Table name format is: <symbol>_<timeframe>_<exchange_name>. Example: "BTCUSDT_15m_binance_spot"
key = f'{asset}_{timeframe}_{exchange}'
# Fetch records older than start_datetime. # Fetch records older than start_datetime.
candles = self.data.get_records_since(key=key, start_datetime=start_datetime, candles = self.data.get_records_since(start_datetime=start_datetime,
record_length=minutes_per_candle,
ex_details=[asset, timeframe, exchange, user_name]) ex_details=[asset, timeframe, exchange, user_name])
if len(candles.index) < num_candles: if len(candles.index) < num_candles:
timesince = dt.datetime.utcnow() - start_datetime timesince = dt.datetime.utcnow() - start_datetime

View File

@ -1,20 +1,38 @@
import ccxt import ccxt
import pandas as pd
import datetime
def main(): def fetch_and_print_ohlcv(exchange_name, symbol, timeframe, since, limit=5):
# Create an instance of the Binance exchange # Initialize the exchange
binance = ccxt.binance({ exchange_class = getattr(ccxt, exchange_name)
'enableRateLimit': True, exchange = exchange_class()
'verbose': False, # Ensure verbose mode is disabled
})
try: # Fetch historical candlestick data with a limit
# Load markets to test the connection ohlcv = exchange.fetch_ohlcv(symbol, timeframe, since=since, limit=limit)
markets = binance.load_markets()
print("Markets loaded successfully") # Convert to DataFrame for better readability
except ccxt.BaseError as e: df = pd.DataFrame(ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
print(f"Error loading markets: {str(e)}")
# Print the first few rows of the DataFrame
print("First few rows of the fetched OHLCV data:")
print(df.head())
# Print the timestamps in human-readable format
df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms')
print("\nFirst few timestamps in human-readable format:")
print(df[['timestamp', 'datetime']].head())
# Confirm the format of the timestamps
print("\nTimestamp format confirmation:")
for ts in df['timestamp']:
print(f"{ts} (milliseconds since Unix epoch)")
if __name__ == "__main__": # Example usage
main() exchange_name = 'binance' # Change this to your exchange
symbol = 'BTC/USDT'
timeframe = '5m'
since = int((datetime.datetime(2024, 8, 1) - datetime.datetime(1970, 1, 1)).total_seconds() * 1000)
fetch_and_print_ohlcv(exchange_name, symbol, timeframe, since, limit=5)

View File

@ -36,19 +36,37 @@ def query_uptodate(records: pd.DataFrame, r_length_min: float) -> Union[float, N
tolerance_minutes = 10 / 60 # 10 seconds tolerance in minutes tolerance_minutes = 10 / 60 # 10 seconds tolerance in minutes
if minutes_since_update > (r_length_min - tolerance_minutes): if minutes_since_update > (r_length_min - tolerance_minutes):
# Return the last timestamp in seconds # Return the last timestamp in seconds
return ms_to_seconds(last_timestamp) return last_timestamp
return None return None
def ms_to_seconds(timestamp): def ms_to_seconds(timestamp: float) -> float:
"""
Converts milliseconds to seconds.
:param timestamp: The timestamp in milliseconds.
:return: The timestamp in seconds.
"""
return timestamp / 1000 return timestamp / 1000
def unix_time_seconds(d_time): def unix_time_seconds(d_time: dt.datetime) -> float:
"""
Converts a datetime object to Unix timestamp in seconds.
:param d_time: The datetime object to convert.
:return: The Unix timestamp in seconds.
"""
return (d_time - epoch).total_seconds() return (d_time - epoch).total_seconds()
def unix_time_millis(d_time): def unix_time_millis(d_time: dt.datetime) -> float:
"""
Converts a datetime object to Unix timestamp in milliseconds.
:param d_time: The datetime object to convert.
:return: The Unix timestamp in milliseconds.
"""
return (d_time - epoch).total_seconds() * 1000.0 return (d_time - epoch).total_seconds() * 1000.0
@ -72,6 +90,10 @@ def query_satisfied(start_datetime: dt.datetime, records: pd.DataFrame, r_length
start_timestamp = unix_time_millis(start_datetime) start_timestamp = unix_time_millis(start_datetime)
print(f'Start timestamp: {start_timestamp}') print(f'Start timestamp: {start_timestamp}')
if records.empty:
print('No records found. Query cannot be satisfied.')
return float('nan')
# Get the oldest timestamp from the records passed in # Get the oldest timestamp from the records passed in
first_timestamp = float(records.open_time.min()) first_timestamp = float(records.open_time.min())
print(f'First timestamp in records: {first_timestamp}') print(f'First timestamp in records: {first_timestamp}')
@ -84,17 +106,17 @@ def query_satisfied(start_datetime: dt.datetime, records: pd.DataFrame, r_length
if start_timestamp <= first_timestamp + total_duration: if start_timestamp <= first_timestamp + total_duration:
return None return None
return first_timestamp / 1000 # Return in seconds return first_timestamp
@lru_cache(maxsize=500) @lru_cache(maxsize=500)
def ts_of_n_minutes_ago(n: int, candle_length: float) -> dt.datetime.timestamp: def ts_of_n_minutes_ago(n: int, candle_length: float) -> dt.datetime:
""" """
Returns the approximate datetime for the start of a candle that was 'n' candles ago. Returns the approximate datetime for the start of a candle that was 'n' candles ago.
:param n: int - The number of candles ago to calculate. :param n: The number of candles ago to calculate.
:param candle_length: float - The length of each candle in minutes. :param candle_length: The length of each candle in minutes.
:return: datetime - The approximate datetime for the start of the 'n'-th candle ago. :return: The approximate datetime for the start of the 'n'-th candle ago.
""" """
# Increment 'n' by 1 to ensure we account for the time that has passed since the last candle closed. # Increment 'n' by 1 to ensure we account for the time that has passed since the last candle closed.
n += 1 n += 1
@ -113,12 +135,12 @@ def ts_of_n_minutes_ago(n: int, candle_length: float) -> dt.datetime.timestamp:
@lru_cache(maxsize=20) @lru_cache(maxsize=20)
def timeframe_to_minutes(timeframe): def timeframe_to_minutes(timeframe: str) -> int:
""" """
Converts a string representing a timeframe into an integer representing the approximate minutes. Converts a string representing a timeframe into an integer representing the approximate minutes.
:param timeframe: str - Timeframe format is [multiplier:focus]. eg '15m', '4h', '1d' :param timeframe: Timeframe format is [multiplier:focus]. e.g., '15m', '4h', '1d'
:return: int - Minutes the timeframe represents ex. '2h'-> 120(minutes). :return: Minutes the timeframe represents, e.g., '2h' -> 120 (minutes).
""" """
# Extract the numerical part of the timeframe param. # Extract the numerical part of the timeframe param.
digits = int("".join([i if i.isdigit() else "" for i in timeframe])) digits = int("".join([i if i.isdigit() else "" for i in timeframe]))

478
tests/test_DataCache_v2.py Normal file
View File

@ -0,0 +1,478 @@
from DataCache_v2 import DataCache
from ExchangeInterface import ExchangeInterface
import unittest
import pandas as pd
import datetime as dt
import os
from Database import SQLite, Database
class DataGenerator:
def __init__(self, timeframe_str):
"""
Initialize the DataGenerator with a timeframe string like '2h', '5m', '1d', '1w', '1M', or '1y'.
"""
# Initialize attributes with placeholder values
self.timeframe_amount = None
self.timeframe_unit = None
# Set the actual timeframe
self.set_timeframe(timeframe_str)
def set_timeframe(self, timeframe_str):
"""
Set the timeframe unit and amount based on a string like '2h', '5m', '1d', '1w', '1M', or '1y'.
"""
self.timeframe_amount = int(timeframe_str[:-1])
unit = timeframe_str[-1]
if unit == 's':
self.timeframe_unit = 'seconds'
elif unit == 'm':
self.timeframe_unit = 'minutes'
elif unit == 'h':
self.timeframe_unit = 'hours'
elif unit == 'd':
self.timeframe_unit = 'days'
elif unit == 'w':
self.timeframe_unit = 'weeks'
elif unit == 'M':
self.timeframe_unit = 'months'
elif unit == 'Y':
self.timeframe_unit = 'years'
else:
raise ValueError(
"Unsupported timeframe unit. Use 's,m,h,d,w,M,Y'.")
def create_table(self, num_rec=None, start=None, end=None):
"""
Create a table with simulated data. If both start and end are provided, num_rec is derived from the interval.
If neither are provided the table will have num_rec and end at the current time.
Parameters:
num_rec (int, optional): The number of records to generate.
start (datetime, optional): The start time for the first record.
end (datetime, optional): The end time for the last record.
Returns:
pd.DataFrame: A DataFrame with the simulated data.
"""
# If neither start nor end are provided.
if start is None and end is None:
end = dt.datetime.utcnow()
if num_rec is None:
raise ValueError("num_rec must be provided if both start and end are not specified.")
# If only start is provided.
if start is not None and end is not None:
total_duration = (end - start).total_seconds()
interval_seconds = self.timeframe_amount * self._get_seconds_per_unit(self.timeframe_unit)
num_rec = int(total_duration // interval_seconds) + 1
# If only end is provided.
if end is not None and start is None:
if num_rec is None:
raise ValueError("num_rec must be provided if both start and end are not specified.")
interval_seconds = self.timeframe_amount * self._get_seconds_per_unit(self.timeframe_unit)
start = end - dt.timedelta(seconds=(num_rec - 1) * interval_seconds)
# Ensure start is aligned to the timeframe interval
start = self.round_down_datetime(start, self.timeframe_unit[0], self.timeframe_amount)
# Generate times
times = [self.unix_time_millis(start + self._delta(i)) for i in range(num_rec)]
df = pd.DataFrame({
'market_id': 1,
'open_time': times,
'open': [100 + i for i in range(num_rec)],
'high': [110 + i for i in range(num_rec)],
'low': [90 + i for i in range(num_rec)],
'close': [105 + i for i in range(num_rec)],
'volume': [1000 + i for i in range(num_rec)]
})
return df
@staticmethod
def _get_seconds_per_unit(unit):
"""Helper method to convert timeframe units to seconds."""
units_in_seconds = {
'seconds': 1,
'minutes': 60,
'hours': 3600,
'days': 86400,
'weeks': 604800,
'months': 2592000, # Assuming 30 days per month
'years': 31536000 # Assuming 365 days per year
}
if unit not in units_in_seconds:
raise ValueError(f"Unsupported timeframe unit: {unit}")
return units_in_seconds[unit]
def generate_incomplete_data(self, query_offset, num_rec=5):
"""
Generate data that is incomplete, i.e., starts before the query but doesn't fully satisfy it.
"""
query_start_time = self.x_time_ago(query_offset)
start_time_for_data = self.get_start_time(query_start_time)
return self.create_table(num_rec, start=start_time_for_data)
@staticmethod
def generate_missing_section(df, drop_start=5, drop_end=8):
"""
Generate data with a missing section.
"""
df = df.drop(df.index[drop_start:drop_end]).reset_index(drop=True)
return df
def get_start_time(self, query_start_time):
margin = 2
delta_args = {self.timeframe_unit: margin * self.timeframe_amount}
return query_start_time - dt.timedelta(**delta_args)
def x_time_ago(self, offset):
"""
Returns a datetime object representing the current time minus the offset in the specified units.
"""
delta_args = {self.timeframe_unit: offset}
return dt.datetime.utcnow() - dt.timedelta(**delta_args)
def _delta(self, i):
"""
Returns a timedelta object for the ith increment based on the timeframe unit and amount.
"""
delta_args = {self.timeframe_unit: i * self.timeframe_amount}
return dt.timedelta(**delta_args)
@staticmethod
def unix_time_millis(dt_obj):
"""
Convert a datetime object to Unix time in milliseconds.
"""
epoch = dt.datetime(1970, 1, 1)
return int((dt_obj - epoch).total_seconds() * 1000)
@staticmethod
def round_down_datetime(dt_obj: dt.datetime, unit: str, interval: int) -> dt.datetime:
if unit == 's': # Round down to the nearest interval of seconds
seconds = (dt_obj.second // interval) * interval
dt_obj = dt_obj.replace(second=seconds, microsecond=0)
elif unit == 'm': # Round down to the nearest interval of minutes
minutes = (dt_obj.minute // interval) * interval
dt_obj = dt_obj.replace(minute=minutes, second=0, microsecond=0)
elif unit == 'h': # Round down to the nearest interval of hours
hours = (dt_obj.hour // interval) * interval
dt_obj = dt_obj.replace(hour=hours, minute=0, second=0, microsecond=0)
elif unit == 'd': # Round down to the nearest interval of days
days = (dt_obj.day // interval) * interval
dt_obj = dt_obj.replace(day=days, hour=0, minute=0, second=0, microsecond=0)
elif unit == 'w': # Round down to the nearest interval of weeks
dt_obj -= dt.timedelta(days=dt_obj.weekday() % (interval * 7))
dt_obj = dt_obj.replace(hour=0, minute=0, second=0, microsecond=0)
elif unit == 'M': # Round down to the nearest interval of months
months = ((dt_obj.month - 1) // interval) * interval + 1
dt_obj = dt_obj.replace(month=months, day=1, hour=0, minute=0, second=0, microsecond=0)
elif unit == 'y': # Round down to the nearest interval of years
years = (dt_obj.year // interval) * interval
dt_obj = dt_obj.replace(year=years, month=1, day=1, hour=0, minute=0, second=0, microsecond=0)
return dt_obj
class TestDataCacheV2(unittest.TestCase):
def setUp(self):
# Set up database and exchanges
self.exchanges = ExchangeInterface()
self.exchanges.connect_exchange(exchange_name='binance', user_name='test_guy', api_keys=None)
self.db_file = 'test_db.sqlite'
self.database = Database(db_file=self.db_file)
# Create necessary tables
sql_create_table_1 = f"""
CREATE TABLE IF NOT EXISTS test_table (
id INTEGER PRIMARY KEY,
market_id INTEGER,
open_time INTEGER UNIQUE ON CONFLICT IGNORE,
open REAL NOT NULL,
high REAL NOT NULL,
low REAL NOT NULL,
close REAL NOT NULL,
volume REAL NOT NULL,
FOREIGN KEY (market_id) REFERENCES market (id)
)"""
sql_create_table_2 = """
CREATE TABLE IF NOT EXISTS exchange (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE
)"""
sql_create_table_3 = """
CREATE TABLE IF NOT EXISTS markets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT,
exchange_id INTEGER,
FOREIGN KEY (exchange_id) REFERENCES exchange(id)
)"""
with SQLite(db_file=self.db_file) as con:
con.execute(sql_create_table_1)
con.execute(sql_create_table_2)
con.execute(sql_create_table_3)
self.data = DataCache(self.exchanges)
self.data.db = self.database
self.ex_details = ['BTC/USD', '2h', 'binance', 'test_guy']
self.key = f'{self.ex_details[0]}_{self.ex_details[1]}_{self.ex_details[2]}'
def tearDown(self):
if os.path.exists(self.db_file):
os.remove(self.db_file)
def test_set_cache(self):
print('\nTesting set_cache() method without no-overwrite flag:')
self.data.set_cache(data='data', key=self.key)
attr = self.data.__getattribute__('cached_data')
self.assertEqual(attr[self.key], 'data')
print(' - Set cache without no-overwrite flag passed.')
print('Testing set_cache() once again with new data without no-overwrite flag:')
self.data.set_cache(data='more_data', key=self.key)
attr = self.data.__getattribute__('cached_data')
self.assertEqual(attr[self.key], 'more_data')
print(' - Set cache with new data without no-overwrite flag passed.')
print('Testing set_cache() method once again with more data with no-overwrite flag set:')
self.data.set_cache(data='even_more_data', key=self.key, do_not_overwrite=True)
attr = self.data.__getattribute__('cached_data')
self.assertEqual(attr[self.key], 'more_data')
print(' - Set cache with no-overwrite flag passed.')
def test_cache_exists(self):
print('Testing cache_exists() method:')
self.assertFalse(self.data.cache_exists(key=self.key))
print(' - Check for non-existent cache passed.')
self.data.set_cache(data='data', key=self.key)
self.assertTrue(self.data.cache_exists(key=self.key))
print(' - Check for existent cache passed.')
def test_update_candle_cache(self):
print('Testing update_candle_cache() method:')
# Initialize the DataGenerator with the 5-minute timeframe
data_gen = DataGenerator('5m')
# Create initial DataFrame and insert into cache
df_initial = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 0, 0))
print(f'Inserting this table into cache:\n{df_initial}\n')
self.data.set_cache(data=df_initial, key=self.key)
# Create new DataFrame to be added to cache
df_new = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 15, 0))
print(f'Updating cache with this table:\n{df_new}\n')
self.data.update_candle_cache(more_records=df_new, key=self.key)
# Retrieve the resulting DataFrame from cache
result = self.data.get_cache(key=self.key)
print(f'The resulting table in cache is:\n{result}\n')
# Create the expected DataFrame
expected = data_gen.create_table(num_rec=6, start=dt.datetime(2024, 8, 9, 0, 0, 0))
print(f'The expected open_time values are:\n{expected["open_time"].tolist()}\n')
# Assert that the open_time values in the result match those in the expected DataFrame, in order
assert result['open_time'].tolist() == expected['open_time'].tolist(), \
f"open_time values in result are {result['open_time'].tolist()}" \
f" but expected {expected['open_time'].tolist()}"
print(f'The results open_time values match:\n{result["open_time"].tolist()}\n')
print(' - Update cache with new records passed.')
def test_update_cached_dict(self):
print('Testing update_cached_dict() method:')
self.data.set_cache(data={}, key=self.key)
self.data.update_cached_dict(cache_key=self.key, dict_key='sub_key', data='value')
cache = self.data.get_cache(key=self.key)
self.assertEqual(cache['sub_key'], 'value')
print(' - Update dictionary in cache passed.')
def test_get_cache(self):
print('Testing get_cache() method:')
self.data.set_cache(data='data', key=self.key)
result = self.data.get_cache(key=self.key)
self.assertEqual(result, 'data')
print(' - Retrieve cache passed.')
def _test_get_records_since(self, set_cache=True, set_db=True, query_offset=None, num_rec=None, ex_details=None,
simulate_scenarios=None):
"""
Test the get_records_since() method by generating a table of simulated data,
inserting it into cache and/or database, and then querying the records.
Parameters:
set_cache (bool): If True, the generated table is inserted into the cache.
set_db (bool): If True, the generated table is inserted into the database.
query_offset (int, optional): The offset in the timeframe units for the query.
num_rec (int, optional): The number of records to generate in the simulated table.
ex_details (list, optional): Exchange details to generate the cache key.
simulate_scenarios (str, optional): The type of scenario to simulate. Options are:
- 'not_enough_data': The table data doesn't go far enough back.
- 'incomplete_data': The table doesn't have enough records to satisfy the query.
- 'missing_section': The table has missing records in the middle.
"""
print('Testing get_records_since() method:')
# Use provided ex_details or fallback to the class attribute.
ex_details = ex_details or self.ex_details
# Generate a cache/database key using exchange details.
key = f'{ex_details[0]}_{ex_details[1]}_{ex_details[2]}'
# Set default number of records if not provided.
num_rec = num_rec or 12
table_timeframe = ex_details[1] # Extract timeframe from exchange details.
# Initialize DataGenerator with the given timeframe.
data_gen = DataGenerator(table_timeframe)
if simulate_scenarios == 'not_enough_data':
# Set query_offset to a time earlier than the start of the table data.
query_offset = (num_rec + 5) * data_gen.timeframe_amount
else:
# Default to querying for 1 record length less than the table duration.
query_offset = query_offset or (num_rec - 1) * data_gen.timeframe_amount
if simulate_scenarios == 'incomplete_data':
# Set start time to generate fewer records than required.
start_time_for_data = data_gen.x_time_ago(num_rec * data_gen.timeframe_amount)
num_rec = 5 # Set a smaller number of records to simulate incomplete data.
else:
# No specific start time for data generation.
start_time_for_data = None
# Create the initial data table.
df_initial = data_gen.create_table(num_rec, start=start_time_for_data)
if simulate_scenarios == 'missing_section':
# Simulate missing section in the data by dropping records.
df_initial = data_gen.generate_missing_section(df_initial, drop_start=2, drop_end=5)
# Convert 'open_time' to datetime for better readability.
temp_df = df_initial.copy()
temp_df['open_time'] = pd.to_datetime(temp_df['open_time'], unit='ms')
print(f'Table Created:\n{temp_df}')
if set_cache:
# Insert the generated table into cache.
print('Inserting table into cache.')
self.data.set_cache(data=df_initial, key=key)
if set_db:
# Insert the generated table into the database.
print('Inserting table into database.')
with SQLite(self.db_file) as con:
df_initial.to_sql(key, con, if_exists='replace', index=False)
# Calculate the start time for querying the records.
start_datetime = data_gen.x_time_ago(query_offset)
# Defaults to current time if not provided to get_records_since()
query_end_time = dt.datetime.utcnow()
print(f'Requesting records from {start_datetime} to {query_end_time}')
# Query the records since the calculated start time.
result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details)
# Filter the initial data table to match the query time.
expected = df_initial[df_initial['open_time'] >= data_gen.unix_time_millis(start_datetime)].reset_index(
drop=True)
temp_df = expected.copy()
temp_df['open_time'] = pd.to_datetime(temp_df['open_time'], unit='ms')
print(f'Expected table:\n{temp_df}')
# Print the result from the query for comparison.
temp_df = result.copy()
temp_df['open_time'] = pd.to_datetime(temp_df['open_time'], unit='ms')
print(f'Resulting table:\n{temp_df}')
if simulate_scenarios in ['not_enough_data', 'incomplete_data', 'missing_section']:
# Check that the result has more rows than the expected incomplete data.
assert result.shape[0] > expected.shape[
0], "Result has fewer or equal rows compared to the incomplete data."
print("\nThe returned DataFrames has filled in the missing data!")
else:
# Ensure the result and expected dataframes match in shape and content.
assert result.shape == expected.shape, f"Shape mismatch: {result.shape} vs {expected.shape}"
pd.testing.assert_series_equal(result['open_time'], expected['open_time'], check_dtype=False)
print("\nThe DataFrames have the same shape and the 'open_time' columns match.")
# Verify that the oldest timestamp in the result is within the allowed time difference.
oldest_timestamp = pd.to_datetime(result['open_time'].min(), unit='ms')
time_diff = oldest_timestamp - start_datetime
max_allowed_time_diff = dt.timedelta(**{data_gen.timeframe_unit: data_gen.timeframe_amount})
assert dt.timedelta(0) <= time_diff <= max_allowed_time_diff, \
f"Oldest timestamp {oldest_timestamp} is not within " \
f"{data_gen.timeframe_amount} {data_gen.timeframe_unit} of {start_datetime}"
print(f'The first timestamp is {time_diff} from {start_datetime}')
# Verify that the newest timestamp in the result is within the allowed time difference.
newest_timestamp = pd.to_datetime(result['open_time'].max(), unit='ms')
time_diff_end = abs(query_end_time - newest_timestamp)
assert dt.timedelta(0) <= time_diff_end <= max_allowed_time_diff, \
f"Newest timestamp {newest_timestamp} is not within {data_gen.timeframe_amount} " \
f"{data_gen.timeframe_unit} of {query_end_time}"
print(f'The last timestamp is {time_diff_end} from {query_end_time}')
print(' - Fetch records within the specified time range passed.')
def test_get_records_since(self):
print('\nTest get_records_since with records set in cache')
self._test_get_records_since()
print('\nTest get_records_since with records not in cache')
self._test_get_records_since(set_cache=False)
print('\nTest get_records_since with records not in database')
self._test_get_records_since(set_cache=False, set_db=False)
print('\nTest get_records_since with a different timeframe')
self._test_get_records_since(query_offset=None, num_rec=None,
ex_details=['BTC/USD', '15m', 'binance', 'test_guy'])
print('\nTest get_records_since where data does not go far enough back')
self._test_get_records_since(simulate_scenarios='not_enough_data')
print('\nTest get_records_since with incomplete data')
self._test_get_records_since(simulate_scenarios='incomplete_data')
print('\nTest get_records_since with missing section in data')
self._test_get_records_since(simulate_scenarios='missing_section')
def test_populate_db(self):
print('Testing _populate_db() method:')
# Create a table of candle records.
data_gen = DataGenerator(self.ex_details[1])
data = data_gen.create_table(num_rec=5)
self.data._populate_db(ex_details=self.ex_details, data=data)
with SQLite(self.db_file) as con:
result = pd.read_sql(f'SELECT * FROM "{self.key}"', con)
self.assertFalse(result.empty)
print(' - Populate database with data passed.')
def test_fetch_candles_from_exchange(self):
print('Testing _fetch_candles_from_exchange() method:')
start_time = dt.datetime.utcnow() - dt.timedelta(days=1)
end_time = dt.datetime.utcnow()
result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h', exchange_name='binance',
user_name='test_guy', start_datetime=start_time,
end_datetime=end_time)
self.assertIsInstance(result, pd.DataFrame)
self.assertFalse(result.empty)
print(' - Fetch candle data from exchange passed.')
if __name__ == '__main__':
unittest.main()