brighter-trading/src/DataCache.py

374 lines
17 KiB
Python

from typing import Any, List
import pandas as pd
import datetime as dt
from Database import Database
from shared_utilities import query_satisfied, query_uptodate, unix_time_millis, timeframe_to_minutes
import logging
logger = logging.getLogger(__name__)
class DataCache:
"""
Fetches manages data limits and optimises memory storage.
Handles connections and operations for the given exchanges.
Example usage:
--------------
db = DataCache(exchanges=some_exchanges_object)
"""
def __init__(self, exchanges):
"""
Initializes the DataCache class.
:param exchanges: The exchanges object handling communication with connected exchanges.
"""
# Maximum number of tables to cache at any given time.
self.max_tables = 50
# Maximum number of records to be kept per table.
self.max_records = 1000
# A dictionary that holds all the cached records.
self.cached_data = {}
# The class that handles the DB interactions.
self.db = Database()
# The class that handles exchange interactions.
self.exchanges = exchanges
def cache_exists(self, key: str) -> bool:
"""
Return False if a cache doesn't exist for this key.
"""
return key in self.cached_data
def update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None:
"""
:param more_records: Adds records to existing cache.
:param key: The access key.
:return: None.
"""
# Combine the new candles with the previously cached dataframe.
records = pd.concat([more_records, self.get_cache(key)], axis=0, ignore_index=True)
# Drop any duplicates from overlap.
records = records.drop_duplicates(subset="open_time", keep='first')
# Sort the records by open_time.
records = records.sort_values(by='open_time').reset_index(drop=True)
# Replace the incomplete dataframe with the modified one.
self.set_cache(data=records, key=key)
return
def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None:
"""
Creates a new cache key and inserts some data.
Todo: This is where this data will be passed to a cache server.
:param data: The records to insert into cache.
:param key: The index key for the data.
:param do_not_overwrite: - Flag to prevent overwriting existing data.
:return: None
"""
# If the flag is set don't overwrite existing data.
if do_not_overwrite and key in self.cached_data:
return
# Assign the data
self.cached_data[key] = data
def update_cached_dict(self, cache_key, dict_key: str, data: Any) -> None:
"""
Updates a dictionary stored in cache.
Todo: This is where this data will be passed to a cache server.
:param data: - The records to insert into cache.
:param cache_key: - The cache index key for the dictionary.
:param dict_key: - The dictionary key for the data.
:return: None
"""
# Assign the data
self.cached_data[cache_key].update({dict_key: data})
def get_cache(self, key: str) -> Any:
"""
Returns data indexed by key.
Todo: This is where data will be retrieved from a cache server.
:param key: The index key for the data.
:return: Any|None - The requested data or None on key error.
"""
if key not in self.cached_data:
print(f"[WARNING: DataCache.py] The requested cache key({key}) doesn't exist!")
return None
return self.cached_data[key]
def get_records_since(self, key: str, start_datetime: dt.datetime, record_length: int,
ex_details: List[str]) -> pd.DataFrame:
"""
Return any records from the cache indexed by table_name that are newer than start_datetime.
:param ex_details: List[str] - Details to pass to the server
:param key: str - The dictionary table_name of the records.
:param start_datetime: dt.datetime - The datetime of the first record requested.
:param record_length: int - The timespan of the records.
:return: pd.DataFrame - The Requested records
Example:
--------
records = data_cache.get_records_since('BTC/USD_2h_binance', dt.datetime.utcnow() - dt.timedelta(minutes=60), 60, ['BTC/USD', '2h', 'binance'])
"""
try:
# End time of query defaults to the current time.
end_datetime = dt.datetime.utcnow()
if self.cache_exists(key=key):
logger.debug('Getting records from cache.')
# If the records exist, retrieve them from the cache.
records = self.get_cache(key)
else:
# If they don't exist in cache, get them from the database.
logger.debug(
f'Records not in cache. Requesting from DB: starting at: {start_datetime} to: {end_datetime}')
records = self.get_records_since_from_db(table_name=key, st=start_datetime,
et=end_datetime, rl=record_length, ex_details=ex_details)
logger.debug(f'Got {len(records.index)} records from DB.')
self.set_cache(data=records, key=key)
# Check if the records in the cache go far enough back to satisfy the query.
first_timestamp = query_satisfied(start_datetime=start_datetime,
records=records,
r_length_min=record_length)
if first_timestamp:
# The records didn't go far enough back if a timestamp was returned.
end_time = dt.datetime.utcfromtimestamp(first_timestamp)
logger.debug(f'Requesting additional records from {start_datetime} to {end_time}')
# Request additional records from the database.
additional_records = self.get_records_since_from_db(table_name=key, st=start_datetime,
et=end_time, rl=record_length,
ex_details=ex_details)
logger.debug(f'Got {len(additional_records.index)} additional records from DB.')
if not additional_records.empty:
# If more records were received, update the cache.
self.update_candle_cache(additional_records, key)
# Check if the records received are up-to-date.
last_timestamp = query_uptodate(records=records, r_length_min=record_length)
if last_timestamp:
# The query was not up-to-date if a timestamp was returned.
start_time = dt.datetime.utcfromtimestamp(last_timestamp)
logger.debug(f'Requesting additional records from {start_time} to {end_datetime}')
# Request additional records from the database.
additional_records = self.get_records_since_from_db(table_name=key, st=start_time,
et=end_datetime, rl=record_length,
ex_details=ex_details)
logger.debug(f'Got {len(additional_records.index)} additional records from DB.')
if not additional_records.empty:
self.update_candle_cache(additional_records, key)
# Create a UTC timestamp.
_timestamp = unix_time_millis(start_datetime)
logger.debug(f"Start timestamp in human-readable form: {dt.datetime.utcfromtimestamp(_timestamp / 1000.0)}")
# Return all records equal to or newer than the timestamp.
result = self.get_cache(key).query('open_time >= @_timestamp')
return result
except Exception as e:
logger.error(f"An error occurred: {str(e)}")
raise
def get_records_since_from_db(self, table_name: str, st: dt.datetime,
et: dt.datetime, rl: float, ex_details: list) -> pd.DataFrame:
"""
Returns records from a specified table that meet a criteria and ensures the records are complete.
If the records do not go back far enough or are not up-to-date, it fetches additional records
from an exchange and populates the table.
:param table_name: Database table name.
:param st: Start datetime.
:param et: End datetime.
:param rl: Timespan in minutes each record represents.
:param ex_details: Exchange details [symbol, interval, exchange_name].
:return: DataFrame of records.
Example:
--------
records = db.get_records_since('test_table', start_time, end_time, 1, ['BTC/USDT', '1m', 'binance'])
"""
def add_data(data, tn, start_t, end_t):
new_records = self._populate_db(table_name=tn, start_time=start_t, end_time=end_t, ex_details=ex_details)
print(f'Got {len(new_records.index)} records from exchange_name')
if not new_records.empty:
data = pd.concat([data, new_records], axis=0, ignore_index=True)
data = data.drop_duplicates(subset="open_time", keep='first')
return data
if self.db.table_exists(table_name=table_name):
print('\nTable existed retrieving records from DB')
print(f'Requesting from {st} to {et}')
records = self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=st, et=et)
print(f'Got {len(records.index)} records from db')
else:
print(f'\nTable didnt exist fetching from {ex_details[2]}')
temp = (((unix_time_millis(et) - unix_time_millis(st)) / 1000) / 60) / rl
print(f'Requesting from {st} to {et}, Should be {temp} records')
records = self._populate_db(table_name=table_name, start_time=st, end_time=et, ex_details=ex_details)
print(f'Got {len(records.index)} records from {ex_details[2]}')
first_timestamp = query_satisfied(start_datetime=st, records=records, r_length_min=rl)
if first_timestamp:
print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}')
print(f'first ts on record is: {first_timestamp}')
end_time = dt.datetime.utcfromtimestamp(first_timestamp)
print(f'Requesting from {st} to {end_time}')
records = add_data(data=records, tn=table_name, start_t=st, end_t=end_time)
last_timestamp = query_uptodate(records=records, r_length_min=rl)
if last_timestamp:
print(f'\nRecords were not updated. Requesting from {ex_details[2]}.')
print(f'the last record on file is: {last_timestamp}')
start_time = dt.datetime.utcfromtimestamp(last_timestamp)
print(f'Requesting from {start_time} to {et}')
records = add_data(data=records, tn=table_name, start_t=start_time, end_t=et)
return records
def _populate_db(self, table_name: str, start_time: dt.datetime, ex_details: list,
end_time: dt.datetime = None) -> pd.DataFrame:
"""
Populates a database table with records from the exchange.
:param table_name: Name of the table in the database.
:param start_time: Start time to fetch the records from.
:param end_time: End time to fetch the records until (optional).
:param ex_details: Exchange details [symbol, interval, exchange_name, user_name].
:return: DataFrame of the data downloaded.
Example:
--------
records = db._populate_table('test_table', start_time, ['BTC/USDT', '1m', 'binance', 'user1'])
"""
if end_time is None:
end_time = dt.datetime.utcnow()
sym, inter, ex, un = ex_details
records = self._fetch_candles_from_exchange(symbol=sym, interval=inter, exchange_name=ex, user_name=un,
start_datetime=start_time, end_datetime=end_time)
if not records.empty:
self.db.insert_candles_into_db(records, table_name=table_name, symbol=sym, exchange_name=ex)
else:
print(f'No records inserted {records}')
return records
def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str,
start_datetime: dt.datetime = None,
end_datetime: dt.datetime = None) -> pd.DataFrame:
"""
Fetches and returns all candles from the specified market, timeframe, and exchange.
:param symbol: Symbol of the market.
:param interval: Timeframe in the format '<int><alpha>' (e.g., '15m', '4h').
:param exchange_name: Name of the exchange.
:param user_name: Name of the user.
:param start_datetime: Start datetime for fetching data (optional).
:param end_datetime: End datetime for fetching data (optional).
:return: DataFrame of candle data.
Example:
--------
candles = db._fetch_candles_from_exchange('BTC/USDT', '1m', 'binance', 'user1', start_time, end_time)
"""
def fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame:
"""
Fills gaps in the data by replicating the last known data point for the missing periods.
:param records: DataFrame containing the original records.
:param interval: Interval of the data (e.g., '1m', '5m').
:return: DataFrame with gaps filled.
Example:
--------
filled_records = fill_data_holes(df, '1m')
"""
time_span = timeframe_to_minutes(interval)
last_timestamp = None
filled_records = []
logger.info(f"Starting to fill data holes for interval: {interval}")
for index, row in records.iterrows():
time_stamp = row['open_time']
# If last_timestamp is None, this is the first record
if last_timestamp is None:
last_timestamp = time_stamp
filled_records.append(row)
logger.debug(f"First timestamp: {time_stamp}")
continue
# Calculate the difference in milliseconds and minutes
delta_ms = time_stamp - last_timestamp
delta_minutes = (delta_ms / 1000) / 60
logger.debug(f"Timestamp: {time_stamp}, Delta minutes: {delta_minutes}")
# If the gap is larger than the time span of the interval, fill the gap
if delta_minutes > time_span:
num_missing_rec = int(delta_minutes / time_span)
step = int(delta_ms / num_missing_rec)
logger.debug(f"Gap detected. Filling {num_missing_rec} records with step: {step}")
for ts in range(int(last_timestamp) + step, int(time_stamp), step):
new_row = row.copy()
new_row['open_time'] = ts
filled_records.append(new_row)
logger.debug(f"Filled timestamp: {ts}")
filled_records.append(row)
last_timestamp = time_stamp
logger.info("Data holes filled successfully.")
return pd.DataFrame(filled_records)
# Default start date if not provided
if start_datetime is None:
start_datetime = dt.datetime(year=2017, month=1, day=1)
# Default end date if not provided
if end_datetime is None:
end_datetime = dt.datetime.utcnow()
# Check if start date is greater than end date
if start_datetime > end_datetime:
raise ValueError("Invalid start and end parameters: start_datetime must be before end_datetime.")
# Get the exchange object
exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name)
# Calculate the expected number of records
temp = (((unix_time_millis(end_datetime) - unix_time_millis(
start_datetime)) / 1000) / 60) / timeframe_to_minutes(interval)
logger.info(f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {temp}')
# If start and end times are the same, set end_datetime to None
if start_datetime == end_datetime:
end_datetime = None
# Fetch historical candlestick data from the exchange
candles = exchange.get_historical_klines(symbol=symbol, interval=interval, start_dt=start_datetime,
end_dt=end_datetime)
num_rec_records = len(candles.index)
logger.info(f'{num_rec_records} candles retrieved from the exchange.')
# Check if the retrieved data covers the expected time range
open_times = candles.open_time
estimated_num_records = (((open_times.max() - open_times.min()) / 1000) / 60) / timeframe_to_minutes(interval)
# Fill in any missing data if the retrieved data is less than expected
if num_rec_records < estimated_num_records:
candles = fill_data_holes(candles, interval)
return candles