Re-wrote DataCache.py and Implemented tests and got them all passing.
This commit is contained in:
parent
c398a423a3
commit
d288eebbec
323
src/DataCache.py
323
src/DataCache.py
|
|
@ -5,13 +5,12 @@ from Database import Database
|
|||
from shared_utilities import query_satisfied, query_uptodate, unix_time_millis, timeframe_to_minutes
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataCache:
|
||||
"""
|
||||
Fetches manages data limits and optimises memory storage.
|
||||
Fetches and manages data limits and optimizes memory storage.
|
||||
Handles connections and operations for the given exchanges.
|
||||
|
||||
Example usage:
|
||||
|
|
@ -19,156 +18,229 @@ class DataCache:
|
|||
db = DataCache(exchanges=some_exchanges_object)
|
||||
"""
|
||||
|
||||
# Disable during production for improved performance.
|
||||
TYPECHECKING_ENABLED = True
|
||||
|
||||
NO_RECORDS_FOUND = float('nan')
|
||||
|
||||
def __init__(self, exchanges):
|
||||
"""
|
||||
Initializes the DataCache class.
|
||||
|
||||
:param exchanges: The exchanges object handling communication with connected exchanges.
|
||||
"""
|
||||
# Maximum number of tables to cache at any given time.
|
||||
self.max_tables = 50
|
||||
# Maximum number of records to be kept per table.
|
||||
self.max_records = 1000
|
||||
# A dictionary that holds all the cached records.
|
||||
self.cached_data = {}
|
||||
# The class that handles the DB interactions.
|
||||
self.db = Database()
|
||||
# The class that handles exchange interactions.
|
||||
self.exchanges = exchanges
|
||||
|
||||
def cache_exists(self, key: str) -> bool:
|
||||
"""
|
||||
Return False if a cache doesn't exist for this key.
|
||||
Checks if a cache exists for the given key.
|
||||
|
||||
:param key: The access key.
|
||||
:return: True if cache exists, False otherwise.
|
||||
"""
|
||||
return key in self.cached_data
|
||||
|
||||
def update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None:
|
||||
"""
|
||||
Adds records to existing cache.
|
||||
|
||||
:param more_records: Adds records to existing cache.
|
||||
:param more_records: The new records to be added.
|
||||
:param key: The access key.
|
||||
:return: None.
|
||||
"""
|
||||
# Combine the new candles with the previously cached dataframe.
|
||||
records = pd.concat([more_records, self.get_cache(key)], axis=0, ignore_index=True)
|
||||
# Drop any duplicates from overlap.
|
||||
records = records.drop_duplicates(subset="open_time", keep='first')
|
||||
# Sort the records by open_time.
|
||||
records = records.sort_values(by='open_time').reset_index(drop=True)
|
||||
# Replace the incomplete dataframe with the modified one.
|
||||
self.set_cache(data=records, key=key)
|
||||
return
|
||||
|
||||
def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None:
|
||||
"""
|
||||
Creates a new cache key and inserts some data.
|
||||
Todo: This is where this data will be passed to a cache server.
|
||||
:param data: The records to insert into cache.
|
||||
:param key: The index key for the data.
|
||||
:param do_not_overwrite: - Flag to prevent overwriting existing data.
|
||||
:return: None
|
||||
Creates a new cache key and inserts data.
|
||||
|
||||
:param data: The records to insert into cache.
|
||||
:param key: The index key for the data.
|
||||
:param do_not_overwrite: Flag to prevent overwriting existing data.
|
||||
:return: None
|
||||
"""
|
||||
# If the flag is set don't overwrite existing data.
|
||||
if do_not_overwrite and key in self.cached_data:
|
||||
return
|
||||
# Assign the data
|
||||
self.cached_data[key] = data
|
||||
|
||||
def update_cached_dict(self, cache_key, dict_key: str, data: Any) -> None:
|
||||
def update_cached_dict(self, cache_key: str, dict_key: str, data: Any) -> None:
|
||||
"""
|
||||
Updates a dictionary stored in cache.
|
||||
Todo: This is where this data will be passed to a cache server.
|
||||
:param data: - The records to insert into cache.
|
||||
:param cache_key: - The cache index key for the dictionary.
|
||||
:param dict_key: - The dictionary key for the data.
|
||||
:return: None
|
||||
Updates a dictionary stored in cache.
|
||||
|
||||
:param data: The data to insert into cache.
|
||||
:param cache_key: The cache index key for the dictionary.
|
||||
:param dict_key: The dictionary key for the data.
|
||||
:return: None
|
||||
"""
|
||||
# Assign the data
|
||||
self.cached_data[cache_key].update({dict_key: data})
|
||||
|
||||
def get_cache(self, key: str) -> Any:
|
||||
"""
|
||||
Returns data indexed by key.
|
||||
Todo: This is where data will be retrieved from a cache server.
|
||||
Returns data indexed by key.
|
||||
|
||||
:param key: The index key for the data.
|
||||
:return: Any|None - The requested data or None on key error.
|
||||
:param key: The index key for the data.
|
||||
:return: Any|None - The requested data or None on key error.
|
||||
"""
|
||||
if key not in self.cached_data:
|
||||
print(f"[WARNING: DataCache.py] The requested cache key({key}) doesn't exist!")
|
||||
logger.warning(f"The requested cache key({key}) doesn't exist!")
|
||||
return None
|
||||
return self.cached_data[key]
|
||||
|
||||
def improved_get_records_since(self, key: str, start_datetime: dt.datetime, record_length: int,
|
||||
ex_details: List[str]) -> pd.DataFrame:
|
||||
"""
|
||||
Fetches records since the specified start datetime.
|
||||
|
||||
:param key: The cache key.
|
||||
:param start_datetime: The start datetime to fetch records from.
|
||||
:param record_length: The required number of records.
|
||||
:param ex_details: Exchange details.
|
||||
:return: DataFrame containing the records.
|
||||
"""
|
||||
try:
|
||||
target = 'cache'
|
||||
args = {
|
||||
'key': key,
|
||||
'start_datetime': start_datetime,
|
||||
'end_datetime': dt.datetime.utcnow(),
|
||||
'record_length': record_length,
|
||||
'ex_details': ex_details
|
||||
}
|
||||
|
||||
df = self.get_or_fetch_from(target=target, **args)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame:
|
||||
key = kwargs.get('key')
|
||||
start_datetime = kwargs.get('start_datetime')
|
||||
end_datetime = kwargs.get('start_datetime')
|
||||
record_length = kwargs.get('record_length')
|
||||
ex_details = kwargs.get('ex_details')
|
||||
|
||||
if self.TYPECHECKING_ENABLED:
|
||||
# Type checking
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("key must be a string")
|
||||
if not isinstance(start_datetime, dt.datetime):
|
||||
raise TypeError("start_datetime must be a datetime object")
|
||||
if not isinstance(end_datetime, dt.datetime):
|
||||
raise TypeError("end_datetime must be a datetime object")
|
||||
if not isinstance(record_length, int):
|
||||
raise TypeError("record_length must be an integer")
|
||||
if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details):
|
||||
raise TypeError("ex_details must be a list of strings")
|
||||
|
||||
# Ensure all required arguments are provided
|
||||
if not all([key, start_datetime, record_length, ex_details]):
|
||||
raise ValueError("Missing required arguments")
|
||||
|
||||
def get_from_cache():
|
||||
return pd.DataFrame
|
||||
|
||||
def get_from_database():
|
||||
return pd.DataFrame
|
||||
|
||||
def get_from_server():
|
||||
return pd.DataFrame
|
||||
|
||||
def data_complete(data, **kwargs) -> bool:
|
||||
"""Check if a dataframe completely satisfied a request."""
|
||||
sd = kwargs.get('start_datetime')
|
||||
ed = kwargs.get('start_datetime')
|
||||
rl = kwargs.get('record_length')
|
||||
|
||||
is_complete = True
|
||||
return is_complete
|
||||
|
||||
request_criteria = {
|
||||
'start_datetime': start_datetime,
|
||||
'end_datetime': end_datetime,
|
||||
'record_length': record_length,
|
||||
}
|
||||
|
||||
if target == 'cache':
|
||||
result = get_from_cache()
|
||||
if data_complete(result, **request_criteria):
|
||||
return result
|
||||
else:
|
||||
self.get_or_fetch_from('database', **kwargs)
|
||||
elif target == 'database':
|
||||
result = get_from_database()
|
||||
if data_complete(result, **request_criteria):
|
||||
return result
|
||||
else:
|
||||
self.get_or_fetch_from('server', **kwargs)
|
||||
elif target == 'server':
|
||||
result = get_from_server()
|
||||
if data_complete(result, **request_criteria):
|
||||
return result
|
||||
else:
|
||||
logger.error('Unable to fetch the requested data.')
|
||||
else:
|
||||
raise ValueError(f'Not a valid target: {target}')
|
||||
|
||||
def get_records_since(self, key: str, start_datetime: dt.datetime, record_length: int,
|
||||
ex_details: List[str]) -> pd.DataFrame:
|
||||
"""
|
||||
Return any records from the cache indexed by table_name that are newer than start_datetime.
|
||||
Fetches records since the specified start datetime.
|
||||
|
||||
:param ex_details: List[str] - Details to pass to the server
|
||||
:param key: str - The dictionary table_name of the records.
|
||||
:param start_datetime: dt.datetime - The datetime of the first record requested.
|
||||
:param record_length: int - The timespan of the records.
|
||||
|
||||
:return: pd.DataFrame - The Requested records
|
||||
|
||||
Example:
|
||||
--------
|
||||
records = data_cache.get_records_since('BTC/USD_2h_binance', dt.datetime.utcnow() - dt.timedelta(minutes=60), 60, ['BTC/USD', '2h', 'binance'])
|
||||
:param key: The cache key.
|
||||
:param start_datetime: The start datetime to fetch records from.
|
||||
:param record_length: The required number of records.
|
||||
:param ex_details: Exchange details.
|
||||
:return: DataFrame containing the records.
|
||||
"""
|
||||
try:
|
||||
# End time of query defaults to the current time.
|
||||
end_datetime = dt.datetime.utcnow()
|
||||
|
||||
if self.cache_exists(key=key):
|
||||
logger.debug('Getting records from cache.')
|
||||
# If the records exist, retrieve them from the cache.
|
||||
records = self.get_cache(key)
|
||||
else:
|
||||
# If they don't exist in cache, get them from the database.
|
||||
logger.debug(
|
||||
f'Records not in cache. Requesting from DB: starting at: {start_datetime} to: {end_datetime}')
|
||||
records = self.get_records_since_from_db(table_name=key, st=start_datetime,
|
||||
et=end_datetime, rl=record_length, ex_details=ex_details)
|
||||
records = self.get_records_since_from_db(table_name=key, st=start_datetime, et=end_datetime,
|
||||
rl=record_length, ex_details=ex_details)
|
||||
logger.debug(f'Got {len(records.index)} records from DB.')
|
||||
self.set_cache(data=records, key=key)
|
||||
|
||||
# Check if the records in the cache go far enough back to satisfy the query.
|
||||
first_timestamp = query_satisfied(start_datetime=start_datetime,
|
||||
records=records,
|
||||
first_timestamp = query_satisfied(start_datetime=start_datetime, records=records,
|
||||
r_length_min=record_length)
|
||||
if first_timestamp:
|
||||
# The records didn't go far enough back if a timestamp was returned.
|
||||
if pd.isna(first_timestamp):
|
||||
logger.debug('No records found to satisfy the query, continuing to fetch more records.')
|
||||
additional_records = self.get_records_since_from_db(table_name=key, st=start_datetime, et=end_datetime,
|
||||
rl=record_length, ex_details=ex_details)
|
||||
elif first_timestamp is not None:
|
||||
end_time = dt.datetime.utcfromtimestamp(first_timestamp)
|
||||
logger.debug(f'Requesting additional records from {start_datetime} to {end_time}')
|
||||
# Request additional records from the database.
|
||||
additional_records = self.get_records_since_from_db(table_name=key, st=start_datetime,
|
||||
et=end_time, rl=record_length,
|
||||
ex_details=ex_details)
|
||||
additional_records = self.get_records_since_from_db(table_name=key, st=start_datetime, et=end_time,
|
||||
rl=record_length, ex_details=ex_details)
|
||||
if not additional_records.empty:
|
||||
logger.debug(f'Got {len(additional_records.index)} additional records from DB.')
|
||||
if not additional_records.empty:
|
||||
# If more records were received, update the cache.
|
||||
self.update_candle_cache(additional_records, key)
|
||||
self.update_candle_cache(additional_records, key)
|
||||
|
||||
# Check if the records received are up-to-date.
|
||||
last_timestamp = query_uptodate(records=records, r_length_min=record_length)
|
||||
|
||||
if last_timestamp:
|
||||
# The query was not up-to-date if a timestamp was returned.
|
||||
if last_timestamp is not None:
|
||||
start_time = dt.datetime.utcfromtimestamp(last_timestamp)
|
||||
logger.debug(f'Requesting additional records from {start_time} to {end_datetime}')
|
||||
# Request additional records from the database.
|
||||
additional_records = self.get_records_since_from_db(table_name=key, st=start_time,
|
||||
et=end_datetime, rl=record_length,
|
||||
ex_details=ex_details)
|
||||
additional_records = self.get_records_since_from_db(table_name=key, st=start_time, et=end_datetime,
|
||||
rl=record_length, ex_details=ex_details)
|
||||
logger.debug(f'Got {len(additional_records.index)} additional records from DB.')
|
||||
if not additional_records.empty:
|
||||
self.update_candle_cache(additional_records, key)
|
||||
|
||||
# Create a UTC timestamp.
|
||||
_timestamp = unix_time_millis(start_datetime)
|
||||
logger.debug(f"Start timestamp in human-readable form: {dt.datetime.utcfromtimestamp(_timestamp / 1000.0)}")
|
||||
|
||||
# Return all records equal to or newer than the timestamp.
|
||||
result = self.get_cache(key).query('open_time >= @_timestamp')
|
||||
return result
|
||||
|
||||
|
|
@ -176,64 +248,61 @@ class DataCache:
|
|||
logger.error(f"An error occurred: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_records_since_from_db(self, table_name: str, st: dt.datetime,
|
||||
et: dt.datetime, rl: float, ex_details: list) -> pd.DataFrame:
|
||||
def get_records_since_from_db(self, table_name: str, st: dt.datetime, et: dt.datetime, rl: float,
|
||||
ex_details: List[str]) -> pd.DataFrame:
|
||||
"""
|
||||
Returns records from a specified table that meet a criteria and ensures the records are complete.
|
||||
If the records do not go back far enough or are not up-to-date, it fetches additional records
|
||||
from an exchange and populates the table.
|
||||
Fetches records from the database since the specified start datetime.
|
||||
|
||||
:param table_name: Database table name.
|
||||
:param st: Start datetime.
|
||||
:param et: End datetime.
|
||||
:param rl: Timespan in minutes each record represents.
|
||||
:param ex_details: Exchange details [symbol, interval, exchange_name].
|
||||
:return: DataFrame of records.
|
||||
|
||||
Example:
|
||||
--------
|
||||
records = db.get_records_since('test_table', start_time, end_time, 1, ['BTC/USDT', '1m', 'binance'])
|
||||
:param table_name: The name of the table in the database.
|
||||
:param st: The start datetime to fetch records from.
|
||||
:param et: The end datetime to fetch records until.
|
||||
:param rl: The required number of records.
|
||||
:param ex_details: Exchange details.
|
||||
:return: DataFrame containing the records.
|
||||
"""
|
||||
|
||||
def add_data(data, tn, start_t, end_t):
|
||||
def add_data(data: pd.DataFrame, tn: str, start_t: dt.datetime, end_t: dt.datetime) -> pd.DataFrame:
|
||||
new_records = self._populate_db(table_name=tn, start_time=start_t, end_time=end_t, ex_details=ex_details)
|
||||
print(f'Got {len(new_records.index)} records from exchange_name')
|
||||
logger.debug(f'Got {len(new_records.index)} records from exchange_name')
|
||||
if not new_records.empty:
|
||||
data = pd.concat([data, new_records], axis=0, ignore_index=True)
|
||||
data = data.drop_duplicates(subset="open_time", keep='first')
|
||||
return data
|
||||
|
||||
if self.db.table_exists(table_name=table_name):
|
||||
print('\nTable existed retrieving records from DB')
|
||||
print(f'Requesting from {st} to {et}')
|
||||
logger.debug('Table existed retrieving records from DB')
|
||||
logger.debug(f'Requesting from {st} to {et}')
|
||||
records = self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=st, et=et)
|
||||
print(f'Got {len(records.index)} records from db')
|
||||
logger.debug(f'Got {len(records.index)} records from db')
|
||||
else:
|
||||
print(f'\nTable didnt exist fetching from {ex_details[2]}')
|
||||
logger.debug(f"Table didn't exist fetching from {ex_details[2]}")
|
||||
temp = (((unix_time_millis(et) - unix_time_millis(st)) / 1000) / 60) / rl
|
||||
print(f'Requesting from {st} to {et}, Should be {temp} records')
|
||||
logger.debug(f'Requesting from {st} to {et}, Should be {temp} records')
|
||||
records = self._populate_db(table_name=table_name, start_time=st, end_time=et, ex_details=ex_details)
|
||||
print(f'Got {len(records.index)} records from {ex_details[2]}')
|
||||
logger.debug(f'Got {len(records.index)} records from {ex_details[2]}')
|
||||
|
||||
first_timestamp = query_satisfied(start_datetime=st, records=records, r_length_min=rl)
|
||||
if first_timestamp:
|
||||
print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}')
|
||||
print(f'first ts on record is: {first_timestamp}')
|
||||
if pd.isna(first_timestamp):
|
||||
logger.debug('No records found to satisfy the query, continuing to fetch more records.')
|
||||
records = add_data(data=records, tn=table_name, start_t=st, end_t=et)
|
||||
elif first_timestamp:
|
||||
logger.debug(f'Records did not go far enough back. Requesting from {ex_details[2]}')
|
||||
logger.debug(f'First ts on record is: {first_timestamp}')
|
||||
end_time = dt.datetime.utcfromtimestamp(first_timestamp)
|
||||
print(f'Requesting from {st} to {end_time}')
|
||||
logger.debug(f'Requesting from {st} to {end_time}')
|
||||
records = add_data(data=records, tn=table_name, start_t=st, end_t=end_time)
|
||||
|
||||
last_timestamp = query_uptodate(records=records, r_length_min=rl)
|
||||
if last_timestamp:
|
||||
print(f'\nRecords were not updated. Requesting from {ex_details[2]}.')
|
||||
print(f'the last record on file is: {last_timestamp}')
|
||||
logger.debug(f'Records were not updated. Requesting from {ex_details[2]}.')
|
||||
logger.debug(f'The last record on file is: {last_timestamp}')
|
||||
start_time = dt.datetime.utcfromtimestamp(last_timestamp)
|
||||
print(f'Requesting from {start_time} to {et}')
|
||||
logger.debug(f'Requesting from {start_time} to {et}')
|
||||
records = add_data(data=records, tn=table_name, start_t=start_time, end_t=et)
|
||||
|
||||
return records
|
||||
|
||||
def _populate_db(self, table_name: str, start_time: dt.datetime, ex_details: list,
|
||||
def _populate_db(self, table_name: str, start_time: dt.datetime, ex_details: List[str],
|
||||
end_time: dt.datetime = None) -> pd.DataFrame:
|
||||
"""
|
||||
Populates a database table with records from the exchange.
|
||||
|
|
@ -243,10 +312,6 @@ class DataCache:
|
|||
:param end_time: End time to fetch the records until (optional).
|
||||
:param ex_details: Exchange details [symbol, interval, exchange_name, user_name].
|
||||
:return: DataFrame of the data downloaded.
|
||||
|
||||
Example:
|
||||
--------
|
||||
records = db._populate_table('test_table', start_time, ['BTC/USDT', '1m', 'binance', 'user1'])
|
||||
"""
|
||||
if end_time is None:
|
||||
end_time = dt.datetime.utcnow()
|
||||
|
|
@ -256,7 +321,7 @@ class DataCache:
|
|||
if not records.empty:
|
||||
self.db.insert_candles_into_db(records, table_name=table_name, symbol=sym, exchange_name=ex)
|
||||
else:
|
||||
print(f'No records inserted {records}')
|
||||
logger.debug(f'No records inserted {records}')
|
||||
return records
|
||||
|
||||
def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str,
|
||||
|
|
@ -272,10 +337,6 @@ class DataCache:
|
|||
:param start_datetime: Start datetime for fetching data (optional).
|
||||
:param end_datetime: End datetime for fetching data (optional).
|
||||
:return: DataFrame of candle data.
|
||||
|
||||
Example:
|
||||
--------
|
||||
candles = db._fetch_candles_from_exchange('BTC/USDT', '1m', 'binance', 'user1', start_time, end_time)
|
||||
"""
|
||||
|
||||
def fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame:
|
||||
|
|
@ -285,10 +346,6 @@ class DataCache:
|
|||
:param records: DataFrame containing the original records.
|
||||
:param interval: Interval of the data (e.g., '1m', '5m').
|
||||
:return: DataFrame with gaps filled.
|
||||
|
||||
Example:
|
||||
--------
|
||||
filled_records = fill_data_holes(df, '1m')
|
||||
"""
|
||||
time_span = timeframe_to_minutes(interval)
|
||||
last_timestamp = None
|
||||
|
|
@ -299,20 +356,17 @@ class DataCache:
|
|||
for index, row in records.iterrows():
|
||||
time_stamp = row['open_time']
|
||||
|
||||
# If last_timestamp is None, this is the first record
|
||||
if last_timestamp is None:
|
||||
last_timestamp = time_stamp
|
||||
filled_records.append(row)
|
||||
logger.debug(f"First timestamp: {time_stamp}")
|
||||
continue
|
||||
|
||||
# Calculate the difference in milliseconds and minutes
|
||||
delta_ms = time_stamp - last_timestamp
|
||||
delta_minutes = (delta_ms / 1000) / 60
|
||||
|
||||
logger.debug(f"Timestamp: {time_stamp}, Delta minutes: {delta_minutes}")
|
||||
|
||||
# If the gap is larger than the time span of the interval, fill the gap
|
||||
if delta_minutes > time_span:
|
||||
num_missing_rec = int(delta_minutes / time_span)
|
||||
step = int(delta_ms / num_missing_rec)
|
||||
|
|
@ -330,44 +384,47 @@ class DataCache:
|
|||
logger.info("Data holes filled successfully.")
|
||||
return pd.DataFrame(filled_records)
|
||||
|
||||
# Default start date if not provided
|
||||
if start_datetime is None:
|
||||
start_datetime = dt.datetime(year=2017, month=1, day=1)
|
||||
|
||||
# Default end date if not provided
|
||||
if end_datetime is None:
|
||||
end_datetime = dt.datetime.utcnow()
|
||||
|
||||
# Check if start date is greater than end date
|
||||
if start_datetime > end_datetime:
|
||||
raise ValueError("Invalid start and end parameters: start_datetime must be before end_datetime.")
|
||||
|
||||
# Get the exchange object
|
||||
exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name)
|
||||
|
||||
# Calculate the expected number of records
|
||||
temp = (((unix_time_millis(end_datetime) - unix_time_millis(
|
||||
expected_records = (((unix_time_millis(end_datetime) - unix_time_millis(
|
||||
start_datetime)) / 1000) / 60) / timeframe_to_minutes(interval)
|
||||
logger.info(
|
||||
f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {expected_records}')
|
||||
|
||||
logger.info(f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {temp}')
|
||||
|
||||
# If start and end times are the same, set end_datetime to None
|
||||
if start_datetime == end_datetime:
|
||||
end_datetime = None
|
||||
|
||||
# Fetch historical candlestick data from the exchange
|
||||
candles = exchange.get_historical_klines(symbol=symbol, interval=interval, start_dt=start_datetime,
|
||||
end_dt=end_datetime)
|
||||
num_rec_records = len(candles.index)
|
||||
|
||||
logger.info(f'{num_rec_records} candles retrieved from the exchange.')
|
||||
|
||||
# Check if the retrieved data covers the expected time range
|
||||
open_times = candles.open_time
|
||||
estimated_num_records = (((open_times.max() - open_times.min()) / 1000) / 60) / timeframe_to_minutes(interval)
|
||||
min_open_time = open_times.min()
|
||||
max_open_time = open_times.max()
|
||||
|
||||
if min_open_time < 1e10:
|
||||
raise ValueError('Records are not in milliseconds')
|
||||
|
||||
max_open_time /= 1000
|
||||
min_open_time /= 1000
|
||||
|
||||
estimated_num_records = ((max_open_time - min_open_time) / 60) / timeframe_to_minutes(interval)
|
||||
|
||||
logger.info(f'Estimated number of records: {estimated_num_records}')
|
||||
|
||||
# Fill in any missing data if the retrieved data is less than expected
|
||||
if num_rec_records < estimated_num_records:
|
||||
logger.info('Detected gaps in the data, attempting to fill missing records.')
|
||||
candles = fill_data_holes(candles, interval)
|
||||
|
||||
return candles
|
||||
|
|
|
|||
|
|
@ -0,0 +1,454 @@
|
|||
from typing import List, Any
|
||||
import pandas as pd
|
||||
import datetime as dt
|
||||
import logging
|
||||
from shared_utilities import unix_time_millis
|
||||
from Database import Database
|
||||
import numpy as np
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def timeframe_to_timedelta(timeframe: str) -> pd.Timedelta | pd.DateOffset:
|
||||
digits = int("".join([i if i.isdigit() else "" for i in timeframe]))
|
||||
unit = "".join([i if i.isalpha() else "" for i in timeframe])
|
||||
|
||||
if unit == 'm':
|
||||
return pd.Timedelta(minutes=digits)
|
||||
elif unit == 'h':
|
||||
return pd.Timedelta(hours=digits)
|
||||
elif unit == 'd':
|
||||
return pd.Timedelta(days=digits)
|
||||
elif unit == 'w':
|
||||
return pd.Timedelta(weeks=digits)
|
||||
elif unit == 'M':
|
||||
return pd.DateOffset(months=digits)
|
||||
elif unit == 'Y':
|
||||
return pd.DateOffset(years=digits)
|
||||
else:
|
||||
raise ValueError(f"Invalid timeframe unit: {unit}")
|
||||
|
||||
|
||||
def estimate_record_count(start_time, end_time, timeframe: str) -> int:
|
||||
"""
|
||||
Estimate the number of records expected between start_time and end_time based on the given timeframe.
|
||||
Accepts either datetime objects or Unix timestamps in milliseconds.
|
||||
"""
|
||||
# Check if the input is in milliseconds (timestamp)
|
||||
if isinstance(start_time, (int, float, np.integer)) and isinstance(end_time, (int, float, np.integer)):
|
||||
# Convert timestamps from milliseconds to seconds for calculation
|
||||
start_time = int(start_time) / 1000
|
||||
end_time = int(end_time) / 1000
|
||||
start_datetime = dt.datetime.utcfromtimestamp(start_time)
|
||||
end_datetime = dt.datetime.utcfromtimestamp(end_time)
|
||||
elif isinstance(start_time, dt.datetime) and isinstance(end_time, dt.datetime):
|
||||
start_datetime = start_time
|
||||
end_datetime = end_time
|
||||
else:
|
||||
raise ValueError("start_time and end_time must be either both "
|
||||
"datetime objects or both Unix timestamps in milliseconds.")
|
||||
|
||||
delta = timeframe_to_timedelta(timeframe)
|
||||
total_seconds = (end_datetime - start_datetime).total_seconds()
|
||||
expected_records = total_seconds // delta.total_seconds()
|
||||
return int(expected_records)
|
||||
|
||||
|
||||
def generate_expected_timestamps(start_datetime: dt.datetime, end_datetime: dt.datetime,
|
||||
timeframe: str) -> pd.DatetimeIndex:
|
||||
delta = timeframe_to_timedelta(timeframe)
|
||||
if isinstance(delta, pd.Timedelta):
|
||||
return pd.date_range(start=start_datetime, end=end_datetime, freq=delta)
|
||||
elif isinstance(delta, pd.DateOffset):
|
||||
current = start_datetime
|
||||
timestamps = []
|
||||
while current <= end_datetime:
|
||||
timestamps.append(current)
|
||||
current += delta
|
||||
return pd.DatetimeIndex(timestamps)
|
||||
|
||||
|
||||
class DataCache:
|
||||
TYPECHECKING_ENABLED = True
|
||||
|
||||
def __init__(self, exchanges):
|
||||
self.db = Database()
|
||||
self.exchanges = exchanges
|
||||
self.cached_data = {}
|
||||
logger.info("DataCache initialized.")
|
||||
|
||||
def get_records_since(self, start_datetime: dt.datetime, ex_details: List[str]) -> pd.DataFrame:
|
||||
if self.TYPECHECKING_ENABLED:
|
||||
if not isinstance(start_datetime, dt.datetime):
|
||||
raise TypeError("start_datetime must be a datetime object")
|
||||
if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details):
|
||||
raise TypeError("ex_details must be a list of strings")
|
||||
if not len(ex_details) == 4:
|
||||
raise TypeError("ex_details must include [asset, timeframe, exchange, user_name]")
|
||||
|
||||
try:
|
||||
args = {
|
||||
'start_datetime': start_datetime,
|
||||
'end_datetime': dt.datetime.utcnow(),
|
||||
'ex_details': ex_details,
|
||||
}
|
||||
return self.get_or_fetch_from('cache', **args)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame:
|
||||
start_datetime = kwargs.get('start_datetime')
|
||||
end_datetime = kwargs.get('end_datetime')
|
||||
ex_details = kwargs.get('ex_details')
|
||||
timeframe = kwargs.get('ex_details')[1]
|
||||
|
||||
if self.TYPECHECKING_ENABLED:
|
||||
if not isinstance(start_datetime, dt.datetime):
|
||||
raise TypeError("start_datetime must be a datetime object")
|
||||
if not isinstance(end_datetime, dt.datetime):
|
||||
raise TypeError("end_datetime must be a datetime object")
|
||||
if not isinstance(timeframe, str):
|
||||
raise TypeError("record_length must be a string representing the timeframe")
|
||||
if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details):
|
||||
raise TypeError("ex_details must be a list of strings")
|
||||
if not all([start_datetime, end_datetime, timeframe, ex_details]):
|
||||
raise ValueError("Missing required arguments")
|
||||
|
||||
request_criteria = {
|
||||
'start_datetime': start_datetime,
|
||||
'end_datetime': end_datetime,
|
||||
'timeframe': timeframe,
|
||||
}
|
||||
|
||||
key = self._make_key(ex_details)
|
||||
combined_data = pd.DataFrame()
|
||||
|
||||
if target == 'cache':
|
||||
resources = [self.get_candles_from_cache, self.get_from_database, self.get_from_server]
|
||||
elif target == 'database':
|
||||
resources = [self.get_from_database, self.get_from_server]
|
||||
elif target == 'server':
|
||||
resources = [self.get_from_server]
|
||||
else:
|
||||
raise ValueError('Not a valid Target!')
|
||||
|
||||
for fetch_method in resources:
|
||||
result = fetch_method(**kwargs)
|
||||
|
||||
if not result.empty:
|
||||
combined_data = pd.concat([combined_data, result]).drop_duplicates()
|
||||
|
||||
if not combined_data.empty and 'open_time' in combined_data.columns:
|
||||
combined_data = combined_data.sort_values(by='open_time').drop_duplicates(subset='open_time',
|
||||
keep='first')
|
||||
|
||||
is_complete, request_criteria = self.data_complete(combined_data, **request_criteria)
|
||||
if is_complete:
|
||||
if fetch_method in [self.get_from_database, self.get_from_server]:
|
||||
self.update_candle_cache(combined_data, key)
|
||||
if fetch_method == self.get_from_server:
|
||||
self._populate_db(ex_details, combined_data)
|
||||
return combined_data
|
||||
|
||||
kwargs.update(request_criteria) # Update kwargs with new start/end times for next fetch attempt
|
||||
|
||||
logger.error('Unable to fetch the requested data.')
|
||||
return combined_data if not combined_data.empty else pd.DataFrame()
|
||||
|
||||
def get_candles_from_cache(self, **kwargs) -> pd.DataFrame:
|
||||
start_datetime = kwargs.get('start_datetime')
|
||||
end_datetime = kwargs.get('end_datetime')
|
||||
ex_details = kwargs.get('ex_details')
|
||||
|
||||
if self.TYPECHECKING_ENABLED:
|
||||
if not isinstance(start_datetime, dt.datetime):
|
||||
raise TypeError("start_datetime must be a datetime object")
|
||||
if not isinstance(end_datetime, dt.datetime):
|
||||
raise TypeError("end_datetime must be a datetime object")
|
||||
if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details):
|
||||
raise TypeError("ex_details must be a list of strings")
|
||||
if not all([start_datetime, end_datetime, ex_details]):
|
||||
raise ValueError("Missing required arguments")
|
||||
|
||||
key = self._make_key(ex_details)
|
||||
logger.debug('Getting records from cache.')
|
||||
df = self.get_cache(key)
|
||||
if df is None:
|
||||
logger.debug("Cache records didn't exist.")
|
||||
return pd.DataFrame()
|
||||
logger.debug('Filtering records.')
|
||||
df_filtered = df[(df['open_time'] >= unix_time_millis(start_datetime)) & (
|
||||
df['open_time'] <= unix_time_millis(end_datetime))].reset_index(drop=True)
|
||||
return df_filtered
|
||||
|
||||
def get_from_database(self, **kwargs) -> pd.DataFrame:
|
||||
start_datetime = kwargs.get('start_datetime')
|
||||
end_datetime = kwargs.get('end_datetime')
|
||||
ex_details = kwargs.get('ex_details')
|
||||
|
||||
if self.TYPECHECKING_ENABLED:
|
||||
if not isinstance(start_datetime, dt.datetime):
|
||||
raise TypeError("start_datetime must be a datetime object")
|
||||
if not isinstance(end_datetime, dt.datetime):
|
||||
raise TypeError("end_datetime must be a datetime object")
|
||||
if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details):
|
||||
raise TypeError("ex_details must be a list of strings")
|
||||
if not all([start_datetime, end_datetime, ex_details]):
|
||||
raise ValueError("Missing required arguments")
|
||||
|
||||
table_name = self._make_key(ex_details)
|
||||
if not self.db.table_exists(table_name):
|
||||
logger.debug('Records not in database.')
|
||||
return pd.DataFrame()
|
||||
|
||||
logger.debug('Getting records from database.')
|
||||
return self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=start_datetime,
|
||||
et=end_datetime)
|
||||
|
||||
def get_from_server(self, **kwargs) -> pd.DataFrame:
|
||||
symbol = kwargs.get('ex_details')[0]
|
||||
interval = kwargs.get('ex_details')[1]
|
||||
exchange_name = kwargs.get('ex_details')[2]
|
||||
user_name = kwargs.get('ex_details')[3]
|
||||
start_datetime = kwargs.get('start_datetime')
|
||||
end_datetime = kwargs.get('end_datetime')
|
||||
|
||||
if self.TYPECHECKING_ENABLED:
|
||||
if not isinstance(symbol, str):
|
||||
raise TypeError("symbol must be a string")
|
||||
if not isinstance(interval, str):
|
||||
raise TypeError("interval must be a string")
|
||||
if not isinstance(exchange_name, str):
|
||||
raise TypeError("exchange_name must be a string")
|
||||
if not isinstance(user_name, str):
|
||||
raise TypeError("user_name must be a string")
|
||||
if not isinstance(start_datetime, dt.datetime):
|
||||
raise TypeError("start_datetime must be a datetime object")
|
||||
if not isinstance(end_datetime, dt.datetime):
|
||||
raise TypeError("end_datetime must be a datetime object")
|
||||
|
||||
logger.debug('Getting records from server.')
|
||||
return self._fetch_candles_from_exchange(symbol, interval, exchange_name, user_name, start_datetime,
|
||||
end_datetime)
|
||||
|
||||
@staticmethod
|
||||
def data_complete(data: pd.DataFrame, **kwargs) -> (bool, dict):
|
||||
"""
|
||||
Checks if the data completely satisfies the request.
|
||||
|
||||
:param data: DataFrame containing the records.
|
||||
:param kwargs: Arguments required for completeness check.
|
||||
:return: A tuple (is_complete, updated_request_criteria) where is_complete is True if the data is complete,
|
||||
False otherwise, and updated_request_criteria contains adjusted start/end times if data is incomplete.
|
||||
"""
|
||||
start_datetime: dt.datetime = kwargs.get('start_datetime')
|
||||
end_datetime: dt.datetime = kwargs.get('end_datetime')
|
||||
timeframe: str = kwargs.get('timeframe')
|
||||
|
||||
if data.empty:
|
||||
logger.debug("Data is empty.")
|
||||
return False, kwargs # No data at all, proceed with the full original request
|
||||
|
||||
temp_data = data.copy()
|
||||
|
||||
# Ensure 'open_time' is in datetime format
|
||||
if temp_data['open_time'].dtype != '<M8[ns]':
|
||||
temp_data['open_time_dt'] = pd.to_datetime(temp_data['open_time'], unit='ms')
|
||||
else:
|
||||
temp_data['open_time_dt'] = temp_data['open_time']
|
||||
|
||||
min_timestamp = temp_data['open_time_dt'].min()
|
||||
max_timestamp = temp_data['open_time_dt'].max()
|
||||
|
||||
logger.debug(f"Data time range: {min_timestamp} to {max_timestamp}")
|
||||
logger.debug(f"Expected time range: {start_datetime} to {end_datetime}")
|
||||
|
||||
tolerance = pd.Timedelta(seconds=5)
|
||||
|
||||
# Initialize updated request criteria
|
||||
updated_request_criteria = kwargs.copy()
|
||||
|
||||
# Check if data covers the required time range with tolerance
|
||||
if min_timestamp > start_datetime + timeframe_to_timedelta(timeframe) + tolerance:
|
||||
logger.debug("Data does not start early enough, even with tolerance.")
|
||||
updated_request_criteria['end_datetime'] = min_timestamp # Fetch the missing earlier data
|
||||
return False, updated_request_criteria
|
||||
|
||||
if max_timestamp < end_datetime - timeframe_to_timedelta(timeframe) - tolerance:
|
||||
logger.debug("Data does not extend late enough, even with tolerance.")
|
||||
updated_request_criteria['start_datetime'] = max_timestamp # Fetch the missing later data
|
||||
return False, updated_request_criteria
|
||||
|
||||
# Filter data between start_datetime and end_datetime
|
||||
mask = (temp_data['open_time_dt'] >= start_datetime) & (temp_data['open_time_dt'] <= end_datetime)
|
||||
data_in_range = temp_data.loc[mask]
|
||||
|
||||
expected_count = estimate_record_count(start_datetime, end_datetime, timeframe)
|
||||
actual_count = len(data_in_range)
|
||||
|
||||
logger.debug(f"Expected record count: {expected_count}, Actual record count: {actual_count}")
|
||||
|
||||
tolerance = 1
|
||||
if actual_count < (expected_count - tolerance):
|
||||
logger.debug("Insufficient records within the specified time range, even with tolerance.")
|
||||
return False, updated_request_criteria
|
||||
|
||||
logger.debug("Data completeness check passed.")
|
||||
return True, kwargs
|
||||
|
||||
def cache_exists(self, key: str) -> bool:
|
||||
return key in self.cached_data
|
||||
|
||||
def get_cache(self, key: str) -> Any | None:
|
||||
if key not in self.cached_data:
|
||||
logger.warning(f"The requested cache key({key}) doesn't exist!")
|
||||
return None
|
||||
return self.cached_data[key]
|
||||
|
||||
def update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None:
|
||||
logger.debug('Updating cache with new records.')
|
||||
# Concatenate the new records with the existing cache
|
||||
records = pd.concat([self.get_cache(key), more_records], axis=0, ignore_index=True)
|
||||
# Drop duplicates based on 'open_time' and keep the first occurrence
|
||||
records = records.drop_duplicates(subset="open_time", keep='first')
|
||||
# Sort the records by 'open_time'
|
||||
records = records.sort_values(by='open_time').reset_index(drop=True)
|
||||
# Reindex 'id' to ensure the expected order
|
||||
records['id'] = range(1, len(records) + 1)
|
||||
# Set the updated DataFrame back to cache
|
||||
self.set_cache(data=records, key=key)
|
||||
|
||||
def update_cached_dict(self, cache_key: str, dict_key: str, data: Any) -> None:
|
||||
"""
|
||||
Updates a dictionary stored in cache.
|
||||
|
||||
:param data: The data to insert into cache.
|
||||
:param cache_key: The cache index key for the dictionary.
|
||||
:param dict_key: The dictionary key for the data.
|
||||
:return: None
|
||||
"""
|
||||
self.cached_data[cache_key].update({dict_key: data})
|
||||
|
||||
def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None:
|
||||
if do_not_overwrite and key in self.cached_data:
|
||||
return
|
||||
self.cached_data[key] = data
|
||||
logger.debug(f'Cache set for key: {key}')
|
||||
|
||||
def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str,
|
||||
start_datetime: dt.datetime = None,
|
||||
end_datetime: dt.datetime = None) -> pd.DataFrame:
|
||||
if start_datetime is None:
|
||||
start_datetime = dt.datetime(year=2017, month=1, day=1)
|
||||
|
||||
if end_datetime is None:
|
||||
end_datetime = dt.datetime.utcnow()
|
||||
|
||||
if start_datetime > end_datetime:
|
||||
raise ValueError("Invalid start and end parameters: start_datetime must be before end_datetime.")
|
||||
|
||||
exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name)
|
||||
|
||||
expected_records = estimate_record_count(start_datetime, end_datetime, interval)
|
||||
logger.info(
|
||||
f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {expected_records}')
|
||||
|
||||
if start_datetime == end_datetime:
|
||||
end_datetime = None
|
||||
|
||||
candles = exchange.get_historical_klines(symbol=symbol, interval=interval, start_dt=start_datetime,
|
||||
end_dt=end_datetime)
|
||||
|
||||
num_rec_records = len(candles.index)
|
||||
if num_rec_records == 0:
|
||||
logger.warning(f"No OHLCV data returned for {symbol}.")
|
||||
return pd.DataFrame(columns=['open_time', 'open', 'high', 'low', 'close', 'volume'])
|
||||
|
||||
logger.info(f'{num_rec_records} candles retrieved from the exchange.')
|
||||
|
||||
open_times = candles.open_time
|
||||
min_open_time = open_times.min()
|
||||
max_open_time = open_times.max()
|
||||
|
||||
if min_open_time < 1e10:
|
||||
raise ValueError('Records are not in milliseconds')
|
||||
|
||||
estimated_num_records = estimate_record_count(min_open_time, max_open_time, interval) + 1
|
||||
logger.info(f'Estimated number of records: {estimated_num_records}')
|
||||
|
||||
if num_rec_records < estimated_num_records:
|
||||
logger.info('Detected gaps in the data, attempting to fill missing records.')
|
||||
candles = self.fill_data_holes(candles, interval)
|
||||
|
||||
return candles
|
||||
|
||||
@staticmethod
|
||||
def fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame:
|
||||
time_span = timeframe_to_timedelta(interval).total_seconds() / 60
|
||||
last_timestamp = None
|
||||
filled_records = []
|
||||
|
||||
logger.info(f"Starting to fill data holes for interval: {interval}")
|
||||
|
||||
for index, row in records.iterrows():
|
||||
time_stamp = row['open_time']
|
||||
|
||||
if last_timestamp is None:
|
||||
last_timestamp = time_stamp
|
||||
filled_records.append(row)
|
||||
logger.debug(f"First timestamp: {time_stamp}")
|
||||
continue
|
||||
|
||||
delta_ms = time_stamp - last_timestamp
|
||||
delta_minutes = (delta_ms / 1000) / 60
|
||||
|
||||
logger.debug(f"Timestamp: {time_stamp}, Delta minutes: {delta_minutes}")
|
||||
|
||||
if delta_minutes > time_span:
|
||||
num_missing_rec = int(delta_minutes / time_span)
|
||||
step = int(delta_ms / num_missing_rec)
|
||||
logger.debug(f"Gap detected. Filling {num_missing_rec} records with step: {step}")
|
||||
|
||||
for ts in range(int(last_timestamp) + step, int(time_stamp), step):
|
||||
new_row = row.copy()
|
||||
new_row['open_time'] = ts
|
||||
filled_records.append(new_row)
|
||||
logger.debug(f"Filled timestamp: {ts}")
|
||||
|
||||
filled_records.append(row)
|
||||
last_timestamp = time_stamp
|
||||
|
||||
logger.info("Data holes filled successfully.")
|
||||
return pd.DataFrame(filled_records)
|
||||
|
||||
def _populate_db(self, ex_details: List[str], data: pd.DataFrame = None) -> None:
|
||||
if data is None or data.empty:
|
||||
logger.debug(f'No records to insert {data}')
|
||||
return
|
||||
|
||||
table_name = self._make_key(ex_details)
|
||||
sym, _, ex, _ = ex_details
|
||||
|
||||
self.db.insert_candles_into_db(data, table_name=table_name, symbol=sym, exchange_name=ex)
|
||||
logger.info(f'Data inserted into table {table_name}')
|
||||
|
||||
@staticmethod
|
||||
def _make_key(ex_details: List[str]) -> str:
|
||||
sym, tf, ex, _ = ex_details
|
||||
key = f'{sym}_{tf}_{ex}'
|
||||
return key
|
||||
|
||||
# Example usage
|
||||
# args = {
|
||||
# 'start_datetime': dt.datetime.now() - dt.timedelta(hours=1), # Example start time
|
||||
# 'ex_details': ['BTCUSDT', '15m', 'Binance', 'user1'],
|
||||
# }
|
||||
#
|
||||
# exchanges = ExchangeHandler()
|
||||
# data = DataCache(exchanges)
|
||||
# df = data.get_records_since(**args)
|
||||
#
|
||||
# # Disabling type checking for a specific instance
|
||||
# data.TYPECHECKING_ENABLED = False
|
||||
163
src/Database.py
163
src/Database.py
|
|
@ -1,6 +1,6 @@
|
|||
import sqlite3
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import config
|
||||
import datetime as dt
|
||||
import pandas as pd
|
||||
|
|
@ -19,7 +19,7 @@ class SQLite:
|
|||
cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||
"""
|
||||
|
||||
def __init__(self, db_file=None):
|
||||
def __init__(self, db_file: str = None):
|
||||
self.db_file = db_file if db_file else config.DB_FILE
|
||||
self.connection = sqlite3.connect(self.db_file)
|
||||
|
||||
|
|
@ -41,11 +41,11 @@ class HDict(dict):
|
|||
hash(hdict)
|
||||
"""
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash(frozenset(self.items()))
|
||||
|
||||
|
||||
def make_query(item: str, table: str, columns: list) -> str:
|
||||
def make_query(item: str, table: str, columns: List[str]) -> str:
|
||||
"""
|
||||
Creates a SQL select query string with the required number of placeholders.
|
||||
|
||||
|
|
@ -53,40 +53,22 @@ def make_query(item: str, table: str, columns: list) -> str:
|
|||
:param table: The table to select from.
|
||||
:param columns: List of columns for the where clause.
|
||||
:return: The query string.
|
||||
|
||||
Example:
|
||||
--------
|
||||
query = make_query('id', 'test_table', ['name', 'age'])
|
||||
# Result: 'SELECT id FROM test_table WHERE name = ? AND age = ?;'
|
||||
"""
|
||||
an_itr = iter(columns)
|
||||
k = next(an_itr)
|
||||
where_str = f"SELECT {item} FROM {table} WHERE {k} = ?"
|
||||
where_str += "".join([f" AND {k} = ?" for k in an_itr]) + ';'
|
||||
return where_str
|
||||
placeholders = " AND ".join([f"{col} = ?" for col in columns])
|
||||
return f"SELECT {item} FROM {table} WHERE {placeholders};"
|
||||
|
||||
|
||||
def make_insert(table: str, values: tuple) -> str:
|
||||
def make_insert(table: str, columns: Tuple[str, ...]) -> str:
|
||||
"""
|
||||
Creates a SQL insert query string with the required number of placeholders.
|
||||
|
||||
:param table: The table to insert into.
|
||||
:param values: Tuple of values to insert.
|
||||
:param columns: Tuple of column names.
|
||||
:return: The query string.
|
||||
|
||||
Example:
|
||||
--------
|
||||
insert = make_insert('test_table', ('name', 'age'))
|
||||
# Result: "INSERT INTO test_table ('name', 'age') VALUES(?, ?);"
|
||||
"""
|
||||
itr1 = iter(values)
|
||||
itr2 = iter(values)
|
||||
k1 = next(itr1)
|
||||
_ = next(itr2)
|
||||
insert_str = f"INSERT INTO {table} ('{k1}'"
|
||||
insert_str += "".join([f", '{k1}'" for k1 in itr1]) + ") VALUES(?" + "".join(
|
||||
[", ?" for _ in enumerate(itr2)]) + ");"
|
||||
return insert_str
|
||||
col_names = ", ".join([f"'{col}'" for col in columns])
|
||||
placeholders = ", ".join(["?" for _ in columns])
|
||||
return f"INSERT INTO {table} ({col_names}) VALUES ({placeholders});"
|
||||
|
||||
|
||||
class Database:
|
||||
|
|
@ -96,16 +78,11 @@ class Database:
|
|||
|
||||
Example usage:
|
||||
--------------
|
||||
db = Database(exchanges=some_exchanges_object, db_file='test_db.sqlite')
|
||||
db = Database(db_file='test_db.sqlite')
|
||||
db.execute_sql('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||
"""
|
||||
|
||||
def __init__(self, db_file=None):
|
||||
"""
|
||||
Initializes the Database class.
|
||||
|
||||
:param db_file: Optional database file name.
|
||||
"""
|
||||
def __init__(self, db_file: str = None):
|
||||
self.db_file = db_file
|
||||
|
||||
def execute_sql(self, sql: str) -> None:
|
||||
|
|
@ -113,17 +90,12 @@ class Database:
|
|||
Executes a raw SQL statement.
|
||||
|
||||
:param sql: SQL statement to execute.
|
||||
|
||||
Example:
|
||||
--------
|
||||
db = Database(exchanges=some_exchanges_object, db_file='test_db.sqlite')
|
||||
db.execute_sql('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||
"""
|
||||
with SQLite(self.db_file) as con:
|
||||
cur = con.cursor()
|
||||
cur.execute(sql)
|
||||
|
||||
def get_item_where(self, item_name: str, table_name: str, filter_vals: tuple) -> int:
|
||||
def get_item_where(self, item_name: str, table_name: str, filter_vals: Tuple[str, Any]) -> Any:
|
||||
"""
|
||||
Returns an item from a table where the filter results should isolate a single row.
|
||||
|
||||
|
|
@ -131,42 +103,29 @@ class Database:
|
|||
:param table_name: Name of the table.
|
||||
:param filter_vals: Tuple of column name and value to filter by.
|
||||
:return: The item.
|
||||
|
||||
Example:
|
||||
--------
|
||||
item = db.get_item_where('name', 'test_table', ('id', 1))
|
||||
# Fetches the 'name' from 'test_table' where 'id' is 1
|
||||
"""
|
||||
with SQLite(self.db_file) as con:
|
||||
cur = con.cursor()
|
||||
qry = make_query(item_name, table_name, [filter_vals[0]])
|
||||
cur.execute(qry, (filter_vals[1],))
|
||||
if user_id := cur.fetchone():
|
||||
return user_id[0]
|
||||
if result := cur.fetchone():
|
||||
return result[0]
|
||||
else:
|
||||
error = f"Couldn't fetch item {item_name} from {table_name} where {filter_vals[0]} = {filter_vals[1]}"
|
||||
raise ValueError(error)
|
||||
|
||||
def get_rows_where(self, table: str, filter_vals: tuple) -> pd.DataFrame | None:
|
||||
def get_rows_where(self, table: str, filter_vals: Tuple[str, Any]) -> pd.DataFrame | None:
|
||||
"""
|
||||
Returns a DataFrame containing all rows of a table that meet the filter criteria.
|
||||
|
||||
:param table: Name of the table.
|
||||
:param filter_vals: Tuple of column name and value to filter by.
|
||||
:return: DataFrame of the query result or None if empty.
|
||||
|
||||
Example:
|
||||
--------
|
||||
rows = db.get_rows_where('test_table', ('name', 'test'))
|
||||
# Fetches all rows from 'test_table' where 'name' is 'test'
|
||||
"""
|
||||
with SQLite(self.db_file) as con:
|
||||
qry = f"SELECT * FROM {table} WHERE {filter_vals[0]}='{filter_vals[1]}'"
|
||||
result = pd.read_sql(qry, con=con)
|
||||
if not result.empty:
|
||||
return result
|
||||
else:
|
||||
return None
|
||||
qry = f"SELECT * FROM {table} WHERE {filter_vals[0]} = ?"
|
||||
result = pd.read_sql(qry, con, params=(filter_vals[1],))
|
||||
return result if not result.empty else None
|
||||
|
||||
def insert_dataframe(self, df: pd.DataFrame, table: str) -> None:
|
||||
"""
|
||||
|
|
@ -174,30 +133,21 @@ class Database:
|
|||
|
||||
:param df: DataFrame to insert.
|
||||
:param table: Name of the table.
|
||||
|
||||
Example:
|
||||
--------
|
||||
df = pd.DataFrame({'id': [1], 'name': ['test']})
|
||||
db.insert_dataframe(df, 'test_table')
|
||||
"""
|
||||
with SQLite(self.db_file) as con:
|
||||
df.to_sql(name=table, con=con, index=False, if_exists='append')
|
||||
|
||||
def insert_row(self, table: str, columns: tuple, values: tuple) -> None:
|
||||
def insert_row(self, table: str, columns: Tuple[str, ...], values: Tuple[Any, ...]) -> None:
|
||||
"""
|
||||
Inserts a row into a specified table.
|
||||
|
||||
:param table: Name of the table.
|
||||
:param columns: Tuple of column names.
|
||||
:param values: Tuple of values to insert.
|
||||
|
||||
Example:
|
||||
--------
|
||||
db.insert_row('test_table', ('id', 'name'), (1, 'test'))
|
||||
"""
|
||||
with SQLite(self.db_file) as conn:
|
||||
cursor = conn.cursor()
|
||||
sql = make_insert(table=table, values=columns)
|
||||
sql = make_insert(table=table, columns=columns)
|
||||
cursor.execute(sql, values)
|
||||
|
||||
def table_exists(self, table_name: str) -> bool:
|
||||
|
|
@ -206,11 +156,6 @@ class Database:
|
|||
|
||||
:param table_name: Name of the table.
|
||||
:return: True if the table exists, False otherwise.
|
||||
|
||||
Example:
|
||||
--------
|
||||
exists = db._table_exists('test_table')
|
||||
# Checks if 'test_table' exists in the database
|
||||
"""
|
||||
with SQLite(self.db_file) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
|
@ -220,7 +165,7 @@ class Database:
|
|||
return result is not None
|
||||
|
||||
@lru_cache(maxsize=1000)
|
||||
def get_from_static_table(self, item: str, table: str, indexes: HDict, create_id: bool = False) -> Any:
|
||||
def get_from_static_table(self, item: str, table: str, indexes: dict, create_id: bool = False) -> Any:
|
||||
"""
|
||||
Returns the row id of an item from a table specified. If the item isn't listed in the table,
|
||||
it will insert the item into a new row and return the autoincremented id. The item received as a hashable
|
||||
|
|
@ -231,23 +176,21 @@ class Database:
|
|||
:param indexes: Hashable dictionary of indexing columns and their values.
|
||||
:param create_id: If True, create a row if it doesn't exist and return the autoincrement ID.
|
||||
:return: The content of the field.
|
||||
|
||||
Example:
|
||||
--------
|
||||
exchange_id = db.get_from_static_table('id', 'exchange', HDict({'name': 'binance'}), create_id=True)
|
||||
"""
|
||||
with SQLite(self.db_file) as conn:
|
||||
cursor = conn.cursor()
|
||||
sql = make_query(item, table, list(indexes.keys()))
|
||||
cursor.execute(sql, tuple(indexes.values()))
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result is None and create_id:
|
||||
sql = make_insert(table, tuple(indexes.keys()))
|
||||
cursor.execute(sql, tuple(indexes.values()))
|
||||
sql = make_query(item, table, list(indexes.keys()))
|
||||
cursor.execute(sql, tuple(indexes.values()))
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
result = cursor.lastrowid # Get the last inserted row ID
|
||||
else:
|
||||
result = result[0] if result else None
|
||||
|
||||
return result
|
||||
|
||||
def _fetch_exchange_id(self, exchange_name: str) -> int:
|
||||
"""
|
||||
|
|
@ -255,25 +198,17 @@ class Database:
|
|||
|
||||
:param exchange_name: Name of the exchange.
|
||||
:return: Primary ID of the exchange.
|
||||
|
||||
Example:
|
||||
--------
|
||||
exchange_id = db._fetch_exchange_id('binance')
|
||||
"""
|
||||
return self.get_from_static_table(item='id', table='exchange', create_id=True,
|
||||
indexes=HDict({'name': exchange_name}))
|
||||
|
||||
def _fetch_market_id(self, symbol: str, exchange_name: str) -> int:
|
||||
"""
|
||||
Returns the market ID for a trading pair listed in the database.
|
||||
Returns the markets ID for a trading pair listed in the database.
|
||||
|
||||
:param symbol: Symbol of the trading pair.
|
||||
:param exchange_name: Name of the exchange.
|
||||
:return: Market ID.
|
||||
|
||||
Example:
|
||||
--------
|
||||
market_id = db._fetch_market_id('BTC/USDT', 'binance')
|
||||
"""
|
||||
exchange_id = self._fetch_exchange_id(exchange_name)
|
||||
market_id = self.get_from_static_table(item='id', table='markets', create_id=True,
|
||||
|
|
@ -289,21 +224,17 @@ class Database:
|
|||
:param table_name: Name of the table to insert into.
|
||||
:param symbol: Symbol of the trading pair.
|
||||
:param exchange_name: Name of the exchange.
|
||||
|
||||
Example:
|
||||
--------
|
||||
df = pd.DataFrame({
|
||||
'open_time': [unix_time_millis(dt.datetime.utcnow())],
|
||||
'open': [1.0],
|
||||
'high': [1.0],
|
||||
'low': [1.0],
|
||||
'close': [1.0],
|
||||
'volume': [1.0]
|
||||
})
|
||||
db._insert_candles_into_db(df, 'test_table', 'BTC/USDT', 'binance')
|
||||
"""
|
||||
market_id = self._fetch_market_id(symbol, exchange_name)
|
||||
candlesticks.insert(0, 'market_id', market_id)
|
||||
|
||||
# Check if 'market_id' column already exists in the DataFrame
|
||||
if 'market_id' in candlesticks.columns:
|
||||
# If it exists, set its value to the fetched market_id
|
||||
candlesticks['market_id'] = market_id
|
||||
else:
|
||||
# If it doesn't exist, insert it as the first column
|
||||
candlesticks.insert(0, 'market_id', market_id)
|
||||
|
||||
sql_create = f"""
|
||||
CREATE TABLE IF NOT EXISTS '{table_name}' (
|
||||
id INTEGER PRIMARY KEY,
|
||||
|
|
@ -332,21 +263,19 @@ class Database:
|
|||
:param st: Start datetime.
|
||||
:param et: End datetime (optional).
|
||||
:return: DataFrame of records.
|
||||
|
||||
Example:
|
||||
--------
|
||||
records = db.get_timestamped_records('test_table', 'open_time', start_time, end_time)
|
||||
"""
|
||||
with SQLite(self.db_file) as conn:
|
||||
start_stamp = unix_time_millis(st)
|
||||
if et is not None:
|
||||
end_stamp = unix_time_millis(et)
|
||||
q_str = (
|
||||
f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= {start_stamp} "
|
||||
f"AND {timestamp_field} <= {end_stamp};"
|
||||
f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= ? "
|
||||
f"AND {timestamp_field} <= ?;"
|
||||
)
|
||||
records = pd.read_sql(q_str, conn, params=(start_stamp, end_stamp))
|
||||
else:
|
||||
q_str = f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= {start_stamp};"
|
||||
records = pd.read_sql(q_str, conn)
|
||||
records = records.drop('id', axis=1)
|
||||
q_str = f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= ?;"
|
||||
records = pd.read_sql(q_str, conn, params=(start_stamp,))
|
||||
|
||||
# records = records.drop('id', axis=1) Todo: Reminder I may need to put this back later.
|
||||
return records
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import ccxt
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Tuple, Dict, List, Union, Any
|
||||
import time
|
||||
import logging
|
||||
|
|
@ -84,6 +84,8 @@ class Exchange:
|
|||
Returns:
|
||||
int: The Unix timestamp in milliseconds.
|
||||
"""
|
||||
if dt.tzinfo is None:
|
||||
raise ValueError("datetime object must be timezone-aware or in UTC.")
|
||||
return int(dt.timestamp() * 1000)
|
||||
|
||||
def _fetch_historical_klines(self, symbol: str, interval: str,
|
||||
|
|
@ -103,6 +105,12 @@ class Exchange:
|
|||
if end_dt is None:
|
||||
end_dt = datetime.utcnow()
|
||||
|
||||
# Convert start_dt and end_dt to UTC if they are naive
|
||||
if start_dt.tzinfo is None:
|
||||
start_dt = start_dt.replace(tzinfo=timezone.utc)
|
||||
if end_dt.tzinfo is None:
|
||||
end_dt = end_dt.replace(tzinfo=timezone.utc)
|
||||
|
||||
max_interval = timedelta(days=200)
|
||||
data_frames = []
|
||||
current_start = start_dt
|
||||
|
|
@ -122,7 +130,7 @@ class Exchange:
|
|||
|
||||
df_columns = ['open_time', 'open', 'high', 'low', 'close', 'volume']
|
||||
candles_df = pd.DataFrame(candles, columns=df_columns)
|
||||
candles_df['open_time'] = candles_df['open_time'] // 1000
|
||||
|
||||
data_frames.append(candles_df)
|
||||
|
||||
current_start = current_end
|
||||
|
|
@ -515,39 +523,41 @@ class Exchange:
|
|||
return []
|
||||
|
||||
|
||||
# Usage Examples
|
||||
|
||||
# Example 1: Initializing the Exchange class
|
||||
api_keys = {
|
||||
'key': 'your_api_key',
|
||||
'secret': 'your_api_secret'
|
||||
}
|
||||
exchange = Exchange(name='Binance', api_keys=api_keys, exchange_id='binance')
|
||||
|
||||
# Example 2: Fetching historical data
|
||||
start_date = datetime(2022, 1, 1)
|
||||
end_date = datetime(2022, 6, 1)
|
||||
historical_data = exchange.get_historical_klines(symbol='BTC/USDT', interval='1d',
|
||||
start_dt=start_date, end_dt=end_date)
|
||||
print(historical_data)
|
||||
|
||||
# Example 3: Fetching the current price of a symbol
|
||||
current_price = exchange.get_price(symbol='BTC/USDT')
|
||||
print(f"Current price of BTC/USDT: {current_price}")
|
||||
|
||||
# Example 4: Placing a limit buy order
|
||||
order_result, order_details = exchange.place_order(symbol='BTC/USDT', side='buy', type='limit',
|
||||
timeInForce='GTC', quantity=0.001, price=30000)
|
||||
print(order_result, order_details)
|
||||
|
||||
# Example 5: Getting account balances
|
||||
balances = exchange.get_balances()
|
||||
print(balances)
|
||||
|
||||
# Example 6: Fetching open orders
|
||||
open_orders = exchange.get_open_orders()
|
||||
print(open_orders)
|
||||
|
||||
# Example 7: Fetching active trades
|
||||
active_trades = exchange.get_active_trades()
|
||||
print(active_trades)
|
||||
#
|
||||
# # Usage Examples
|
||||
#
|
||||
# # Example 1: Initializing the Exchange class
|
||||
# api_keys = {
|
||||
# 'key': 'your_api_key',
|
||||
# 'secret': 'your_api_secret'
|
||||
# }
|
||||
# exchange = Exchange(name='Binance', api_keys=api_keys, exchange_id='binance')
|
||||
#
|
||||
# # Example 2: Fetching historical data
|
||||
# start_date = datetime(2022, 1, 1)
|
||||
# end_date = datetime(2022, 6, 1)
|
||||
# historical_data = exchange.get_historical_klines(symbol='BTC/USDT', interval='1d',
|
||||
# start_dt=start_date, end_dt=end_date)
|
||||
# print(historical_data)
|
||||
#
|
||||
# # Example 3: Fetching the current price of a symbol
|
||||
# current_price = exchange.get_price(symbol='BTC/USDT')
|
||||
# print(f"Current price of BTC/USDT: {current_price}")
|
||||
#
|
||||
# # Example 4: Placing a limit buy order
|
||||
# order_result, order_details = exchange.place_order(symbol='BTC/USDT', side='buy', type='limit',
|
||||
# timeInForce='GTC', quantity=0.001, price=30000)
|
||||
# print(order_result, order_details)
|
||||
#
|
||||
# # Example 5: Getting account balances
|
||||
# balances = exchange.get_balances()
|
||||
# print(balances)
|
||||
#
|
||||
# # Example 6: Fetching open orders
|
||||
# open_orders = exchange.get_open_orders()
|
||||
# print(open_orders)
|
||||
#
|
||||
# # Example 7: Fetching active trades
|
||||
# active_trades = exchange.get_active_trades()
|
||||
# print(active_trades)
|
||||
#
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
import logging
|
||||
import json
|
||||
from typing import List, Any, Dict
|
||||
import pandas as pd
|
||||
import requests
|
||||
import ccxt
|
||||
from Exchange import Exchange
|
||||
|
||||
|
|
@ -46,7 +44,7 @@ class ExchangeInterface:
|
|||
self.add_exchange(user_name, exchange)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to connect user '{user_name}' to exchange '{exchange_name}': {str(e)}")
|
||||
logger.error(f"Failed to connect user '{user_name}' to exchange '{exchange_name}': {str(e)}")
|
||||
return False
|
||||
|
||||
def add_exchange(self, user_name: str, exchange: Exchange):
|
||||
|
|
@ -60,7 +58,7 @@ class ExchangeInterface:
|
|||
row = {'user': user_name, 'name': exchange.name, 'reference': exchange, 'balances': exchange.balances}
|
||||
self.exchange_data = add_row(self.exchange_data, row)
|
||||
except Exception as e:
|
||||
logging.error(f"Couldn't create an instance of the exchange! {str(e)}")
|
||||
logger.error(f"Couldn't create an instance of the exchange! {str(e)}")
|
||||
raise
|
||||
|
||||
def get_exchange(self, ename: str, uname: str) -> Exchange:
|
||||
|
|
@ -144,12 +142,12 @@ class ExchangeInterface:
|
|||
elif fetch_type == 'orders':
|
||||
data = reference.get_open_orders()
|
||||
else:
|
||||
logging.error(f"Invalid fetch type: {fetch_type}")
|
||||
logger.error(f"Invalid fetch type: {fetch_type}")
|
||||
return {}
|
||||
|
||||
data_dict[name] = data
|
||||
except Exception as e:
|
||||
logging.error(f"Error retrieving data for {name}: {str(e)}")
|
||||
logger.error(f"Error retrieving data for {name}: {str(e)}")
|
||||
|
||||
return data_dict
|
||||
|
||||
|
|
|
|||
|
|
@ -43,12 +43,8 @@ class Candles:
|
|||
# Calculate the approximate start_datetime the first of n record will have.
|
||||
start_datetime = ts_of_n_minutes_ago(n=num_candles, candle_length=minutes_per_candle)
|
||||
|
||||
# Table name format is: <symbol>_<timeframe>_<exchange_name>. Example: "BTCUSDT_15m_binance_spot"
|
||||
key = f'{asset}_{timeframe}_{exchange}'
|
||||
|
||||
# Fetch records older than start_datetime.
|
||||
candles = self.data.get_records_since(key=key, start_datetime=start_datetime,
|
||||
record_length=minutes_per_candle,
|
||||
candles = self.data.get_records_since(start_datetime=start_datetime,
|
||||
ex_details=[asset, timeframe, exchange, user_name])
|
||||
if len(candles.index) < num_candles:
|
||||
timesince = dt.datetime.utcnow() - start_datetime
|
||||
|
|
|
|||
|
|
@ -1,20 +1,38 @@
|
|||
import ccxt
|
||||
import pandas as pd
|
||||
import datetime
|
||||
|
||||
|
||||
def main():
|
||||
# Create an instance of the Binance exchange
|
||||
binance = ccxt.binance({
|
||||
'enableRateLimit': True,
|
||||
'verbose': False, # Ensure verbose mode is disabled
|
||||
})
|
||||
def fetch_and_print_ohlcv(exchange_name, symbol, timeframe, since, limit=5):
|
||||
# Initialize the exchange
|
||||
exchange_class = getattr(ccxt, exchange_name)
|
||||
exchange = exchange_class()
|
||||
|
||||
try:
|
||||
# Load markets to test the connection
|
||||
markets = binance.load_markets()
|
||||
print("Markets loaded successfully")
|
||||
except ccxt.BaseError as e:
|
||||
print(f"Error loading markets: {str(e)}")
|
||||
# Fetch historical candlestick data with a limit
|
||||
ohlcv = exchange.fetch_ohlcv(symbol, timeframe, since=since, limit=limit)
|
||||
|
||||
# Convert to DataFrame for better readability
|
||||
df = pd.DataFrame(ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
|
||||
|
||||
# Print the first few rows of the DataFrame
|
||||
print("First few rows of the fetched OHLCV data:")
|
||||
print(df.head())
|
||||
|
||||
# Print the timestamps in human-readable format
|
||||
df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms')
|
||||
print("\nFirst few timestamps in human-readable format:")
|
||||
print(df[['timestamp', 'datetime']].head())
|
||||
|
||||
# Confirm the format of the timestamps
|
||||
print("\nTimestamp format confirmation:")
|
||||
for ts in df['timestamp']:
|
||||
print(f"{ts} (milliseconds since Unix epoch)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# Example usage
|
||||
exchange_name = 'binance' # Change this to your exchange
|
||||
symbol = 'BTC/USDT'
|
||||
timeframe = '5m'
|
||||
since = int((datetime.datetime(2024, 8, 1) - datetime.datetime(1970, 1, 1)).total_seconds() * 1000)
|
||||
|
||||
fetch_and_print_ohlcv(exchange_name, symbol, timeframe, since, limit=5)
|
||||
|
|
|
|||
|
|
@ -36,19 +36,37 @@ def query_uptodate(records: pd.DataFrame, r_length_min: float) -> Union[float, N
|
|||
tolerance_minutes = 10 / 60 # 10 seconds tolerance in minutes
|
||||
if minutes_since_update > (r_length_min - tolerance_minutes):
|
||||
# Return the last timestamp in seconds
|
||||
return ms_to_seconds(last_timestamp)
|
||||
return last_timestamp
|
||||
return None
|
||||
|
||||
|
||||
def ms_to_seconds(timestamp):
|
||||
def ms_to_seconds(timestamp: float) -> float:
|
||||
"""
|
||||
Converts milliseconds to seconds.
|
||||
|
||||
:param timestamp: The timestamp in milliseconds.
|
||||
:return: The timestamp in seconds.
|
||||
"""
|
||||
return timestamp / 1000
|
||||
|
||||
|
||||
def unix_time_seconds(d_time):
|
||||
def unix_time_seconds(d_time: dt.datetime) -> float:
|
||||
"""
|
||||
Converts a datetime object to Unix timestamp in seconds.
|
||||
|
||||
:param d_time: The datetime object to convert.
|
||||
:return: The Unix timestamp in seconds.
|
||||
"""
|
||||
return (d_time - epoch).total_seconds()
|
||||
|
||||
|
||||
def unix_time_millis(d_time):
|
||||
def unix_time_millis(d_time: dt.datetime) -> float:
|
||||
"""
|
||||
Converts a datetime object to Unix timestamp in milliseconds.
|
||||
|
||||
:param d_time: The datetime object to convert.
|
||||
:return: The Unix timestamp in milliseconds.
|
||||
"""
|
||||
return (d_time - epoch).total_seconds() * 1000.0
|
||||
|
||||
|
||||
|
|
@ -72,6 +90,10 @@ def query_satisfied(start_datetime: dt.datetime, records: pd.DataFrame, r_length
|
|||
start_timestamp = unix_time_millis(start_datetime)
|
||||
print(f'Start timestamp: {start_timestamp}')
|
||||
|
||||
if records.empty:
|
||||
print('No records found. Query cannot be satisfied.')
|
||||
return float('nan')
|
||||
|
||||
# Get the oldest timestamp from the records passed in
|
||||
first_timestamp = float(records.open_time.min())
|
||||
print(f'First timestamp in records: {first_timestamp}')
|
||||
|
|
@ -84,17 +106,17 @@ def query_satisfied(start_datetime: dt.datetime, records: pd.DataFrame, r_length
|
|||
if start_timestamp <= first_timestamp + total_duration:
|
||||
return None
|
||||
|
||||
return first_timestamp / 1000 # Return in seconds
|
||||
return first_timestamp
|
||||
|
||||
|
||||
@lru_cache(maxsize=500)
|
||||
def ts_of_n_minutes_ago(n: int, candle_length: float) -> dt.datetime.timestamp:
|
||||
def ts_of_n_minutes_ago(n: int, candle_length: float) -> dt.datetime:
|
||||
"""
|
||||
Returns the approximate datetime for the start of a candle that was 'n' candles ago.
|
||||
|
||||
:param n: int - The number of candles ago to calculate.
|
||||
:param candle_length: float - The length of each candle in minutes.
|
||||
:return: datetime - The approximate datetime for the start of the 'n'-th candle ago.
|
||||
:param n: The number of candles ago to calculate.
|
||||
:param candle_length: The length of each candle in minutes.
|
||||
:return: The approximate datetime for the start of the 'n'-th candle ago.
|
||||
"""
|
||||
# Increment 'n' by 1 to ensure we account for the time that has passed since the last candle closed.
|
||||
n += 1
|
||||
|
|
@ -113,12 +135,12 @@ def ts_of_n_minutes_ago(n: int, candle_length: float) -> dt.datetime.timestamp:
|
|||
|
||||
|
||||
@lru_cache(maxsize=20)
|
||||
def timeframe_to_minutes(timeframe):
|
||||
def timeframe_to_minutes(timeframe: str) -> int:
|
||||
"""
|
||||
Converts a string representing a timeframe into an integer representing the approximate minutes.
|
||||
|
||||
:param timeframe: str - Timeframe format is [multiplier:focus]. eg '15m', '4h', '1d'
|
||||
:return: int - Minutes the timeframe represents ex. '2h'-> 120(minutes).
|
||||
:param timeframe: Timeframe format is [multiplier:focus]. e.g., '15m', '4h', '1d'
|
||||
:return: Minutes the timeframe represents, e.g., '2h' -> 120 (minutes).
|
||||
"""
|
||||
# Extract the numerical part of the timeframe param.
|
||||
digits = int("".join([i if i.isdigit() else "" for i in timeframe]))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,478 @@
|
|||
from DataCache_v2 import DataCache
|
||||
from ExchangeInterface import ExchangeInterface
|
||||
import unittest
|
||||
import pandas as pd
|
||||
import datetime as dt
|
||||
import os
|
||||
from Database import SQLite, Database
|
||||
|
||||
|
||||
class DataGenerator:
|
||||
def __init__(self, timeframe_str):
|
||||
"""
|
||||
Initialize the DataGenerator with a timeframe string like '2h', '5m', '1d', '1w', '1M', or '1y'.
|
||||
"""
|
||||
# Initialize attributes with placeholder values
|
||||
self.timeframe_amount = None
|
||||
self.timeframe_unit = None
|
||||
# Set the actual timeframe
|
||||
self.set_timeframe(timeframe_str)
|
||||
|
||||
def set_timeframe(self, timeframe_str):
|
||||
"""
|
||||
Set the timeframe unit and amount based on a string like '2h', '5m', '1d', '1w', '1M', or '1y'.
|
||||
"""
|
||||
self.timeframe_amount = int(timeframe_str[:-1])
|
||||
unit = timeframe_str[-1]
|
||||
|
||||
if unit == 's':
|
||||
self.timeframe_unit = 'seconds'
|
||||
elif unit == 'm':
|
||||
self.timeframe_unit = 'minutes'
|
||||
elif unit == 'h':
|
||||
self.timeframe_unit = 'hours'
|
||||
elif unit == 'd':
|
||||
self.timeframe_unit = 'days'
|
||||
elif unit == 'w':
|
||||
self.timeframe_unit = 'weeks'
|
||||
elif unit == 'M':
|
||||
self.timeframe_unit = 'months'
|
||||
elif unit == 'Y':
|
||||
self.timeframe_unit = 'years'
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported timeframe unit. Use 's,m,h,d,w,M,Y'.")
|
||||
|
||||
def create_table(self, num_rec=None, start=None, end=None):
|
||||
"""
|
||||
Create a table with simulated data. If both start and end are provided, num_rec is derived from the interval.
|
||||
If neither are provided the table will have num_rec and end at the current time.
|
||||
|
||||
Parameters:
|
||||
num_rec (int, optional): The number of records to generate.
|
||||
start (datetime, optional): The start time for the first record.
|
||||
end (datetime, optional): The end time for the last record.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: A DataFrame with the simulated data.
|
||||
"""
|
||||
# If neither start nor end are provided.
|
||||
if start is None and end is None:
|
||||
end = dt.datetime.utcnow()
|
||||
if num_rec is None:
|
||||
raise ValueError("num_rec must be provided if both start and end are not specified.")
|
||||
|
||||
# If only start is provided.
|
||||
if start is not None and end is not None:
|
||||
total_duration = (end - start).total_seconds()
|
||||
interval_seconds = self.timeframe_amount * self._get_seconds_per_unit(self.timeframe_unit)
|
||||
num_rec = int(total_duration // interval_seconds) + 1
|
||||
|
||||
# If only end is provided.
|
||||
if end is not None and start is None:
|
||||
if num_rec is None:
|
||||
raise ValueError("num_rec must be provided if both start and end are not specified.")
|
||||
interval_seconds = self.timeframe_amount * self._get_seconds_per_unit(self.timeframe_unit)
|
||||
start = end - dt.timedelta(seconds=(num_rec - 1) * interval_seconds)
|
||||
|
||||
# Ensure start is aligned to the timeframe interval
|
||||
start = self.round_down_datetime(start, self.timeframe_unit[0], self.timeframe_amount)
|
||||
|
||||
# Generate times
|
||||
times = [self.unix_time_millis(start + self._delta(i)) for i in range(num_rec)]
|
||||
|
||||
df = pd.DataFrame({
|
||||
'market_id': 1,
|
||||
'open_time': times,
|
||||
'open': [100 + i for i in range(num_rec)],
|
||||
'high': [110 + i for i in range(num_rec)],
|
||||
'low': [90 + i for i in range(num_rec)],
|
||||
'close': [105 + i for i in range(num_rec)],
|
||||
'volume': [1000 + i for i in range(num_rec)]
|
||||
})
|
||||
|
||||
return df
|
||||
|
||||
@staticmethod
|
||||
def _get_seconds_per_unit(unit):
|
||||
"""Helper method to convert timeframe units to seconds."""
|
||||
units_in_seconds = {
|
||||
'seconds': 1,
|
||||
'minutes': 60,
|
||||
'hours': 3600,
|
||||
'days': 86400,
|
||||
'weeks': 604800,
|
||||
'months': 2592000, # Assuming 30 days per month
|
||||
'years': 31536000 # Assuming 365 days per year
|
||||
}
|
||||
if unit not in units_in_seconds:
|
||||
raise ValueError(f"Unsupported timeframe unit: {unit}")
|
||||
return units_in_seconds[unit]
|
||||
|
||||
def generate_incomplete_data(self, query_offset, num_rec=5):
|
||||
"""
|
||||
Generate data that is incomplete, i.e., starts before the query but doesn't fully satisfy it.
|
||||
"""
|
||||
query_start_time = self.x_time_ago(query_offset)
|
||||
start_time_for_data = self.get_start_time(query_start_time)
|
||||
return self.create_table(num_rec, start=start_time_for_data)
|
||||
|
||||
@staticmethod
|
||||
def generate_missing_section(df, drop_start=5, drop_end=8):
|
||||
"""
|
||||
Generate data with a missing section.
|
||||
"""
|
||||
df = df.drop(df.index[drop_start:drop_end]).reset_index(drop=True)
|
||||
return df
|
||||
|
||||
def get_start_time(self, query_start_time):
|
||||
margin = 2
|
||||
delta_args = {self.timeframe_unit: margin * self.timeframe_amount}
|
||||
return query_start_time - dt.timedelta(**delta_args)
|
||||
|
||||
def x_time_ago(self, offset):
|
||||
"""
|
||||
Returns a datetime object representing the current time minus the offset in the specified units.
|
||||
"""
|
||||
delta_args = {self.timeframe_unit: offset}
|
||||
return dt.datetime.utcnow() - dt.timedelta(**delta_args)
|
||||
|
||||
def _delta(self, i):
|
||||
"""
|
||||
Returns a timedelta object for the ith increment based on the timeframe unit and amount.
|
||||
"""
|
||||
delta_args = {self.timeframe_unit: i * self.timeframe_amount}
|
||||
return dt.timedelta(**delta_args)
|
||||
|
||||
@staticmethod
|
||||
def unix_time_millis(dt_obj):
|
||||
"""
|
||||
Convert a datetime object to Unix time in milliseconds.
|
||||
"""
|
||||
epoch = dt.datetime(1970, 1, 1)
|
||||
return int((dt_obj - epoch).total_seconds() * 1000)
|
||||
|
||||
@staticmethod
|
||||
def round_down_datetime(dt_obj: dt.datetime, unit: str, interval: int) -> dt.datetime:
|
||||
if unit == 's': # Round down to the nearest interval of seconds
|
||||
seconds = (dt_obj.second // interval) * interval
|
||||
dt_obj = dt_obj.replace(second=seconds, microsecond=0)
|
||||
elif unit == 'm': # Round down to the nearest interval of minutes
|
||||
minutes = (dt_obj.minute // interval) * interval
|
||||
dt_obj = dt_obj.replace(minute=minutes, second=0, microsecond=0)
|
||||
elif unit == 'h': # Round down to the nearest interval of hours
|
||||
hours = (dt_obj.hour // interval) * interval
|
||||
dt_obj = dt_obj.replace(hour=hours, minute=0, second=0, microsecond=0)
|
||||
elif unit == 'd': # Round down to the nearest interval of days
|
||||
days = (dt_obj.day // interval) * interval
|
||||
dt_obj = dt_obj.replace(day=days, hour=0, minute=0, second=0, microsecond=0)
|
||||
elif unit == 'w': # Round down to the nearest interval of weeks
|
||||
dt_obj -= dt.timedelta(days=dt_obj.weekday() % (interval * 7))
|
||||
dt_obj = dt_obj.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
elif unit == 'M': # Round down to the nearest interval of months
|
||||
months = ((dt_obj.month - 1) // interval) * interval + 1
|
||||
dt_obj = dt_obj.replace(month=months, day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
elif unit == 'y': # Round down to the nearest interval of years
|
||||
years = (dt_obj.year // interval) * interval
|
||||
dt_obj = dt_obj.replace(year=years, month=1, day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
return dt_obj
|
||||
|
||||
|
||||
class TestDataCacheV2(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Set up database and exchanges
|
||||
self.exchanges = ExchangeInterface()
|
||||
self.exchanges.connect_exchange(exchange_name='binance', user_name='test_guy', api_keys=None)
|
||||
self.db_file = 'test_db.sqlite'
|
||||
self.database = Database(db_file=self.db_file)
|
||||
|
||||
# Create necessary tables
|
||||
sql_create_table_1 = f"""
|
||||
CREATE TABLE IF NOT EXISTS test_table (
|
||||
id INTEGER PRIMARY KEY,
|
||||
market_id INTEGER,
|
||||
open_time INTEGER UNIQUE ON CONFLICT IGNORE,
|
||||
open REAL NOT NULL,
|
||||
high REAL NOT NULL,
|
||||
low REAL NOT NULL,
|
||||
close REAL NOT NULL,
|
||||
volume REAL NOT NULL,
|
||||
FOREIGN KEY (market_id) REFERENCES market (id)
|
||||
)"""
|
||||
sql_create_table_2 = """
|
||||
CREATE TABLE IF NOT EXISTS exchange (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT UNIQUE
|
||||
)"""
|
||||
sql_create_table_3 = """
|
||||
CREATE TABLE IF NOT EXISTS markets (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
symbol TEXT,
|
||||
exchange_id INTEGER,
|
||||
FOREIGN KEY (exchange_id) REFERENCES exchange(id)
|
||||
)"""
|
||||
with SQLite(db_file=self.db_file) as con:
|
||||
con.execute(sql_create_table_1)
|
||||
con.execute(sql_create_table_2)
|
||||
con.execute(sql_create_table_3)
|
||||
self.data = DataCache(self.exchanges)
|
||||
self.data.db = self.database
|
||||
|
||||
self.ex_details = ['BTC/USD', '2h', 'binance', 'test_guy']
|
||||
self.key = f'{self.ex_details[0]}_{self.ex_details[1]}_{self.ex_details[2]}'
|
||||
|
||||
def tearDown(self):
|
||||
if os.path.exists(self.db_file):
|
||||
os.remove(self.db_file)
|
||||
|
||||
def test_set_cache(self):
|
||||
print('\nTesting set_cache() method without no-overwrite flag:')
|
||||
self.data.set_cache(data='data', key=self.key)
|
||||
attr = self.data.__getattribute__('cached_data')
|
||||
self.assertEqual(attr[self.key], 'data')
|
||||
print(' - Set cache without no-overwrite flag passed.')
|
||||
|
||||
print('Testing set_cache() once again with new data without no-overwrite flag:')
|
||||
self.data.set_cache(data='more_data', key=self.key)
|
||||
attr = self.data.__getattribute__('cached_data')
|
||||
self.assertEqual(attr[self.key], 'more_data')
|
||||
print(' - Set cache with new data without no-overwrite flag passed.')
|
||||
|
||||
print('Testing set_cache() method once again with more data with no-overwrite flag set:')
|
||||
self.data.set_cache(data='even_more_data', key=self.key, do_not_overwrite=True)
|
||||
attr = self.data.__getattribute__('cached_data')
|
||||
self.assertEqual(attr[self.key], 'more_data')
|
||||
print(' - Set cache with no-overwrite flag passed.')
|
||||
|
||||
def test_cache_exists(self):
|
||||
print('Testing cache_exists() method:')
|
||||
self.assertFalse(self.data.cache_exists(key=self.key))
|
||||
print(' - Check for non-existent cache passed.')
|
||||
|
||||
self.data.set_cache(data='data', key=self.key)
|
||||
self.assertTrue(self.data.cache_exists(key=self.key))
|
||||
print(' - Check for existent cache passed.')
|
||||
|
||||
def test_update_candle_cache(self):
|
||||
print('Testing update_candle_cache() method:')
|
||||
|
||||
# Initialize the DataGenerator with the 5-minute timeframe
|
||||
data_gen = DataGenerator('5m')
|
||||
|
||||
# Create initial DataFrame and insert into cache
|
||||
df_initial = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 0, 0))
|
||||
print(f'Inserting this table into cache:\n{df_initial}\n')
|
||||
self.data.set_cache(data=df_initial, key=self.key)
|
||||
|
||||
# Create new DataFrame to be added to cache
|
||||
df_new = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 15, 0))
|
||||
print(f'Updating cache with this table:\n{df_new}\n')
|
||||
self.data.update_candle_cache(more_records=df_new, key=self.key)
|
||||
|
||||
# Retrieve the resulting DataFrame from cache
|
||||
result = self.data.get_cache(key=self.key)
|
||||
print(f'The resulting table in cache is:\n{result}\n')
|
||||
|
||||
# Create the expected DataFrame
|
||||
expected = data_gen.create_table(num_rec=6, start=dt.datetime(2024, 8, 9, 0, 0, 0))
|
||||
print(f'The expected open_time values are:\n{expected["open_time"].tolist()}\n')
|
||||
|
||||
# Assert that the open_time values in the result match those in the expected DataFrame, in order
|
||||
assert result['open_time'].tolist() == expected['open_time'].tolist(), \
|
||||
f"open_time values in result are {result['open_time'].tolist()}" \
|
||||
f" but expected {expected['open_time'].tolist()}"
|
||||
|
||||
print(f'The results open_time values match:\n{result["open_time"].tolist()}\n')
|
||||
print(' - Update cache with new records passed.')
|
||||
|
||||
def test_update_cached_dict(self):
|
||||
print('Testing update_cached_dict() method:')
|
||||
self.data.set_cache(data={}, key=self.key)
|
||||
self.data.update_cached_dict(cache_key=self.key, dict_key='sub_key', data='value')
|
||||
|
||||
cache = self.data.get_cache(key=self.key)
|
||||
self.assertEqual(cache['sub_key'], 'value')
|
||||
print(' - Update dictionary in cache passed.')
|
||||
|
||||
def test_get_cache(self):
|
||||
print('Testing get_cache() method:')
|
||||
self.data.set_cache(data='data', key=self.key)
|
||||
result = self.data.get_cache(key=self.key)
|
||||
self.assertEqual(result, 'data')
|
||||
print(' - Retrieve cache passed.')
|
||||
|
||||
def _test_get_records_since(self, set_cache=True, set_db=True, query_offset=None, num_rec=None, ex_details=None,
|
||||
simulate_scenarios=None):
|
||||
"""
|
||||
Test the get_records_since() method by generating a table of simulated data,
|
||||
inserting it into cache and/or database, and then querying the records.
|
||||
|
||||
Parameters:
|
||||
set_cache (bool): If True, the generated table is inserted into the cache.
|
||||
set_db (bool): If True, the generated table is inserted into the database.
|
||||
query_offset (int, optional): The offset in the timeframe units for the query.
|
||||
num_rec (int, optional): The number of records to generate in the simulated table.
|
||||
ex_details (list, optional): Exchange details to generate the cache key.
|
||||
simulate_scenarios (str, optional): The type of scenario to simulate. Options are:
|
||||
- 'not_enough_data': The table data doesn't go far enough back.
|
||||
- 'incomplete_data': The table doesn't have enough records to satisfy the query.
|
||||
- 'missing_section': The table has missing records in the middle.
|
||||
"""
|
||||
|
||||
print('Testing get_records_since() method:')
|
||||
|
||||
# Use provided ex_details or fallback to the class attribute.
|
||||
ex_details = ex_details or self.ex_details
|
||||
# Generate a cache/database key using exchange details.
|
||||
key = f'{ex_details[0]}_{ex_details[1]}_{ex_details[2]}'
|
||||
|
||||
# Set default number of records if not provided.
|
||||
num_rec = num_rec or 12
|
||||
table_timeframe = ex_details[1] # Extract timeframe from exchange details.
|
||||
|
||||
# Initialize DataGenerator with the given timeframe.
|
||||
data_gen = DataGenerator(table_timeframe)
|
||||
|
||||
if simulate_scenarios == 'not_enough_data':
|
||||
# Set query_offset to a time earlier than the start of the table data.
|
||||
query_offset = (num_rec + 5) * data_gen.timeframe_amount
|
||||
else:
|
||||
# Default to querying for 1 record length less than the table duration.
|
||||
query_offset = query_offset or (num_rec - 1) * data_gen.timeframe_amount
|
||||
|
||||
if simulate_scenarios == 'incomplete_data':
|
||||
# Set start time to generate fewer records than required.
|
||||
start_time_for_data = data_gen.x_time_ago(num_rec * data_gen.timeframe_amount)
|
||||
num_rec = 5 # Set a smaller number of records to simulate incomplete data.
|
||||
else:
|
||||
# No specific start time for data generation.
|
||||
start_time_for_data = None
|
||||
|
||||
# Create the initial data table.
|
||||
df_initial = data_gen.create_table(num_rec, start=start_time_for_data)
|
||||
|
||||
if simulate_scenarios == 'missing_section':
|
||||
# Simulate missing section in the data by dropping records.
|
||||
df_initial = data_gen.generate_missing_section(df_initial, drop_start=2, drop_end=5)
|
||||
|
||||
# Convert 'open_time' to datetime for better readability.
|
||||
temp_df = df_initial.copy()
|
||||
temp_df['open_time'] = pd.to_datetime(temp_df['open_time'], unit='ms')
|
||||
print(f'Table Created:\n{temp_df}')
|
||||
|
||||
if set_cache:
|
||||
# Insert the generated table into cache.
|
||||
print('Inserting table into cache.')
|
||||
self.data.set_cache(data=df_initial, key=key)
|
||||
|
||||
if set_db:
|
||||
# Insert the generated table into the database.
|
||||
print('Inserting table into database.')
|
||||
with SQLite(self.db_file) as con:
|
||||
df_initial.to_sql(key, con, if_exists='replace', index=False)
|
||||
|
||||
# Calculate the start time for querying the records.
|
||||
start_datetime = data_gen.x_time_ago(query_offset)
|
||||
# Defaults to current time if not provided to get_records_since()
|
||||
query_end_time = dt.datetime.utcnow()
|
||||
print(f'Requesting records from {start_datetime} to {query_end_time}')
|
||||
|
||||
# Query the records since the calculated start time.
|
||||
result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details)
|
||||
|
||||
# Filter the initial data table to match the query time.
|
||||
expected = df_initial[df_initial['open_time'] >= data_gen.unix_time_millis(start_datetime)].reset_index(
|
||||
drop=True)
|
||||
temp_df = expected.copy()
|
||||
temp_df['open_time'] = pd.to_datetime(temp_df['open_time'], unit='ms')
|
||||
print(f'Expected table:\n{temp_df}')
|
||||
|
||||
# Print the result from the query for comparison.
|
||||
temp_df = result.copy()
|
||||
temp_df['open_time'] = pd.to_datetime(temp_df['open_time'], unit='ms')
|
||||
print(f'Resulting table:\n{temp_df}')
|
||||
|
||||
if simulate_scenarios in ['not_enough_data', 'incomplete_data', 'missing_section']:
|
||||
# Check that the result has more rows than the expected incomplete data.
|
||||
assert result.shape[0] > expected.shape[
|
||||
0], "Result has fewer or equal rows compared to the incomplete data."
|
||||
print("\nThe returned DataFrames has filled in the missing data!")
|
||||
else:
|
||||
# Ensure the result and expected dataframes match in shape and content.
|
||||
assert result.shape == expected.shape, f"Shape mismatch: {result.shape} vs {expected.shape}"
|
||||
pd.testing.assert_series_equal(result['open_time'], expected['open_time'], check_dtype=False)
|
||||
print("\nThe DataFrames have the same shape and the 'open_time' columns match.")
|
||||
|
||||
# Verify that the oldest timestamp in the result is within the allowed time difference.
|
||||
oldest_timestamp = pd.to_datetime(result['open_time'].min(), unit='ms')
|
||||
time_diff = oldest_timestamp - start_datetime
|
||||
max_allowed_time_diff = dt.timedelta(**{data_gen.timeframe_unit: data_gen.timeframe_amount})
|
||||
|
||||
assert dt.timedelta(0) <= time_diff <= max_allowed_time_diff, \
|
||||
f"Oldest timestamp {oldest_timestamp} is not within " \
|
||||
f"{data_gen.timeframe_amount} {data_gen.timeframe_unit} of {start_datetime}"
|
||||
|
||||
print(f'The first timestamp is {time_diff} from {start_datetime}')
|
||||
|
||||
# Verify that the newest timestamp in the result is within the allowed time difference.
|
||||
newest_timestamp = pd.to_datetime(result['open_time'].max(), unit='ms')
|
||||
time_diff_end = abs(query_end_time - newest_timestamp)
|
||||
|
||||
assert dt.timedelta(0) <= time_diff_end <= max_allowed_time_diff, \
|
||||
f"Newest timestamp {newest_timestamp} is not within {data_gen.timeframe_amount} " \
|
||||
f"{data_gen.timeframe_unit} of {query_end_time}"
|
||||
|
||||
print(f'The last timestamp is {time_diff_end} from {query_end_time}')
|
||||
print(' - Fetch records within the specified time range passed.')
|
||||
|
||||
def test_get_records_since(self):
|
||||
print('\nTest get_records_since with records set in cache')
|
||||
self._test_get_records_since()
|
||||
|
||||
print('\nTest get_records_since with records not in cache')
|
||||
self._test_get_records_since(set_cache=False)
|
||||
|
||||
print('\nTest get_records_since with records not in database')
|
||||
self._test_get_records_since(set_cache=False, set_db=False)
|
||||
|
||||
print('\nTest get_records_since with a different timeframe')
|
||||
self._test_get_records_since(query_offset=None, num_rec=None,
|
||||
ex_details=['BTC/USD', '15m', 'binance', 'test_guy'])
|
||||
|
||||
print('\nTest get_records_since where data does not go far enough back')
|
||||
self._test_get_records_since(simulate_scenarios='not_enough_data')
|
||||
|
||||
print('\nTest get_records_since with incomplete data')
|
||||
self._test_get_records_since(simulate_scenarios='incomplete_data')
|
||||
|
||||
print('\nTest get_records_since with missing section in data')
|
||||
self._test_get_records_since(simulate_scenarios='missing_section')
|
||||
|
||||
def test_populate_db(self):
|
||||
print('Testing _populate_db() method:')
|
||||
# Create a table of candle records.
|
||||
data_gen = DataGenerator(self.ex_details[1])
|
||||
data = data_gen.create_table(num_rec=5)
|
||||
|
||||
self.data._populate_db(ex_details=self.ex_details, data=data)
|
||||
|
||||
with SQLite(self.db_file) as con:
|
||||
result = pd.read_sql(f'SELECT * FROM "{self.key}"', con)
|
||||
self.assertFalse(result.empty)
|
||||
print(' - Populate database with data passed.')
|
||||
|
||||
def test_fetch_candles_from_exchange(self):
|
||||
print('Testing _fetch_candles_from_exchange() method:')
|
||||
start_time = dt.datetime.utcnow() - dt.timedelta(days=1)
|
||||
end_time = dt.datetime.utcnow()
|
||||
|
||||
result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h', exchange_name='binance',
|
||||
user_name='test_guy', start_datetime=start_time,
|
||||
end_datetime=end_time)
|
||||
self.assertIsInstance(result, pd.DataFrame)
|
||||
self.assertFalse(result.empty)
|
||||
print(' - Fetch candle data from exchange passed.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue