699 lines
32 KiB
Python
699 lines
32 KiB
Python
import json
|
||
from typing import List, Any, Tuple
|
||
import pandas as pd
|
||
import datetime as dt
|
||
import logging
|
||
from shared_utilities import unix_time_millis
|
||
from Database import Database
|
||
import numpy as np
|
||
|
||
# Configure logging
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def timeframe_to_timedelta(timeframe: str) -> pd.Timedelta | pd.DateOffset:
|
||
digits = int("".join([i if i.isdigit() else "" for i in timeframe]))
|
||
unit = "".join([i if i.isalpha() else "" for i in timeframe])
|
||
|
||
if unit == 'm':
|
||
return pd.Timedelta(minutes=digits)
|
||
elif unit == 'h':
|
||
return pd.Timedelta(hours=digits)
|
||
elif unit == 'd':
|
||
return pd.Timedelta(days=digits)
|
||
elif unit == 'w':
|
||
return pd.Timedelta(weeks=digits)
|
||
elif unit == 'M':
|
||
return pd.DateOffset(months=digits)
|
||
elif unit == 'Y':
|
||
return pd.DateOffset(years=digits)
|
||
else:
|
||
raise ValueError(f"Invalid timeframe unit: {unit}")
|
||
|
||
|
||
def estimate_record_count(start_time, end_time, timeframe: str) -> int:
|
||
"""
|
||
Estimate the number of records expected between start_time and end_time based on the given timeframe.
|
||
Accepts either datetime objects or Unix timestamps in milliseconds.
|
||
"""
|
||
# Check if the input is in milliseconds (timestamp)
|
||
if isinstance(start_time, (int, float, np.integer)) and isinstance(end_time, (int, float, np.integer)):
|
||
# Convert timestamps from milliseconds to seconds for calculation
|
||
start_time = int(start_time) / 1000
|
||
end_time = int(end_time) / 1000
|
||
start_datetime = dt.datetime.utcfromtimestamp(start_time).replace(tzinfo=dt.timezone.utc)
|
||
end_datetime = dt.datetime.utcfromtimestamp(end_time).replace(tzinfo=dt.timezone.utc)
|
||
elif isinstance(start_time, dt.datetime) and isinstance(end_time, dt.datetime):
|
||
if start_time.tzinfo is None:
|
||
raise ValueError("start_time is timezone naive. Please provide a timezone-aware datetime.")
|
||
if end_time.tzinfo is None:
|
||
raise ValueError("end_time is timezone naive. Please provide a timezone-aware datetime.")
|
||
start_datetime = start_time
|
||
end_datetime = end_time
|
||
else:
|
||
raise ValueError("start_time and end_time must be either both "
|
||
"datetime objects or both Unix timestamps in milliseconds.")
|
||
|
||
delta = timeframe_to_timedelta(timeframe)
|
||
total_seconds = (end_datetime - start_datetime).total_seconds()
|
||
expected_records = total_seconds // delta.total_seconds()
|
||
return int(expected_records)
|
||
|
||
|
||
# def generate_expected_timestamps(start_datetime: dt.datetime, end_datetime: dt.datetime,
|
||
# timeframe: str) -> pd.DatetimeIndex:
|
||
# """
|
||
# What it says. Todo: confirm this is unused and archive.
|
||
#
|
||
# :param start_datetime:
|
||
# :param end_datetime:
|
||
# :param timeframe:
|
||
# :return:
|
||
# """
|
||
# if start_datetime.tzinfo is None:
|
||
# raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
||
# if end_datetime.tzinfo is None:
|
||
# raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
||
#
|
||
# delta = timeframe_to_timedelta(timeframe)
|
||
# if isinstance(delta, pd.Timedelta):
|
||
# return pd.date_range(start=start_datetime, end=end_datetime, freq=delta)
|
||
# elif isinstance(delta, pd.DateOffset):
|
||
# current = start_datetime
|
||
# timestamps = []
|
||
# while current <= end_datetime:
|
||
# timestamps.append(current)
|
||
# current += delta
|
||
# return pd.DatetimeIndex(timestamps)
|
||
|
||
|
||
class DataCache:
|
||
TYPECHECKING_ENABLED = True
|
||
|
||
def __init__(self, exchanges):
|
||
self.db = Database()
|
||
self.exchanges = exchanges
|
||
# Single DataFrame for all cached data
|
||
self.caches = {}
|
||
logger.info("DataCache initialized.")
|
||
|
||
def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None:
|
||
"""
|
||
Sets or updates an entry in the cache with the provided key. If the key already exists, the existing entry
|
||
is replaced unless `do_not_overwrite` is True. In that case, the existing entry is preserved.
|
||
|
||
Parameters:
|
||
data: The data to be cached. This can be of any type.
|
||
key: The unique key used to identify the cached data.
|
||
do_not_overwrite : The default is False, meaning that the existing entry will be replaced.
|
||
"""
|
||
if do_not_overwrite and key in self.cache['key'].values:
|
||
return
|
||
|
||
# Construct a new DataFrame row with the key and data
|
||
new_row = pd.DataFrame({'key': [key], 'data': [data]})
|
||
|
||
# If the key already exists in the cache, remove the old entry
|
||
self.cache = self.cache[self.cache['key'] != key]
|
||
|
||
# Append the new row to the cache
|
||
self.cache = pd.concat([self.cache, new_row], ignore_index=True)
|
||
|
||
print(f'Current Cache: {self.cache}')
|
||
logger.debug(f'Cache set for key: {key}')
|
||
|
||
def get_or_fetch_rows(self, table: str, filter_vals: Tuple[str, Any]) -> pd.DataFrame | None:
|
||
"""
|
||
Retrieves rows from the cache if available; otherwise, queries the database and caches the result.
|
||
|
||
:param table: Name of the database table to query.
|
||
:param filter_vals: A tuple containing the column name and the value to filter by.
|
||
:return: A DataFrame containing the requested rows, or None if no matching rows are found.
|
||
"""
|
||
# Construct a filter condition for the cache based on the table name and filter values.
|
||
cache_filter = (self.cache['table'] == table) & (self.cache[filter_vals[0]] == filter_vals[1])
|
||
cached_rows = self.cache[cache_filter]
|
||
|
||
# If the data is found in the cache, return it.
|
||
if not cached_rows.empty:
|
||
return cached_rows
|
||
|
||
# If the data is not found in the cache, query the database.
|
||
rows = self.db.get_rows_where(table, filter_vals)
|
||
if rows is not None:
|
||
# Tag the rows with the table name and add them to the cache.
|
||
rows['table'] = table
|
||
self.cache = pd.concat([self.cache, rows])
|
||
return rows
|
||
|
||
def remove_row(self, filter_vals: Tuple[str, Any], additional_filter: Tuple[str, Any] = None,
|
||
remove_from_db: bool = True, table: str = None) -> None:
|
||
"""
|
||
Removes a specific row from the cache and optionally from the database based on filter criteria.
|
||
|
||
:param filter_vals: A tuple containing the column name and the value to filter by.
|
||
:param additional_filter: An optional additional filter to apply.
|
||
:param remove_from_db: If True, also removes the row from the database. Default is True.
|
||
:param table: The name of the table from which to remove the row in the database (optional).
|
||
"""
|
||
logger.debug(
|
||
f"Removing row from cache: filter={filter_vals},"
|
||
f" additional_filter={additional_filter}, remove_from_db={remove_from_db}, table={table}")
|
||
|
||
# Construct the filter condition for the cache
|
||
cache_filter = (self.cache[filter_vals[0]] == filter_vals[1])
|
||
|
||
if additional_filter:
|
||
cache_filter = cache_filter & (self.cache[additional_filter[0]] == additional_filter[1])
|
||
|
||
# Remove the row from the cache
|
||
self.cache = self.cache.drop(self.cache[cache_filter].index)
|
||
logger.info(f"Row removed from cache: filter={filter_vals}")
|
||
|
||
if remove_from_db and table:
|
||
# Construct the SQL query to delete from the database
|
||
sql = f"DELETE FROM {table} WHERE {filter_vals[0]} = ?"
|
||
params = [filter_vals[1]]
|
||
|
||
if additional_filter:
|
||
sql += f" AND {additional_filter[0]} = ?"
|
||
params.append(additional_filter[1])
|
||
|
||
# Execute the SQL query to remove the row from the database
|
||
self.db.execute_sql(sql, params)
|
||
logger.info(
|
||
f"Row removed from database: table={table}, filter={filter_vals},"
|
||
f" additional_filter={additional_filter}")
|
||
|
||
def is_attr_taken(self, table: str, attr: str, val: Any) -> bool:
|
||
"""
|
||
Checks if a specific attribute in a table is already taken.
|
||
|
||
:param table: The name of the table to check.
|
||
:param attr: The attribute to check (e.g., 'user_name', 'email').
|
||
:param val: The value of the attribute to check.
|
||
:return: True if the attribute is already taken, False otherwise.
|
||
"""
|
||
# Fetch rows from the specified table where the attribute matches the given value
|
||
result = self.get_or_fetch_rows(table=table, filter_vals=(attr, val))
|
||
return result is not None and not result.empty
|
||
|
||
def fetch_item(self, item_name: str, table_name: str, filter_vals: Tuple[str, Any]) -> Any:
|
||
"""
|
||
Retrieves a specific item from the cache or database, caching the result if necessary.
|
||
|
||
:param item_name: The name of the column to retrieve.
|
||
:param table_name: The name of the table where the item is stored.
|
||
:param filter_vals: A tuple containing the column name and the value to filter by.
|
||
:return: The value of the requested item.
|
||
:raises ValueError: If the item is not found in either the cache or the database.
|
||
"""
|
||
# Fetch the relevant rows from the cache or database
|
||
rows = self.get_or_fetch_rows(table_name, filter_vals)
|
||
if rows is not None and not rows.empty:
|
||
# Return the specific item from the first matching row.
|
||
return rows.iloc[0][item_name]
|
||
|
||
# If the item is not found, raise an error.
|
||
raise ValueError(f"Item '{item_name}' not found in '{table_name}' where {filter_vals[0]} = {filter_vals[1]}")
|
||
|
||
def modify_item(self, table: str, filter_vals: Tuple[str, Any], field_name: str, new_data: Any) -> None:
|
||
"""
|
||
Modifies a specific field in a row within the cache and updates the database accordingly.
|
||
|
||
:param table: The name of the table where the data is stored.
|
||
:param filter_vals: A tuple containing the column name and the value to filter by.
|
||
:param field_name: The field to be updated.
|
||
:param new_data: The new data to be set.
|
||
"""
|
||
# Retrieve the row from the cache or database
|
||
row = self.get_or_fetch_rows(table, filter_vals)
|
||
|
||
if row is None or row.empty:
|
||
raise ValueError(f"Row not found in cache or database for {filter_vals[0]} = {filter_vals[1]}")
|
||
|
||
# Modify the specified field
|
||
if isinstance(new_data, str):
|
||
row.loc[0, field_name] = new_data
|
||
else:
|
||
# If new_data is not a string, it’s converted to a JSON string before being inserted into the DataFrame.
|
||
row.loc[0, field_name] = json.dumps(new_data)
|
||
|
||
# Update the cache by removing the old entry and adding the modified row
|
||
self.cache = self.cache.drop(
|
||
self.cache[(self.cache['table'] == table) & (self.cache[filter_vals[0]] == filter_vals[1])].index
|
||
)
|
||
self.cache = pd.concat([self.cache, row])
|
||
|
||
# Update the database with the modified row
|
||
self.db.insert_dataframe(row.drop(columns='id'), table)
|
||
|
||
def insert_df(self, df: pd.DataFrame, table: str, skip_cache: bool = False) -> None:
|
||
"""
|
||
Inserts data into the specified table in the database, with an option to skip cache insertion.
|
||
|
||
:param df: The DataFrame containing the data to insert.
|
||
:param table: The name of the table where the data should be inserted.
|
||
:param skip_cache: If True, skips inserting the data into the cache. Default is False.
|
||
"""
|
||
# Insert the data into the database
|
||
self.db.insert_dataframe(df=df, table=table)
|
||
|
||
# Optionally insert the data into the cache
|
||
if not skip_cache:
|
||
df['table'] = table # Add table name for cache identification
|
||
self.cache = pd.concat([self.cache, df])
|
||
|
||
def insert_row(self, table: str, columns: tuple, values: tuple) -> None:
|
||
"""
|
||
Inserts a single row into the specified table in the database.
|
||
|
||
:param table: The name of the table where the row should be inserted.
|
||
:param columns: A tuple of column names corresponding to the values.
|
||
:param values: A tuple of values to insert into the specified columns.
|
||
"""
|
||
self.db.insert_row(table=table, columns=columns, values=values)
|
||
|
||
def get_records_since(self, start_datetime: dt.datetime, ex_details: List[str]) -> pd.DataFrame:
|
||
"""
|
||
This gets up-to-date records from a specified market and exchange.
|
||
|
||
:param start_datetime: The approximate time the first record should represent.
|
||
:param ex_details: The user exchange and market.
|
||
:return: The records.
|
||
"""
|
||
if self.TYPECHECKING_ENABLED:
|
||
if not isinstance(start_datetime, dt.datetime):
|
||
raise TypeError("start_datetime must be a datetime object")
|
||
if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details):
|
||
raise TypeError("ex_details must be a list of strings")
|
||
if not len(ex_details) == 4:
|
||
raise TypeError("ex_details must include [asset, timeframe, exchange, user_name]")
|
||
|
||
if start_datetime.tzinfo is None:
|
||
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
||
end_datetime = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc)
|
||
|
||
try:
|
||
args = {
|
||
'start_datetime': start_datetime,
|
||
'end_datetime': end_datetime,
|
||
'ex_details': ex_details,
|
||
}
|
||
return self._get_or_fetch_from('data', **args)
|
||
except Exception as e:
|
||
logger.error(f"An error occurred: {str(e)}")
|
||
raise
|
||
|
||
def _get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame:
|
||
"""
|
||
Fetches market records from a resource stack (data, database, exchange).
|
||
fills incomplete request by fetching down the stack then updates the rest.
|
||
|
||
:param target: Starting point for the fetch. ['data', 'database', 'exchange']
|
||
:param kwargs: Details and credentials for the request.
|
||
:return: Records in a dataframe.
|
||
"""
|
||
start_datetime = kwargs.get('start_datetime')
|
||
if start_datetime.tzinfo is None:
|
||
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
||
|
||
end_datetime = kwargs.get('end_datetime')
|
||
if end_datetime.tzinfo is None:
|
||
raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
||
|
||
ex_details = kwargs.get('ex_details')
|
||
timeframe = kwargs.get('ex_details')[1]
|
||
|
||
if self.TYPECHECKING_ENABLED:
|
||
if not isinstance(start_datetime, dt.datetime):
|
||
raise TypeError("start_datetime must be a datetime object")
|
||
if not isinstance(end_datetime, dt.datetime):
|
||
raise TypeError("end_datetime must be a datetime object")
|
||
if not isinstance(timeframe, str):
|
||
raise TypeError("record_length must be a string representing the timeframe")
|
||
if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details):
|
||
raise TypeError("ex_details must be a list of strings")
|
||
if not all([start_datetime, end_datetime, timeframe, ex_details]):
|
||
raise ValueError("Missing required arguments")
|
||
|
||
request_criteria = {
|
||
'start_datetime': start_datetime,
|
||
'end_datetime': end_datetime,
|
||
'timeframe': timeframe,
|
||
}
|
||
|
||
key = self._make_key(ex_details)
|
||
combined_data = pd.DataFrame()
|
||
|
||
if target == 'data':
|
||
resources = [self._get_candles_from_cache, self._get_from_database, self._get_from_server]
|
||
elif target == 'database':
|
||
resources = [self._get_from_database, self._get_from_server]
|
||
elif target == 'server':
|
||
resources = [self._get_from_server]
|
||
else:
|
||
raise ValueError('Not a valid Target!')
|
||
|
||
for fetch_method in resources:
|
||
result = fetch_method(**kwargs)
|
||
|
||
if not result.empty:
|
||
# Drop the 'id' column if it exists in the result
|
||
if 'id' in result.columns:
|
||
result = result.drop(columns=['id'])
|
||
|
||
# Concatenate, drop duplicates based on 'time', and sort
|
||
combined_data = pd.concat([combined_data, result]).drop_duplicates(subset='time').sort_values(
|
||
by='time')
|
||
|
||
is_complete, request_criteria = self._data_complete(combined_data, **request_criteria)
|
||
if is_complete:
|
||
if fetch_method in [self._get_from_database, self._get_from_server]:
|
||
self._update_candle_cache(combined_data, key)
|
||
if fetch_method == self._get_from_server:
|
||
self._populate_db(ex_details, combined_data)
|
||
return combined_data
|
||
|
||
kwargs.update(request_criteria) # Update kwargs with new start/end times for next fetch attempt
|
||
|
||
logger.error('Unable to fetch the requested data.')
|
||
return combined_data if not combined_data.empty else pd.DataFrame()
|
||
|
||
def _get_candles_from_cache(self, **kwargs) -> pd.DataFrame:
|
||
start_datetime = kwargs.get('start_datetime')
|
||
if start_datetime.tzinfo is None:
|
||
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
||
|
||
end_datetime = kwargs.get('end_datetime')
|
||
if end_datetime.tzinfo is None:
|
||
raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
||
|
||
ex_details = kwargs.get('ex_details')
|
||
|
||
if self.TYPECHECKING_ENABLED:
|
||
if not isinstance(start_datetime, dt.datetime):
|
||
raise TypeError("start_datetime must be a datetime object")
|
||
if not isinstance(end_datetime, dt.datetime):
|
||
raise TypeError("end_datetime must be a datetime object")
|
||
if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details):
|
||
raise TypeError("ex_details must be a list of strings")
|
||
if not all([start_datetime, end_datetime, ex_details]):
|
||
raise ValueError("Missing required arguments")
|
||
|
||
key = self._make_key(ex_details)
|
||
logger.debug('Getting records from data.')
|
||
df = self.get_cache(key)
|
||
if df is None:
|
||
logger.debug("Cache records didn't exist.")
|
||
return pd.DataFrame()
|
||
logger.debug('Filtering records.')
|
||
df_filtered = df[(df['time'] >= unix_time_millis(start_datetime)) & (
|
||
df['time'] <= unix_time_millis(end_datetime))].reset_index(drop=True)
|
||
return df_filtered
|
||
|
||
def _get_from_database(self, **kwargs) -> pd.DataFrame:
|
||
start_datetime = kwargs.get('start_datetime')
|
||
if start_datetime.tzinfo is None:
|
||
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
||
|
||
end_datetime = kwargs.get('end_datetime')
|
||
if end_datetime.tzinfo is None:
|
||
raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
||
|
||
ex_details = kwargs.get('ex_details')
|
||
|
||
if self.TYPECHECKING_ENABLED:
|
||
if not isinstance(start_datetime, dt.datetime):
|
||
raise TypeError("start_datetime must be a datetime object")
|
||
if not isinstance(end_datetime, dt.datetime):
|
||
raise TypeError("end_datetime must be a datetime object")
|
||
if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details):
|
||
raise TypeError("ex_details must be a list of strings")
|
||
if not all([start_datetime, end_datetime, ex_details]):
|
||
raise ValueError("Missing required arguments")
|
||
|
||
table_name = self._make_key(ex_details)
|
||
if not self.db.table_exists(table_name):
|
||
logger.debug('Records not in database.')
|
||
return pd.DataFrame()
|
||
|
||
logger.debug('Getting records from database.')
|
||
return self.db.get_timestamped_records(table_name=table_name, timestamp_field='time', st=start_datetime,
|
||
et=end_datetime)
|
||
|
||
def _get_from_server(self, **kwargs) -> pd.DataFrame:
|
||
symbol = kwargs.get('ex_details')[0]
|
||
interval = kwargs.get('ex_details')[1]
|
||
exchange_name = kwargs.get('ex_details')[2]
|
||
user_name = kwargs.get('ex_details')[3]
|
||
start_datetime = kwargs.get('start_datetime')
|
||
if start_datetime.tzinfo is None:
|
||
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
||
|
||
end_datetime = kwargs.get('end_datetime')
|
||
if end_datetime.tzinfo is None:
|
||
raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
||
|
||
if self.TYPECHECKING_ENABLED:
|
||
if not isinstance(symbol, str):
|
||
raise TypeError("symbol must be a string")
|
||
if not isinstance(interval, str):
|
||
raise TypeError("interval must be a string")
|
||
if not isinstance(exchange_name, str):
|
||
raise TypeError("exchange_name must be a string")
|
||
if not isinstance(user_name, str):
|
||
raise TypeError("user_name must be a string")
|
||
if not isinstance(start_datetime, dt.datetime):
|
||
raise TypeError("start_datetime must be a datetime object")
|
||
if not isinstance(end_datetime, dt.datetime):
|
||
raise TypeError("end_datetime must be a datetime object")
|
||
|
||
logger.debug('Getting records from server.')
|
||
return self._fetch_candles_from_exchange(symbol, interval, exchange_name, user_name, start_datetime,
|
||
end_datetime)
|
||
|
||
@staticmethod
|
||
def _data_complete(data: pd.DataFrame, **kwargs) -> (bool, dict):
|
||
"""
|
||
Checks if the data completely satisfies the request.
|
||
|
||
:param data: DataFrame containing the records.
|
||
:param kwargs: Arguments required for completeness check.
|
||
:return: A tuple (is_complete, updated_request_criteria) where is_complete is True if the data is complete,
|
||
False otherwise, and updated_request_criteria contains adjusted start/end times if data is incomplete.
|
||
"""
|
||
if data.empty:
|
||
logger.debug("Data is empty.")
|
||
return False, kwargs # No data at all, proceed with the full original request
|
||
|
||
start_datetime: dt.datetime = kwargs.get('start_datetime')
|
||
if start_datetime.tzinfo is None:
|
||
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
||
|
||
end_datetime: dt.datetime = kwargs.get('end_datetime')
|
||
if end_datetime.tzinfo is None:
|
||
raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
||
|
||
timeframe: str = kwargs.get('timeframe')
|
||
|
||
temp_data = data.copy()
|
||
|
||
# Convert 'time' to datetime with unit='ms' and localize to UTC
|
||
temp_data['time_dt'] = pd.to_datetime(temp_data['time'],
|
||
unit='ms', errors='coerce').dt.tz_localize('UTC')
|
||
|
||
min_timestamp = temp_data['time_dt'].min()
|
||
max_timestamp = temp_data['time_dt'].max()
|
||
|
||
logger.debug(f"Data time range: {min_timestamp} to {max_timestamp}")
|
||
logger.debug(f"Expected time range: {start_datetime} to {end_datetime}")
|
||
|
||
tolerance = pd.Timedelta(seconds=5)
|
||
|
||
# Initialize updated request criteria
|
||
updated_request_criteria = kwargs.copy()
|
||
|
||
# Check if data covers the required time range with tolerance
|
||
if min_timestamp > start_datetime + timeframe_to_timedelta(timeframe) + tolerance:
|
||
logger.debug("Data does not start early enough, even with tolerance.")
|
||
updated_request_criteria['end_datetime'] = min_timestamp # Fetch the missing earlier data
|
||
return False, updated_request_criteria
|
||
|
||
if max_timestamp < end_datetime - timeframe_to_timedelta(timeframe) - tolerance:
|
||
logger.debug("Data does not extend late enough, even with tolerance.")
|
||
updated_request_criteria['start_datetime'] = max_timestamp # Fetch the missing later data
|
||
return False, updated_request_criteria
|
||
|
||
# Filter data between start_datetime and end_datetime
|
||
mask = (temp_data['time_dt'] >= start_datetime) & (temp_data['time_dt'] <= end_datetime)
|
||
data_in_range = temp_data.loc[mask]
|
||
|
||
expected_count = estimate_record_count(start_datetime, end_datetime, timeframe)
|
||
actual_count = len(data_in_range)
|
||
|
||
logger.debug(f"Expected record count: {expected_count}, Actual record count: {actual_count}")
|
||
|
||
tolerance = 1
|
||
if actual_count < (expected_count - tolerance):
|
||
logger.debug("Insufficient records within the specified time range, even with tolerance.")
|
||
return False, updated_request_criteria
|
||
|
||
logger.debug("Data completeness check passed.")
|
||
return True, kwargs
|
||
|
||
def cache_exists(self, key: str) -> bool:
|
||
return key in self.cache['key'].values
|
||
|
||
def get_cache(self, key: str) -> Any | None:
|
||
# Check if the key exists in the cache
|
||
if key not in self.cache['key'].values:
|
||
logger.warning(f"The requested data key ({key}) doesn't exist!")
|
||
return None
|
||
|
||
# Retrieve the data associated with the key
|
||
result = self.cache[self.cache['key'] == key]['data'].iloc[0]
|
||
return result
|
||
|
||
def _update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None:
|
||
logger.debug('Updating data with new records.')
|
||
# Concatenate the new records with the existing data
|
||
records = pd.concat([self.get_cache(key), more_records], axis=0, ignore_index=True)
|
||
# Drop duplicates based on 'time' and keep the first occurrence
|
||
records = records.drop_duplicates(subset="time", keep='first')
|
||
# Sort the records by 'time'
|
||
records = records.sort_values(by='time').reset_index(drop=True)
|
||
# Reindex 'id' to ensure the expected order
|
||
records['id'] = range(1, len(records) + 1)
|
||
# Set the updated DataFrame back to data
|
||
self.set_cache(data=records, key=key)
|
||
|
||
def update_cached_dict(self, cache_key: str, dict_key: str, data: Any) -> None:
|
||
"""
|
||
Updates a dictionary stored in the DataFrame cache.
|
||
|
||
:param data: The data to insert into the dictionary.
|
||
:param cache_key: The cache key for the dictionary.
|
||
:param dict_key: The key within the dictionary to update.
|
||
:return: None
|
||
"""
|
||
# Locate the row in the DataFrame that matches the cache_key
|
||
cache_index = self.cache.index[self.cache['key'] == cache_key]
|
||
|
||
if not cache_index.empty:
|
||
# Update the dictionary stored in the 'data' column
|
||
cache_dict = self.cache.at[cache_index[0], 'data']
|
||
|
||
if isinstance(cache_dict, dict):
|
||
cache_dict[dict_key] = data
|
||
|
||
# Ensure the DataFrame is updated with the new dictionary
|
||
self.cache.at[cache_index[0], 'data'] = cache_dict
|
||
else:
|
||
raise ValueError(f"Expected a dictionary in cache, but found {type(cache_dict)}.")
|
||
else:
|
||
raise KeyError(f"Cache key '{cache_key}' not found.")
|
||
|
||
def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str,
|
||
start_datetime: dt.datetime = None,
|
||
end_datetime: dt.datetime = None) -> pd.DataFrame:
|
||
if start_datetime is None:
|
||
start_datetime = dt.datetime(year=2017, month=1, day=1, tzinfo=dt.timezone.utc)
|
||
|
||
if end_datetime is None:
|
||
end_datetime = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc)
|
||
|
||
if start_datetime > end_datetime:
|
||
raise ValueError("Invalid start and end parameters: start_datetime must be before end_datetime.")
|
||
|
||
exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name)
|
||
|
||
expected_records = estimate_record_count(start_datetime, end_datetime, interval)
|
||
logger.info(
|
||
f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {expected_records}')
|
||
|
||
if start_datetime == end_datetime:
|
||
end_datetime = None
|
||
|
||
candles = exchange.get_historical_klines(symbol=symbol, interval=interval, start_dt=start_datetime,
|
||
end_dt=end_datetime)
|
||
|
||
num_rec_records = len(candles.index)
|
||
if num_rec_records == 0:
|
||
logger.warning(f"No OHLCV data returned for {symbol}.")
|
||
return pd.DataFrame(columns=['time', 'open', 'high', 'low', 'close', 'volume'])
|
||
|
||
logger.info(f'{num_rec_records} candles retrieved from the exchange.')
|
||
|
||
open_times = candles.time
|
||
min_open_time = open_times.min()
|
||
max_open_time = open_times.max()
|
||
|
||
if min_open_time < 1e10:
|
||
raise ValueError('Records are not in milliseconds')
|
||
|
||
estimated_num_records = estimate_record_count(min_open_time, max_open_time, interval) + 1
|
||
logger.info(f'Estimated number of records: {estimated_num_records}')
|
||
|
||
if num_rec_records < estimated_num_records:
|
||
logger.info('Detected gaps in the data, attempting to fill missing records.')
|
||
candles = self._fill_data_holes(candles, interval)
|
||
|
||
return candles
|
||
|
||
@staticmethod
|
||
def _fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame:
|
||
time_span = timeframe_to_timedelta(interval).total_seconds() / 60
|
||
last_timestamp = None
|
||
filled_records = []
|
||
|
||
logger.info(f"Starting to fill data holes for interval: {interval}")
|
||
|
||
for index, row in records.iterrows():
|
||
time_stamp = row['time']
|
||
|
||
if last_timestamp is None:
|
||
last_timestamp = time_stamp
|
||
filled_records.append(row)
|
||
logger.debug(f"First timestamp: {time_stamp}")
|
||
continue
|
||
|
||
delta_ms = time_stamp - last_timestamp
|
||
delta_minutes = (delta_ms / 1000) / 60
|
||
|
||
logger.debug(f"Timestamp: {time_stamp}, Delta minutes: {delta_minutes}")
|
||
|
||
if delta_minutes > time_span:
|
||
num_missing_rec = int(delta_minutes / time_span)
|
||
step = int(delta_ms / num_missing_rec)
|
||
logger.debug(f"Gap detected. Filling {num_missing_rec} records with step: {step}")
|
||
|
||
for ts in range(int(last_timestamp) + step, int(time_stamp), step):
|
||
new_row = row.copy()
|
||
new_row['time'] = ts
|
||
filled_records.append(new_row)
|
||
logger.debug(f"Filled timestamp: {ts}")
|
||
|
||
filled_records.append(row)
|
||
last_timestamp = time_stamp
|
||
|
||
logger.info("Data holes filled successfully.")
|
||
return pd.DataFrame(filled_records)
|
||
|
||
def _populate_db(self, ex_details: List[str], data: pd.DataFrame = None) -> None:
|
||
if data is None or data.empty:
|
||
logger.debug(f'No records to insert {data}')
|
||
return
|
||
|
||
table_name = self._make_key(ex_details)
|
||
sym, _, ex, _ = ex_details
|
||
|
||
self.db.insert_candles_into_db(data, table_name=table_name, symbol=sym, exchange_name=ex)
|
||
logger.info(f'Data inserted into table {table_name}')
|
||
|
||
@staticmethod
|
||
def _make_key(ex_details: List[str]) -> str:
|
||
sym, tf, ex, _ = ex_details
|
||
key = f'{sym}_{tf}_{ex}'
|
||
return key
|