brighter-trading/src/DataCache_v3.py

1920 lines
88 KiB
Python

import copy
import pickle
import time
import logging
import datetime as dt
from collections import deque
from typing import Any, Tuple, List, Optional
import pandas as pd
import numpy as np
from indicators import Indicator, indicators_registry
from shared_utilities import unix_time_millis
from Database import Database
# Configure logging
logger = logging.getLogger(__name__)
# Helper Methods
def timeframe_to_timedelta(timeframe: str) -> pd.Timedelta | pd.DateOffset:
digits = int("".join([i if i.isdigit() else "" for i in timeframe]))
unit = "".join([i if i.isalpha() else "" for i in timeframe])
if unit == 's': # Add support for seconds
return pd.Timedelta(seconds=digits)
elif unit == 'm':
return pd.Timedelta(minutes=digits)
elif unit == 'h':
return pd.Timedelta(hours=digits)
elif unit == 'd':
return pd.Timedelta(days=digits)
elif unit == 'w':
return pd.Timedelta(weeks=digits)
elif unit == 'M':
return pd.DateOffset(months=digits)
elif unit == 'Y':
return pd.DateOffset(years=digits)
else:
raise ValueError(f"Invalid timeframe unit: {unit}")
def estimate_record_count(start_time, end_time, timeframe: str) -> int:
"""
Estimate the number of records expected between start_time and end_time based on the given timeframe.
Accepts either datetime objects or Unix timestamps in milliseconds.
"""
# Check for invalid inputs
if not isinstance(start_time, (int, float, np.integer, dt.datetime)):
raise ValueError(f"Invalid type for start_time: {type(start_time)}. Expected datetime or timestamp.")
if not isinstance(end_time, (int, float, np.integer, dt.datetime)):
raise ValueError(f"Invalid type for end_time: {type(end_time)}. Expected datetime or timestamp.")
# Convert milliseconds to datetime if needed
if isinstance(start_time, (int, float, np.integer)) and isinstance(end_time, (int, float, np.integer)):
start_time = dt.datetime.utcfromtimestamp(start_time / 1000).replace(tzinfo=dt.timezone.utc)
end_time = dt.datetime.utcfromtimestamp(end_time / 1000).replace(tzinfo=dt.timezone.utc)
elif isinstance(start_time, dt.datetime) and isinstance(end_time, dt.datetime):
if start_time.tzinfo is None or end_time.tzinfo is None:
raise ValueError("start_time and end_time must be timezone-aware.")
# Normalize both to UTC to avoid time zone differences affecting results
start_time = start_time.astimezone(dt.timezone.utc)
end_time = end_time.astimezone(dt.timezone.utc)
# If the interval is negative, return 0
if end_time <= start_time:
return 0
# Use the existing timeframe_to_timedelta function
delta = timeframe_to_timedelta(timeframe)
# Fixed-length intervals optimization
if isinstance(delta, dt.timedelta):
total_seconds = (end_time - start_time).total_seconds()
return int(total_seconds // delta.total_seconds())
# Handle 'M' (months) and 'Y' (years) using DateOffset
elif isinstance(delta, pd.DateOffset):
if 'months' in delta.kwds:
return (end_time.year - start_time.year) * 12 + (end_time.month - start_time.month)
elif 'years' in delta.kwds:
return end_time.year - start_time.year
raise ValueError(f"Invalid timeframe: {timeframe}")
class CacheEntryMetadata:
"""Stores metadata for a cache entry (row)."""
def __init__(self, expiration_time: Optional[int] = None):
self.creation_time = time.time()
self.expiration_time = expiration_time
self.expiration_timestamp = self.creation_time + expiration_time if expiration_time else None
def is_expired(self) -> bool:
"""Check if the entry is expired."""
if self.expiration_time is None:
return False # No expiration set, entry never expires
if self.expiration_time == 0:
return True # Expire immediately
return time.time() > self.expiration_timestamp
class CacheEntry:
"""Stores data and its expiration metadata."""
def __init__(self, data: Any, expiration_time: Optional[int] = None):
self.data = data
self.metadata = CacheEntryMetadata(expiration_time)
class RowBasedCache:
"""Cache for storing individual rows, where each entry has a unique key."""
def __init__(self, default_expiration: Optional[int] = None, size_limit: Optional[int] = None,
eviction_policy: str = "evict", purge_threshold: int = 10):
self.cache = {}
self.default_expiration = default_expiration
self.size_limit = size_limit
self.eviction_policy = eviction_policy
self.access_order = deque() # Tracks the order of access for eviction
self.access_counter = 0 # Counter to track accesses
self.purge_threshold = purge_threshold # Define how often to trigger purge
def add_entry(self, key: str, data: Any, expiration_time: Optional[int] = None):
"""Add an entry to the cache."""
self._check_purge() # Check if purge is needed
if self.size_limit is not None and len(self.cache) >= self.size_limit:
if self.eviction_policy == "evict":
self.evict()
elif self.eviction_policy == "deny":
return "Cache limit reached. Entry not added."
# If key already exists and is a DataFrame, append the new data
if key in self.cache and isinstance(self.cache[key].data, pd.DataFrame):
# If the key already exists, append the new data to the existing DataFrame
if isinstance(self.cache[key].data, pd.DataFrame):
self.cache[key].data = pd.concat([self.cache[key].data, data], ignore_index=True)
else:
self.cache[key].data = data # For non-DataFrame types, just replace the data
else:
# Otherwise, replace the entry with the new data
expiration_time = expiration_time or self.default_expiration
self.cache[key] = CacheEntry(data, expiration_time)
# Update access order
if key not in self.access_order:
self.access_order.append(key)
else:
self.access_order.remove(key) # Move the key to the end
self.access_order.append(key)
def get_entry(self, key: str) -> Any:
"""Retrieve an entry by key, ensuring expired entries are ignored."""
self._check_purge() # Check if purge is needed
if key in self.cache:
if not self.cache[key].metadata.is_expired():
self.access_order.remove(key)
self.access_order.append(key) # Update access order for eviction
return self.cache[key].data
else:
del self.cache[key] # Remove expired entry
return None
def query(self, conditions: List[Tuple[str, Any]]) -> pd.DataFrame:
"""Query cache entries by conditions, ignoring expired entries."""
self._check_purge() # Check if purge is needed
# Get the value of tbl_key out of the list of key-value pairs.
key_value = next((value for key, value in conditions if key == 'tbl_key'), None)
if key_value is None or key_value not in self.cache:
return pd.DataFrame() # Return an empty DataFrame if key is not found
entry = self.cache[key_value]
# Expire entry if expired.
if entry.metadata.is_expired():
# Remove expired entry and Return an empty DataFrame
del self.cache[key_value]
return pd.DataFrame()
data = entry.data
# If the data is a DataFrame, apply the conditions using pandas .query()
if isinstance(data, pd.DataFrame):
# Construct the query string and prepare local variables for the query
query_conditions = ' and '.join([f'`{col}` == @val_{col}' for col, _ in conditions if col != 'tbl_key'])
query_vars = {f'val_{col}': val for col, val in conditions if col != 'tbl_key'}
# Use pandas .query() with local_dict to pass the variables
return data.query(query_conditions, local_dict=query_vars) if query_conditions else data
return pd.DataFrame([data]) # Return the non-DataFrame data as a single row DataFrame if possible
def is_attr_taken(self, column: str, value: Any) -> bool:
"""Check if a column contains the specified value in the Row-Based Cache."""
self._check_purge() # Check if purge is needed
for key, entry in self.cache.items():
if isinstance(entry.data, pd.DataFrame): # Only apply to DataFrames
if column in entry.data.columns:
# Use DataFrame.query to check if the column contains the value
query_result = entry.data.query(f'`{column}` == @value', local_dict={'value': value})
if not query_result.empty:
return True # Return True if the value exists in any DataFrame row
return False # Return False if no match is found
def evict(self):
"""Evict the oldest accessed entry."""
oldest_key = self.access_order.popleft()
del self.cache[oldest_key]
def _check_purge(self):
"""Increment the access counter and trigger purge if threshold is reached."""
self.access_counter += 1
if self.access_counter >= self.purge_threshold:
self._purge_expired()
self.access_counter = 0 # Reset the counter after purging
def _purge_expired(self):
"""Remove expired entries from the cache."""
expired_keys = [key for key, entry in self.cache.items() if entry.metadata.is_expired()]
for key in expired_keys:
del self.cache[key]
self.access_order.remove(key)
def get_all_items(self) -> dict[str, Any]:
"""Retrieve all non-expired items in the cache."""
self._check_purge() # Ensure expired entries are purged as needed
return {key: entry.data for key, entry in self.cache.items() if not entry.metadata.is_expired()}
def remove_item(self, conditions: List[Tuple[str, Any]]) -> bool:
"""Remove an item from the cache using key-value conditions.
In row cache, only 'tbl_key' is used to identify the entry.
"""
# Find the value of 'tbl_key' from the conditions
key_value = next((value for key, value in conditions if key == 'tbl_key'), None)
if key_value is None or key_value not in self.cache:
return False # Key not found, so nothing to remove
# If no additional conditions are provided, remove the entire entry by key
if len(conditions) == 1:
del self.cache[key_value]
self.access_order.remove(key_value)
return True
entry = self.cache[key_value]
# If the data is a DataFrame, apply additional filtering
if isinstance(entry.data, pd.DataFrame):
# Construct the query string and prepare local variables for the query
query_conditions = ' and '.join([f'`{col}` == @val_{col}' for col, _ in conditions if col != 'tbl_key'])
query_vars = {f'val_{col}': val for col, val in conditions if col != 'tbl_key'}
# Apply the query to the DataFrame, removing matching rows
remaining_data = entry.data.query(f'not ({query_conditions})', local_dict=query_vars)
if remaining_data.empty:
# If all rows are removed, delete the entire entry
del self.cache[key_value]
self.access_order.remove(key_value)
else:
# Update the entry with the remaining rows
entry.data = remaining_data
else:
# If the data is not a DataFrame, remove the entire entry if the 'tbl_key' matches
del self.cache[key_value]
self.access_order.remove(key_value)
return True # Successfully removed the item
class TableBasedCache:
"""Cache for storing entire tables with expiration applied to rows."""
def __init__(self, default_expiration: Optional[int] = None, size_limit: Optional[int] = None,
eviction_policy: str = "evict"):
self.cache = pd.DataFrame() # The DataFrame where both data and metadata are stored
self.default_expiration = default_expiration
self.size_limit = size_limit
self.eviction_policy = eviction_policy
self.access_order = deque() # Tracks the order of access for eviction
def _check_size_limit(self):
"""Check and enforce the size limit."""
if self.size_limit and len(self.cache) > self.size_limit:
if self.eviction_policy == "evict":
excess = len(self.cache) - self.size_limit
self.evict(excess) # Evict excess rows
elif self.eviction_policy == "deny":
return False # Don't allow more rows to be added
return True
def add_table(self, df: pd.DataFrame, expiration_time: Optional[int] = None, overwrite: Optional[str] = None,
key: Optional[str] = None):
"""
Adds a DataFrame to the cache, attaching metadata to each row.
Optionally overwrites rows based on a column value.
:param overwrite: Column name to use for identifying rows to overwrite.
:param df: The DataFrame to add.
:param expiration_time: Optional expiration time for the rows.
:param key:
"""
expiration_time = expiration_time or self.default_expiration
if expiration_time is not None:
metadata = [CacheEntryMetadata(expiration_time) for _ in range(len(df))]
else:
metadata = [CacheEntryMetadata() for _ in range(len(df))]
# Add metadata to each row of the DataFrame
df_with_metadata = df.copy()
df_with_metadata['metadata'] = metadata
# If a key is provided, add a 'tbl_key' column to the DataFrame
if key is not None:
df_with_metadata['tbl_key'] = key
if getattr(self, 'cache', None) is None:
# If the cache is empty, initialize it with the new DataFrame
self.cache = df_with_metadata
else:
# Append the new rows
self.cache = pd.concat([self.cache, df_with_metadata], ignore_index=True)
if overwrite:
# Drop duplicates based on the overwrite column, keeping the last occurrence (new data)
self.cache = self.cache.drop_duplicates(subset=overwrite, keep='last')
# Enforce size limit
if not self._check_size_limit():
return "Cache limit reached. Table not added."
def _purge_expired(self):
"""Remove expired rows from the cache."""
try:
# Filter rows where metadata is not expired, keep columns even if no valid rows
is_valid = self.cache['metadata'].apply(lambda meta: not meta.is_expired())
# Filter DataFrame, ensuring columns are always kept
self.cache = self.cache.loc[is_valid].reindex(self.cache.columns, axis=1).reset_index(drop=True)
except KeyError:
raise KeyError("The 'metadata' column is missing from the cache.")
except AttributeError as e:
raise AttributeError(f"Error in metadata processing: {e}")
def query(self, conditions: List[Tuple[str, Any]]) -> pd.DataFrame:
"""Query rows based on conditions and return valid (non-expired) entries."""
self._purge_expired() # Remove expired rows before querying
# Start with the entire cache
result = self.cache.copy()
# Apply conditions using pandas .query()
if not result.empty:
query_conditions = ' and '.join([f'`{col}` == @val_{col}' for col, _ in conditions])
query_vars = {f'val_{col}': val for col, val in conditions}
# Use pandas .query() with local_dict to pass the variables
result = result.query(query_conditions, local_dict=query_vars) if query_conditions else result
# Remove the metadata and tbl_key columns for the result
return result.drop(columns=['metadata'], errors='ignore')
def is_attr_taken(self, column: str, value: Any) -> bool:
"""Check if a column contains the specified value in the Table-Based Cache."""
self._purge_expired() # Ensure expired entries are removed
if column not in self.cache.columns:
return False # Column does not exist
# Use DataFrame.query to check if the column contains the value
query_result = self.cache.query(f'`{column}` == @value', local_dict={'value': value})
return not query_result.empty # Return True if the value exists, otherwise False
def evict(self, num_rows: int = 1):
"""Evict the oldest accessed rows based on the access order."""
if len(self.cache) == 0:
return
# Evict the first num_rows rows
self.cache = self.cache.iloc[num_rows:].reset_index(drop=True)
def get_all_items(self) -> pd.DataFrame:
"""Retrieve all non-expired rows from the table-based cache."""
self._purge_expired() # Ensure expired rows are removed
return self.cache
def remove_item(self, conditions: List[Tuple[str, Any]]) -> bool:
"""Remove rows from the table-based cache that match the key-value conditions."""
self._purge_expired() # Ensure expired entries are removed
if self.cache.empty:
return False # Cache is empty
# Construct the query string and prepare local variables for the query
query_conditions = ' and '.join([f'`{col}` == @val_{col}' for col, _ in conditions])
query_vars = {f'val_{col}': val for col, val in conditions}
# Apply the query to find matching rows
remaining_data = self.cache.query(f'not ({query_conditions})', local_dict=query_vars)
if len(remaining_data) == len(self.cache):
return False # No rows matched the conditions, so nothing was removed
# Update the cache with the remaining data
self.cache = remaining_data
return True # Successfully removed matching rows
class CacheManager:
"""Manages different cache types (row-based and table-based)."""
def __init__(self):
self.caches = {}
def create_cache(self, name: str, cache_type: str,
size_limit: Optional[int] = None,
eviction_policy: str = 'evict',
default_expiration: Optional[dt.timedelta] = None,
columns: Optional[list] = None) -> TableBasedCache | RowBasedCache:
"""
Creates a new cache with the given parameters.
:param name: The name of the cache.
:param cache_type: The type of cache ('row' or 'table').
:param size_limit: Maximum number of items allowed in the cache.
:param eviction_policy: Policy for evicting items when cache limit is reached.
:param default_expiration: A timedelta object representing the expiration time.
:param columns: Optional list of column names to initialize an empty DataFrame for a table-based cache.
:return: The created cache.
"""
# Convert default_expiration timedelta to seconds
expiration_in_seconds = default_expiration.total_seconds() if \
default_expiration not in [None, 0] else default_expiration
# Create cache using expiration_in_seconds
if cache_type == 'row':
self.caches[name] = RowBasedCache(size_limit=size_limit, eviction_policy=eviction_policy,
default_expiration=expiration_in_seconds)
elif cache_type == 'table':
self.caches[name] = TableBasedCache(size_limit=size_limit, eviction_policy=eviction_policy,
default_expiration=expiration_in_seconds)
# Initialize the DataFrame with provided columns if specified
if columns:
self.caches[name].add_table(df=pd.DataFrame(columns=columns))
logging.info(f"Table-based cache '{name}' initialized with columns: {columns}")
else:
raise ValueError(f"Unsupported cache type: {cache_type}")
logging.info(f"Cache '{name}' of type '{cache_type}' created with expiration: {default_expiration}")
return self.caches.get(name)
def get_cache(self, name: str) -> RowBasedCache | TableBasedCache:
"""Retrieve a cache by name."""
if name in self.caches:
return self.caches[name]
else:
raise KeyError(f"Cache: {name}, does not exist.")
def get_rows_from_cache(self, cache_name: str, filter_vals: list[tuple[str, Any]]) -> pd.DataFrame | None:
"""
Retrieves rows from the cache if available;
:param cache_name: The key used to identify the cache.
:param filter_vals: A list of tuples, each containing a column name and the value(s) to filter by.
:return: A DataFrame containing the requested rows, or None if no matching rows are found.
:raises ValueError: If the cache is not a DataFrame or does not contain DataFrames in the 'data' column.
"""
# Check if the cache exists
if cache_name in self.caches:
cache = self.get_cache(cache_name)
# Ensure the cache contains DataFrames (required for querying)
if not isinstance(cache, (TableBasedCache, RowBasedCache)):
raise ValueError(f"Cache '{cache_name}' does not contain TableBasedCache or RowBasedCache.")
# Perform the query on the cache using filter_vals
filtered_cache = cache.query(filter_vals) # Pass the list of filters
# If data is found in the cache, return it
if not filtered_cache.empty:
return filtered_cache
# No result return an empty Dataframe
return pd.DataFrame()
def get_cache_item(self, item_name: str, cache_name: str, filter_vals: tuple[str, any]) -> any:
"""
Retrieves a specific item from the cache.
:param item_name: The name of the column to retrieve.
:param cache_name: The name used to identify the cache (also the name of the database table).
:param filter_vals: A tuple containing the column name and the value to filter by.
:return: The value of the requested item.
:raises ValueError: If the item is not found in either the cache,
or if the column does not exist.
"""
# Fetch the relevant rows from the cache or database
rows = self.get_rows_from_cache(cache_name=cache_name, filter_vals=[filter_vals])
if rows is not None and not rows.empty:
if item_name not in rows.columns:
raise ValueError(f"Column '{item_name}' does not exist in the cache '{cache_name}'.")
# Return the specific item from the first matching row.
return rows.iloc[0][item_name]
# No item found in the cache that satisfied the query.
return None
def insert_row_into_cache(self, cache_name: str, columns: tuple, values: tuple, key: str = None) -> None:
"""
Inserts a single row into the specified cache.
:param cache_name: The name of the cache where the row should be inserted.
:param columns: A tuple of column names corresponding to the values.
:param values: A tuple of values to insert into the specified columns.
:param key: Optional key for the cache item.
"""
# Create a DataFrame for the new row
new_row_df = pd.DataFrame([values], columns=list(columns))
# Determine if the cache is row-based or table-based, and insert accordingly
cache = self.get_cache(cache_name)
if isinstance(cache, RowBasedCache):
if key is None:
raise ValueError('A key must be provided for row based cache.')
# For row-based cache, insert the new row as a new cache entry using the key
cache.add_entry(key=key, data=new_row_df)
elif isinstance(cache, TableBasedCache):
# For table-based cache, append the new row to the existing DataFrame
cache.add_table(df=new_row_df)
else:
raise ValueError(f"Unknown cache type for {cache_name}")
def insert_df_into_cache(self, df: pd.DataFrame, cache_name: str) -> None:
"""
Inserts data from a DataFrame into the specified cache.
:param df: The DataFrame containing the data to insert.
:param cache_name: The name of the cache where the data should be inserted.
"""
cache = self.get_cache(cache_name)
if isinstance(cache, RowBasedCache):
# For row-based cache, insert each row of the DataFrame individually using the first column as the key
for idx, row in df.iterrows():
key = str(row[0]) # Assuming the first column is the unique key for each row
cache.add_entry(key=key, data=row.to_frame().T) # Convert row back to DataFrame for insertion
elif isinstance(cache, TableBasedCache):
# For table-based cache, insert the entire DataFrame
cache.add_table(df=df)
else:
raise ValueError(f"Unknown cache type for {cache_name}")
def remove_row_from_cache(self, cache_name: str, filter_vals: List[tuple[str, Any]]) -> None:
"""
Removes rows from the cache based on multiple filter criteria.
This method is specifically designed for caches stored as DataFrames.
:param cache_name: The name of the cache (or table) from which to remove rows.
:param filter_vals: A list of tuples, each containing a column name and the value to filter by.
:raises ValueError: If the cache is not a DataFrame or if no valid cache is found.
"""
# Ensure filter_vals is a list of tuples
if not isinstance(filter_vals, list) or not all(isinstance(item, tuple) for item in filter_vals):
raise ValueError("filter_vals must be a list of tuples (column, value)")
cache = self.get_cache(cache_name)
if cache is None:
raise ValueError(f"Cache '{cache_name}' not found.")
# Call the cache system to remove the filtered rows
cache.remove_item(filter_vals)
def modify_cache_item(self, cache_name: str, filter_vals: List[Tuple[str, any]], field_name: str,
new_data: any) -> None:
"""
Modifies a specific field in a row within the cache.
:param cache_name: The name used to identify the cache.
:param filter_vals: A list of tuples containing column names and values to filter by.
:param field_name: The field to be updated.
:param new_data: The new data to be set.
:raises ValueError: If the row is not found in the cache, or if multiple rows are returned.
"""
# Retrieve the row from the cache
rows = self.get_rows_from_cache(cache_name=cache_name, filter_vals=filter_vals)
if rows is None or rows.empty:
raise ValueError(f"Row not found in cache for {filter_vals}")
# Check if multiple rows are returned
if len(rows) > 1:
raise ValueError(f"Multiple rows found for {filter_vals}. Please provide a more specific filter.")
# Update the DataFrame with the new value
rows[field_name] = new_data
# Get the cache instance
cache = self.get_cache(cache_name)
# Set the updated row in the cache
if isinstance(cache, RowBasedCache):
# For row-based cache, the 'tbl_key' must be in filter_vals
key_value = next((val for key, val in filter_vals if key == 'tbl_key'), None)
if key_value is None:
raise ValueError("'tbl_key' must be present in filter_vals for row-based caches.")
# Update the cache entry with the modified row
cache.add_entry(key=key_value, data=rows)
elif isinstance(cache, TableBasedCache):
# For table-based cache, use the existing query method to update the correct rows
cache.add_table(rows)
else:
raise ValueError(f"Unsupported cache type for {cache_name}")
@staticmethod
def key_exists(cache, key):
# Handle different cache types
if isinstance(cache, RowBasedCache):
return True if key in cache.cache else False
if isinstance(cache, TableBasedCache):
existing_rows = cache.query([("tbl_key", key)])
return False if existing_rows.empty else True
class SnapshotDataCache(CacheManager):
"""
Extends DataCacheBase with snapshot functionality for both row-based and table-based caches.
Attributes:
snapshots (dict): A dictionary to store snapshots, with cache names as keys and snapshot data as values.
Methods:
snapshot_cache(cache_name: str): Takes a snapshot of the specified cache and stores it.
get_snapshot(cache_name: str): Retrieves the most recent snapshot of the specified cache.
list_snapshots() -> dict: Lists all available snapshots along with their timestamps.
"""
def __init__(self):
super().__init__() # Call the constructor of CacheManager
self.snapshots = {} # Initialize the snapshots dictionary
def snapshot_cache(self, cache_name: str):
"""Takes a snapshot of the specified cache and stores it with a timestamp."""
if cache_name not in self.caches:
raise ValueError(f"Cache '{cache_name}' does not exist.")
# Deep copy of the cache to ensure that the snapshot is independent
cache_data = copy.deepcopy(self.caches[cache_name])
# Store the snapshot with a timestamp
self.snapshots[cache_name] = (cache_data, dt.datetime.now())
print(f"Snapshot taken for cache '{cache_name}' at {self.snapshots[cache_name][1]}.")
def get_snapshot(self, cache_name: str):
"""Retrieves the most recent snapshot of the specified cache."""
if cache_name not in self.snapshots:
raise ValueError(f"No snapshot available for cache '{cache_name}'.")
snapshot, timestamp = self.snapshots[cache_name]
print(f"Returning snapshot of cache '{cache_name}' taken at {timestamp}.")
return snapshot
def list_snapshots(self) -> dict:
"""Lists all available snapshots along with their timestamps."""
return {cache_name: timestamp for cache_name, (_, timestamp) in self.snapshots.items()}
class DatabaseInteractions(SnapshotDataCache):
"""
Extends SnapshotDataCache with additional functionality, including database interactions and cache management.
Attributes:
db (Database): A database connection instance for executing queries.
TYPECHECKING_ENABLED (bool): A class attribute to toggle type checking.
Methods:
get_or_fetch_rows(cache_name: str, filter_vals: tuple[str, any]) -> pd.DataFrame | None:
Retrieves rows from the cache or database.
fetch_item(item_name: str, cache_name: str, filter_vals: tuple[str, any]) -> any:
Retrieves a specific item from the cache or database.
insert_row(cache_name: str, columns: tuple, values: tuple, skip_cache: bool = False) -> None:
Inserts a row into the cache and database.
insert_df(df: pd.DataFrame, cache_name: str, skip_cache: bool = False) -> None:
Inserts data from a DataFrame into the cache and database.
remove_row(cache_name: str, filter_vals: tuple[str, any], additional_filter: tuple[str, any] = None,
remove_from_db: bool = True) -> None:
Removes a row from the cache and optionally the database.
is_attr_taken(cache_name: str, attr: str, val: any) -> bool:
Checks if a specific attribute has a given value in the cache.
modify_item(cache_name: str, filter_vals: tuple[str, any], field_name: str, new_data: any) -> None:
Modifies a field in a row within the cache and updates the database.
update_cached_dict(cache_name: str, cache_key: str, dict_key: str, data: any) -> None:
Updates a dictionary stored in the DataFrame cache.
Usage Example:
# Create a DatabaseInteractions instance
db_cache_manager = DatabaseInteractions(exchanges=exchanges)
# Set some data in the cache
db_cache_manager.set_cache_item(key='AAPL', data={'price': 100}, cache_name='stock_prices')
# Fetch or query rows from the cache or database
rows = db_cache_manager.get_or_fetch_rows(cache_name='stock_prices', filter_vals=('symbol', 'AAPL'))
print(f"Fetched Rows:\n{rows}")
# Fetch a specific item from the cache or database
price = db_cache_manager.fetch_item(item_name='price', cache_name='stock_prices',
filter_vals=('symbol', 'AAPL'))
print(f"Fetched Price: {price}")
# Insert a new row into the cache and database
db_cache_manager.insert_row(cache_name='stock_prices', columns=('symbol', 'price'), values=('TSLA', 800))
# Check if an attribute value is already taken
is_taken = db_cache_manager.is_attr_taken(cache_name='stock_prices', attr='symbol', val='AAPL')
print(f"Is Symbol Taken: {is_taken}")
# Modify an existing item in the cache and database
db_cache_manager.modify_item(cache_name='stock_prices', filter_vals=('symbol', 'AAPL'),
field_name='price', new_data=105)
# Update a dictionary within the cache
db_cache_manager.update_cached_dict(cache_name='stock_prices', cache_key='AAPL', dict_key='price', data=110)
"""
TYPECHECKING_ENABLED = True
def __init__(self):
super().__init__()
self.db = Database()
def get_rows_from_datacache(self, cache_name: str, filter_vals: list[tuple[str, Any]] = None,
key: str = None, include_tbl_key: bool = False) -> pd.DataFrame | None:
"""
Retrieves rows from the cache if available; otherwise, queries the database and caches the result.
:param include_tbl_key: If True, includes 'tbl_key' in the returned DataFrame.
:param key: Optional key to filter by 'tbl_key'.
:param cache_name: The key used to identify the cache (also the name of the database table).
:param filter_vals: A list of tuples, each containing a column name and the value(s) to filter by.
If a value is a list, it will use the SQL 'IN' clause.
:return: A DataFrame containing the requested rows, or None if no matching rows are found.
:raises ValueError: If the cache is not a DataFrame or does not contain DataFrames in the 'data' column.
"""
# Ensure at least one of filter_vals or key is provided
if not filter_vals and not key:
raise ValueError("At least one of 'filter_vals' or 'key' must be provided.")
# Use an empty list if filter_vals is None
filter_vals = filter_vals or []
# Insert the key if provided
if key:
filter_vals.insert(0, ('tbl_key', key))
# Convert filter values that are lists to 'IN' clauses for cache filtering
result = self.get_rows_from_cache(cache_name, filter_vals)
# Fallback to database if no result found in cache
if result.empty:
result = self._fetch_from_database(cache_name, filter_vals)
# Only use _fetch_from_database_with_list_support if any filter values are lists
if result.empty and any(isinstance(val, list) for _, val in filter_vals):
result = self._fetch_from_database_with_list_support(cache_name, filter_vals)
# Remove 'tbl_key' unless include_tbl_key is True
if not include_tbl_key:
result = result.drop(columns=['tbl_key'], errors='ignore')
return result
def _fetch_from_database_with_list_support(self, cache_name: str,
filter_vals: List[tuple[str, Any]]) -> pd.DataFrame:
"""
Fetch rows from the database, supporting list values that require SQL 'IN' clauses.
:param cache_name: The name of the table or key used to store/retrieve data.
:param filter_vals: A list of tuples with the filter column and value, supporting lists for 'IN' clause.
:return: A DataFrame with the fetched rows, or None if no data is found.
"""
where_clauses = []
params = []
for col, val in filter_vals:
if isinstance(val, list):
placeholders = ', '.join('?' for _ in val)
where_clauses.append(f"{col} IN ({placeholders})")
params.extend(val)
else:
where_clauses.append(f"{col} = ?")
params.append(val)
where_clause = " AND ".join(where_clauses)
sql_query = f"SELECT * FROM {cache_name} WHERE {where_clause}"
# Execute the SQL query with the prepared parameters
rows = self.db.get_rows_where(sql_query, params)
# Cache the result (either row or table based)
if rows is not None and not rows.empty:
cache = self.get_cache(cache_name)
if isinstance(cache, RowBasedCache):
for _, row in rows.iterrows():
cache.add_entry(key=row['tbl_key'], data=row)
else:
cache.add_table(rows, overwrite='tbl_key')
return rows
def _fetch_from_database(self, cache_name: str, filter_vals: List[tuple[str, Any]]) -> pd.DataFrame:
"""
Fetch rows from the database and cache the result.
:param cache_name: The name of the table or key used to store/retrieve data.
:param filter_vals: A list of tuples with the filter column and value.
:return: A DataFrame with the fetched rows, or None if no data is found.
"""
# Use db.get_rows_where, assuming it can handle multiple filters
rows = self.db.get_rows_where(cache_name, filter_vals)
if rows is not None and not rows.empty:
# Cache the fetched data (let the caching system handle whether it's row or table-based)
cache = self.get_cache(cache_name)
if isinstance(cache, RowBasedCache):
# For row-based cache, assume the first filter value is used as the key
key_value = filter_vals[0][1] # Use the value of the first filter as the key
cache.add_entry(key=key_value, data=rows)
else:
# For table-based cache, add the entire DataFrame to the cache
cache.add_table(df=rows, overwrite='tbl_key')
# Return the fetched rows
return rows
# If no rows are found, return None
return pd.DataFrame()
def get_datacache_item(self, item_name: str, cache_name: str, filter_vals: tuple[str, any]) -> any:
"""
Retrieves a specific item from the cache or database, caching the result if necessary.
:param item_name: The name of the column to retrieve.
:param cache_name: The name used to identify the cache (also the name of the database table).
:param filter_vals: A tuple containing the column name and the value to filter by.
:return: The value of the requested item.
:raises ValueError: If the item is not found in either the cache or the database,
or if the column does not exist.
"""
# Fetch the relevant rows from the cache or database
rows = self.get_rows_from_datacache(cache_name=cache_name, filter_vals=[filter_vals])
if rows is not None and not rows.empty:
if item_name not in rows.columns:
raise ValueError(f"Column '{item_name}' does not exist in the cache '{cache_name}'.")
# Return the specific item from the first matching row.
return rows.iloc[0][item_name]
# The item was not found.
return None
def insert_row_into_datacache(self, cache_name: str, columns: tuple, values: tuple, key: str = None,
skip_cache: bool = False) -> None:
"""
Inserts a single row into the specified cache and database, with an option to skip cache insertion.
:param cache_name: The name of the cache (and database table) where the row should be inserted.
:param columns: A tuple of column names corresponding to the values.
:param values: A tuple of values to insert into the specified columns.
:param key: Optional key for the cache item. If None, the auto-incremented ID from the database will be used.
:param skip_cache: If True, skips inserting the row into the cache. Default is False.
"""
if key:
columns, values = columns + ('tbl_key',), values + (key,)
# Insert the row into the database
last_inserted_id = self.db.insert_row(table=cache_name, columns=columns, values=values)
# Now include the 'id' in the columns and values when inserting into the cache
columns = ('id',) + columns
values = (last_inserted_id,) + values
# Insert the row into the cache
if skip_cache:
return
self.insert_row_into_cache(cache_name, columns, values, key)
def insert_df_into_datacache(self, df: pd.DataFrame, cache_name: str, skip_cache: bool = False) -> None:
"""
Inserts data from a DataFrame into the specified cache and database, with an option to skip cache insertion.
:param df: The DataFrame containing the data to insert.
:param cache_name: The name of the cache (and database table) where the data should be inserted.
:param skip_cache: If True, skips inserting the data into the cache. Default is False.
"""
# Insert the data into the database
self.db.insert_dataframe(df=df, table=cache_name)
if not skip_cache:
self.insert_df_into_cache(df, cache_name)
def remove_row_from_datacache(self, cache_name: str, filter_vals: List[tuple[str, Any]],
remove_from_db: bool = True, key: str = None) -> None:
"""
Removes rows from the cache and optionally from the database based on multiple filter criteria.
This method is specifically designed for caches stored as DataFrames.
:param key: Optional key
:param cache_name: The name of the cache (or table) from which to remove rows.
:param filter_vals: A list of tuples, each containing a column name and the value to filter by.
:param remove_from_db: If True, also removes the rows from the database. Default is True.
:raises ValueError: If the cache is not a DataFrame or if no valid cache is found.
"""
if key:
filter_vals.insert(0, ('tbl_key', key))
self.remove_row_from_cache(cache_name, filter_vals)
# Remove from the database if required
if remove_from_db:
# Construct SQL to remove the rows from the database based on filter_vals
sql = f"DELETE FROM {cache_name} WHERE " + " AND ".join([f"{col} = ?" for col, _ in filter_vals])
params = [val for _, val in filter_vals]
# Execute the SQL query to remove the row from the database
self.db.execute_sql(sql, params)
def modify_datacache_item(self, cache_name: str, filter_vals: List[Tuple[str, any]], field_names: Tuple[str, ...],
new_values: Tuple[Any, ...], key: str = None, overwrite: str = None) -> None:
"""
Modifies specific fields in a row within the cache and updates the database accordingly.
:param cache_name: The name used to identify the cache.
:param filter_vals: A list of tuples containing column names and values to filter by.
:param field_names: A tuple of field names to be updated.
:param new_values: A tuple of new values corresponding to field_names.
:param key: Optional key to identify the entry.
:param overwrite: Column name(s) to use for overwriting in the cache.
:raises ValueError: If the row is not found or multiple rows are returned.
"""
if key:
filter_vals.insert(0, ('tbl_key', key))
# Retrieve the row from the cache or database
rows = self.get_rows_from_datacache(cache_name=cache_name, filter_vals=filter_vals)
if rows is None or rows.empty:
raise ValueError(f"Row not found in cache or database for {filter_vals}")
# Check if multiple rows are returned
if len(rows) > 1:
raise ValueError(f"Multiple rows found for {filter_vals}. Please provide more specific filter.")
# Update the DataFrame with the new values
for field_name, new_value in zip(field_names, new_values):
rows[field_name] = new_value
# Get the cache instance
cache = self.get_cache(cache_name)
# Set the updated row in the cache
if isinstance(cache, RowBasedCache):
# Update the cache entry with the modified row
cache.add_entry(key=key, data=rows)
elif isinstance(cache, TableBasedCache):
# Use 'overwrite' to ensure correct row is updated
cache.add_table(rows, overwrite=overwrite)
else:
raise ValueError(f"Unsupported cache type for {cache_name}")
# Update the values in the database
set_clause = ", ".join([f"{field} = ?" for field in field_names])
where_clause = " AND ".join([f"{col} = ?" for col, _ in filter_vals])
sql_update = f"UPDATE {cache_name} SET {set_clause} WHERE {where_clause}"
params = list(new_values) + [val for _, val in filter_vals]
# Execute the SQL update to modify the database
self.db.execute_sql(sql_update, params)
def modify_multiple_datacache_items(self, cache_name: str, filter_vals: List[Tuple[str, any]],
field_names: Tuple[str, ...], new_values: Tuple[Any, ...],
key: str = None, overwrite: str = None) -> None:
"""
Modifies specific fields in multiple rows within the cache and updates the database accordingly.
:param cache_name: The name used to identify the cache.
:param filter_vals: A list of tuples containing column names and values to filter by.
If a filter value is a list, it will be used with the 'IN' clause.
:param field_names: A tuple of field names to be updated.
:param new_values: A tuple of new values corresponding to field_names.
:param key: Optional key to identify the entry.
:param overwrite: Column name(s) to use for overwriting in the cache.
:raises ValueError: If no rows are found.
"""
if key:
filter_vals.insert(0, ('tbl_key', key))
# Prepare the SQL query
where_clauses = []
query_params = []
for col, val in filter_vals:
if isinstance(val, list):
# Use the 'IN' clause if the value is a list
placeholders = ', '.join('?' for _ in val)
where_clauses.append(f"{col} IN ({placeholders})")
query_params.extend(val)
else:
where_clauses.append(f"{col} = ?")
query_params.append(val)
# Build the SQL query string
where_clause = " AND ".join(where_clauses)
set_clause = ", ".join([f"{field} = ?" for field in field_names])
sql_update = f"UPDATE {cache_name} SET {set_clause} WHERE {where_clause}"
# Add the new values to the parameters list
query_params = list(new_values) + query_params
# Execute the SQL update to modify the database
self.db.execute_sql(sql_update, query_params)
# Retrieve the rows from the cache to update the cache
rows = self.get_rows_from_datacache(cache_name=cache_name, filter_vals=filter_vals)
if rows is None or rows.empty:
raise ValueError(f"Rows not found in cache or database for {filter_vals}")
# Update the cache with the new values
for field_name, new_value in zip(field_names, new_values):
rows[field_name] = new_value
# Get the cache instance
cache = self.get_cache(cache_name)
if isinstance(cache, RowBasedCache):
for _, row in rows.iterrows():
key_value = row['tbl_key']
cache.add_entry(key=key_value, data=row)
elif isinstance(cache, TableBasedCache):
cache.add_table(rows, overwrite=overwrite)
else:
raise ValueError(f"Unsupported cache type for {cache_name}")
def serialized_datacache_insert(self, cache_name: str, data: Any, key: str = None,
do_not_overwrite: bool = False):
"""
Stores an item in the cache, with custom serialization for object instances.
If the data is not a DataFrame, the entire object is serialized and stored under a column named 'data'.
:param cache_name: The name of the cache.
:param data: Any object to store in the cache, but should be a DataFrame with one row for normal operations.
:param key: The key for row-based caches, used to identify the entry. Required for row-based caches.
:param do_not_overwrite: If True, prevents overwriting existing entries in the cache.
"""
# Retrieve the cache
cache = self.get_cache(cache_name)
# If overwrite is disabled and the key already exists, prevent overwriting
if do_not_overwrite and self.key_exists(cache, key):
logging.warning(f"Key '{key}' already exists in cache '{cache_name}'. Overwrite prevented.")
return
# If the data is a DataFrame, ensure it contains exactly one row
if isinstance(data, pd.DataFrame):
if len(data) != 1:
raise ValueError('This method is for inserting a DataFrame with exactly one row.')
# Ensure key is provided for RowBasedCache
if isinstance(cache, RowBasedCache) and key is None:
raise ValueError("RowBasedCache requires a key to store the data.")
# List of types to exclude from serialization
excluded_objects = (str, int, float, bool, type(None), bytes)
# Process and serialize non-excluded objects in the row
row = data.iloc[0] # Access the first (and only) row
row_values = []
for col_value in row:
# Serialize column value if it's not one of the excluded types
if not isinstance(col_value, excluded_objects):
col_value = pickle.dumps(col_value)
row_values.append(col_value)
# Insert the row into the cache and database (key is handled in insert_row_into_datacache)
self.insert_row_into_datacache(cache_name=cache_name, columns=tuple(data.columns),
values=tuple(row_values), key=key)
else:
# For non-DataFrame data, serialize the entire object
serialized_data = pickle.dumps(data)
# Insert the serialized object under a column named 'data'
self.insert_row_into_datacache(cache_name=cache_name, columns=('data',),
values=(serialized_data,), key=key)
return
def get_serialized_datacache(self,
cache_name: str,
filter_vals: List[Tuple[str, Any]] = None,
key: str = None) -> pd.DataFrame | Any:
"""
Retrieves an item from the specified cache and deserializes object columns if necessary.
If the stored data is a serialized object (not a DataFrame), it returns the deserialized object.
:param key: The key to identify the cache entry.
:param filter_vals: List of column filters (name, value) for the cache query.
:param cache_name: The name of the cache.
:return Any: Cached data with deserialized objects or the original non-DataFrame object, or None if not found.
"""
# Ensure at least one of filter_vals or key is provided
if not filter_vals and not key:
raise ValueError("At least one of 'filter_vals' or 'key' must be provided.")
# Prepare filter values
filter_vals = filter_vals or []
if key:
filter_vals.insert(0, ('tbl_key', key))
# Retrieve rows from the cache using the key
data = self.get_rows_from_datacache(cache_name=cache_name, filter_vals=filter_vals)
# Return None if no data is found
if data is None or (isinstance(data, pd.DataFrame) and data.empty):
logging.info(f"No data found in cache '{cache_name}' for key: {key}")
return pd.DataFrame()
# Handle non-DataFrame data
if not isinstance(data, pd.DataFrame):
logging.warning(f"Unexpected data format from cache '{cache_name}'. Returning None.")
return pd.DataFrame()
# Check for single column 'data' (serialized object case)
if 'data' in data.columns and len(data.columns) == 1:
return self._deserialize_object(data.iloc[0]['data'], cache_name)
# Handle deserialization of DataFrame columns
return self._deserialize_dataframe_row(data, cache_name)
@staticmethod
def _deserialize_object(serialized_data: Any, cache_name: str) -> Any:
"""
Deserializes an object stored as a serialized byte stream in the cache.
:param serialized_data: Serialized byte data to deserialize.
:param cache_name: The name of the cache (used for logging).
:return: Deserialized object, or the raw bytes if deserialization fails.
"""
if not isinstance(serialized_data, bytes):
logging.warning(f"Expected bytes for deserialization in cache '{cache_name}', got {type(serialized_data)}.")
return serialized_data
try:
deserialized_data = pickle.loads(serialized_data)
logging.info(f"Serialized object retrieved and deserialized from cache '{cache_name}'")
return deserialized_data
except (pickle.PickleError, TypeError) as e:
logging.warning(f"Failed to deserialize object from cache '{cache_name}': {e}")
return serialized_data # Fallback to the raw serialized data
@staticmethod
def _deserialize_dataframe_row(data: pd.DataFrame, cache_name: str) -> pd.DataFrame:
"""
Deserializes any serialized columns in a DataFrame row.
:param data: The DataFrame containing serialized columns.
:param cache_name: The name of the cache (used for logging).
:return: DataFrame with deserialized values.
"""
row = data.iloc[0] # Assuming we only retrieve one row
deserialized_row = []
for col_value in row:
if isinstance(col_value, bytes):
try:
deserialized_col_value = pickle.loads(col_value)
deserialized_row.append(deserialized_col_value)
except (pickle.PickleError, TypeError) as e:
logging.warning(f"Failed to deserialize column value in cache '{cache_name}': {e}")
deserialized_row.append(col_value) # Fallback to the original value
else:
deserialized_row.append(col_value)
deserialized_data = pd.DataFrame([deserialized_row], columns=data.columns)
logging.info(f"Data retrieved and deserialized from cache '{cache_name}'")
return deserialized_data
class ServerInteractions(DatabaseInteractions):
"""
Extends DataCache to specialize in handling candle (OHLC) data and server interactions.
"""
def __init__(self):
super().__init__()
self.exchanges = None
def set_exchange(self, exchanges):
"""
Sets an exchange interface for this class to use.
:param exchanges: ExchangeInterface obj
:return: none.
"""
self.exchanges = exchanges
@staticmethod
def _make_key(ex_details: list[str]) -> str:
"""
Generates a unique key string based on the exchange details provided.
:param ex_details: A list of strings containing symbol, timeframe, exchange, etc.
:return: A key string composed of the concatenated details.
"""
symbol, timeframe, exchange, _ = ex_details
key = f'{symbol}_{timeframe}_{exchange}'
return key
def _update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None:
logger.debug('Updating data with new records.')
# Retrieve the 'candles' cache
candles = self.get_cache('candles')
# Store the updated records back in the cache
candles.add_entry(key=key, data=more_records)
def get_records_since(self, start_datetime: dt.datetime, ex_details: list[str]) -> pd.DataFrame:
"""
This gets up-to-date records from a specified market and exchange.
:param start_datetime: The approximate time the first record should represent.
:param ex_details: The user exchange and market.
:return: The records.
"""
if self.TYPECHECKING_ENABLED:
if not isinstance(start_datetime, dt.datetime):
raise TypeError("start_datetime must be a datetime object")
if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details):
raise TypeError("ex_details must be a list of strings")
if not len(ex_details) == 4:
raise TypeError("ex_details must include [asset, timeframe, exchange, user_name]")
if start_datetime.tzinfo is None:
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
end_datetime = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc)
try:
args = {
'start_datetime': start_datetime,
'end_datetime': end_datetime,
'ex_details': ex_details,
}
return self.get_or_fetch_from(target='data', **args)
except Exception as e:
logger.error(f"An error occurred: {str(e)}")
raise
def get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame:
"""
Fetches market records from a resource stack (data, database, exchange).
Fills incomplete request by fetching down the stack then updates the rest.
:param target: Starting point for the fetch. ['data', 'database', 'exchange']
:param kwargs: Details and credentials for the request.
:return: Records in a dataframe.
"""
start_datetime = kwargs.get('start_datetime')
if start_datetime.tzinfo is None:
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
end_datetime = kwargs.get('end_datetime')
if end_datetime.tzinfo is None:
raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.")
ex_details = kwargs.get('ex_details')
timeframe = ex_details[1]
if self.TYPECHECKING_ENABLED:
if not isinstance(start_datetime, dt.datetime):
raise TypeError("start_datetime must be a datetime object")
if not isinstance(end_datetime, dt.datetime):
raise TypeError("end_datetime must be a datetime object")
if not isinstance(timeframe, str):
raise TypeError("record_length must be a string representing the timeframe")
if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details):
raise TypeError("ex_details must be a list of strings")
if not all([start_datetime, end_datetime, timeframe, ex_details]):
raise ValueError("Missing required arguments")
request_criteria = {
'start_datetime': start_datetime,
'end_datetime': end_datetime,
'timeframe': timeframe,
}
key = self._make_key(ex_details=ex_details)
combined_data = pd.DataFrame()
if target == 'data':
resources = [self._get_candles_from_cache, self._get_from_database, self._get_from_server]
elif target == 'database':
resources = [self._get_from_database, self._get_from_server]
elif target == 'server':
resources = [self._get_from_server]
else:
raise ValueError('Not a valid Target!')
for fetch_method in resources:
result = fetch_method(**kwargs)
if result is not None and not result.empty:
# Drop the 'id' column if it exists in the result
if 'id' in result.columns:
result = result.drop(columns=['id'])
# Concatenate, drop duplicates based on 'time', and sort
combined_data = pd.concat([combined_data, result]).drop_duplicates(subset='time').sort_values(
by='time')
is_complete, request_criteria = self._data_complete(data=combined_data, **request_criteria)
if is_complete:
if fetch_method in [self._get_from_database, self._get_from_server]:
self._update_candle_cache(more_records=combined_data, key=key)
if fetch_method == self._get_from_server:
self._populate_db(ex_details=ex_details, data=combined_data)
return combined_data
kwargs.update(request_criteria) # Update kwargs with new start/end times for next fetch attempt
logger.error('Unable to fetch the requested data.')
return combined_data if not combined_data.empty else pd.DataFrame()
def _get_candles_from_cache(self, **kwargs) -> pd.DataFrame:
start_datetime = kwargs.get('start_datetime')
end_datetime = kwargs.get('end_datetime')
ex_details = kwargs.get('ex_details')
if not all([start_datetime, end_datetime, ex_details]):
raise ValueError("Missing required arguments for candle data retrieval.")
key = self._make_key(ex_details=ex_details)
logger.debug('Getting records from candles cache.')
df = self.get_cache('candles').get_entry(key=key)
if df is None or df.empty:
logger.debug("No cached records found.")
return pd.DataFrame()
df_filtered = df[(df['time'] >= unix_time_millis(start_datetime)) &
(df['time'] <= unix_time_millis(end_datetime))].reset_index(drop=True)
return df_filtered
def _get_from_database(self, **kwargs) -> pd.DataFrame:
start_datetime = kwargs.get('start_datetime')
if start_datetime.tzinfo is None:
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
end_datetime = kwargs.get('end_datetime')
if end_datetime.tzinfo is None:
raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.")
ex_details = kwargs.get('ex_details')
if self.TYPECHECKING_ENABLED:
if not isinstance(start_datetime, dt.datetime):
raise TypeError("start_datetime must be a datetime object")
if not isinstance(end_datetime, dt.datetime):
raise TypeError("end_datetime must be a datetime object")
if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details):
raise TypeError("ex_details must be a list of strings")
if not all([start_datetime, end_datetime, ex_details]):
raise ValueError("Missing required arguments")
table_name = self._make_key(ex_details=ex_details)
if not self.db.table_exists(table_name):
logger.debug('Records not in database.')
return pd.DataFrame()
logger.debug('Getting records from database.')
return self.db.get_timestamped_records(table_name=table_name, timestamp_field='time', st=start_datetime,
et=end_datetime)
def _get_from_server(self, **kwargs) -> pd.DataFrame:
symbol = kwargs.get('ex_details')[0]
interval = kwargs.get('ex_details')[1]
exchange_name = kwargs.get('ex_details')[2]
user_name = kwargs.get('ex_details')[3]
start_datetime = kwargs.get('start_datetime')
if start_datetime.tzinfo is None:
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
end_datetime = kwargs.get('end_datetime')
if end_datetime.tzinfo is None:
raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.")
if self.TYPECHECKING_ENABLED:
if not isinstance(symbol, str):
raise TypeError("symbol must be a string")
if not isinstance(interval, str):
raise TypeError("interval must be a string")
if not isinstance(exchange_name, str):
raise TypeError("exchange_name must be a string")
if not isinstance(user_name, str):
raise TypeError("user_name must be a string")
if not isinstance(start_datetime, dt.datetime):
raise TypeError("start_datetime must be a datetime object")
if not isinstance(end_datetime, dt.datetime):
raise TypeError("end_datetime must be a datetime object")
logger.debug('Getting records from server.')
return self._fetch_candles_from_exchange(symbol=symbol, interval=interval, exchange_name=exchange_name,
user_name=user_name, start_datetime=start_datetime,
end_datetime=end_datetime)
@staticmethod
def _data_complete(data: pd.DataFrame, **kwargs) -> (bool, dict):
"""
Checks if the data completely satisfies the request.
:param data: DataFrame containing the records.
:param kwargs: Arguments required for completeness check.
:return: A tuple (is_complete, updated_request_criteria) where is_complete is True if the data is complete,
False otherwise, and updated_request_criteria contains adjusted start/end times if data is incomplete.
"""
if data.empty:
logger.debug("Data is empty.")
return False, kwargs # No data at all, proceed with the full original request
start_datetime: dt.datetime = kwargs.get('start_datetime')
if start_datetime.tzinfo is None:
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
end_datetime: dt.datetime = kwargs.get('end_datetime')
if end_datetime.tzinfo is None:
raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.")
timeframe: str = kwargs.get('timeframe')
temp_data = data.copy()
# Convert 'time' to datetime with unit='ms' and localize to UTC
temp_data['time_dt'] = pd.to_datetime(temp_data['time'],
unit='ms', errors='coerce').dt.tz_localize('UTC')
min_timestamp = temp_data['time_dt'].min()
max_timestamp = temp_data['time_dt'].max()
logger.debug(f"Data time range: {min_timestamp} to {max_timestamp}")
logger.debug(f"Expected time range: {start_datetime} to {end_datetime}")
tolerance = pd.Timedelta(seconds=5)
# Initialize updated request criteria
updated_request_criteria = kwargs.copy()
# Check if data covers the required time range with tolerance
if min_timestamp > start_datetime + timeframe_to_timedelta(timeframe) + tolerance:
logger.debug("Data does not start early enough, even with tolerance.")
updated_request_criteria['end_datetime'] = min_timestamp # Fetch the missing earlier data
return False, updated_request_criteria
if max_timestamp < end_datetime - timeframe_to_timedelta(timeframe) - tolerance:
logger.debug("Data does not extend late enough, even with tolerance.")
updated_request_criteria['start_datetime'] = max_timestamp # Fetch the missing later data
return False, updated_request_criteria
# Filter data between start_datetime and end_datetime
mask = (temp_data['time_dt'] >= start_datetime) & (temp_data['time_dt'] <= end_datetime)
data_in_range = temp_data.loc[mask]
expected_count = estimate_record_count(start_datetime, end_datetime, timeframe)
actual_count = len(data_in_range)
logger.debug(f"Expected record count: {expected_count}, Actual record count: {actual_count}")
tolerance = 1
if actual_count < (expected_count - tolerance):
logger.debug("Insufficient records within the specified time range, even with tolerance.")
return False, updated_request_criteria
logger.debug("Data completeness check passed.")
return True, kwargs
def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str,
start_datetime: dt.datetime = None,
end_datetime: dt.datetime = None) -> pd.DataFrame:
if start_datetime is None:
start_datetime = dt.datetime(year=2017, month=1, day=1, tzinfo=dt.timezone.utc)
if end_datetime is None:
end_datetime = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc)
if start_datetime > end_datetime:
raise ValueError("Invalid start and end parameters: start_datetime must be before end_datetime.")
exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name)
expected_records = estimate_record_count(start_datetime, end_datetime, interval)
logger.info(
f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {expected_records}')
if start_datetime == end_datetime:
end_datetime = None
candles = exchange.get_historical_klines(symbol=symbol, interval=interval, start_dt=start_datetime,
end_dt=end_datetime)
num_rec_records = len(candles.index)
if num_rec_records == 0:
logger.warning(f"No OHLCV data returned for {symbol}.")
return pd.DataFrame(columns=['time', 'open', 'high', 'low', 'close', 'volume'])
logger.info(f'{num_rec_records} candles retrieved from the exchange.')
open_times = candles.time
min_open_time = open_times.min()
max_open_time = open_times.max()
if min_open_time < 1e10:
raise ValueError('Records are not in milliseconds')
estimated_num_records = estimate_record_count(min_open_time, max_open_time, interval) + 1
logger.info(f'Estimated number of records: {estimated_num_records}')
if num_rec_records < estimated_num_records:
logger.info('Detected gaps in the data, attempting to fill missing records.')
candles = self._fill_data_holes(records=candles, interval=interval)
return candles
@staticmethod
def _fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame:
time_span = timeframe_to_timedelta(interval).total_seconds() / 60
last_timestamp = None
filled_records = []
logger.info(f"Starting to fill data holes for interval: {interval}")
for index, row in records.iterrows():
time_stamp = row['time']
if last_timestamp is None:
last_timestamp = time_stamp
filled_records.append(row)
logger.debug(f"First timestamp: {time_stamp}")
continue
delta_ms = time_stamp - last_timestamp
delta_minutes = (delta_ms / 1000) / 60
logger.debug(f"Timestamp: {time_stamp}, Delta minutes: {delta_minutes}")
if delta_minutes > time_span:
num_missing_rec = int(delta_minutes / time_span)
step = int(delta_ms / num_missing_rec)
logger.debug(f"Gap detected. Filling {num_missing_rec} records with step: {step}")
for ts in range(int(last_timestamp) + step, int(time_stamp), step):
new_row = row.copy()
new_row['time'] = ts
filled_records.append(new_row)
logger.debug(f"Filled timestamp: {ts}")
filled_records.append(row)
last_timestamp = time_stamp
logger.info("Data holes filled successfully.")
return pd.DataFrame(filled_records)
def _populate_db(self, ex_details: list[str], data: pd.DataFrame = None) -> None:
if data is None or data.empty:
logger.debug(f'No records to insert {data}')
return
table_name = self._make_key(ex_details=ex_details)
symbol, _, exchange, _ = ex_details
self.db.insert_candles_into_db(candlesticks=data, table_name=table_name, symbol=symbol, exchange_name=exchange)
logger.info(f'Data inserted into table {table_name}')
class IndicatorCache(ServerInteractions):
"""
IndicatorCache extends ServerInteractions and manages caching of both instantiated indicator objects
and their calculated data. This class optimizes the indicator calculation process by caching and
reusing data where possible, minimizing redundant computations.
Attributes:
indicator_registry (dict): A dictionary mapping indicator types (e.g., 'SMA', 'EMA') to their classes.
"""
def __init__(self):
"""
Initialize the IndicatorCache with caches for indicators and their calculated data.
"""
super().__init__()
# Registry of available indicators
self.indicator_registry = indicators_registry
@staticmethod
def _make_indicator_key(symbol: str, timeframe: str, exchange_name: str, indicator_type: str, period: int) -> str:
"""
Generates a unique cache key for caching the indicator data based on the input properties.
"""
return f"{symbol}_{timeframe}_{exchange_name}_{indicator_type}_{period}"
def get_indicator_instance(self, indicator_type: str, properties: dict) -> Indicator:
"""
Retrieves or creates an instance of an indicator based on its type and properties.
"""
if indicator_type not in self.indicator_registry:
raise ValueError(f"Unsupported indicator type: {indicator_type}")
indicator_class = self.indicator_registry[indicator_type]
return indicator_class(name=indicator_type, indicator_type=indicator_type, properties=properties)
def set_user_indicator_properties(self, user_id: str, indicator_type: str, symbol: str, timeframe: str,
exchange_name: str, display_properties: dict):
"""
Stores or updates user-specific display properties for an indicator.
"""
if not isinstance(display_properties, dict):
raise ValueError("display_properties must be a dictionary")
user_cache_key = f"user_{user_id}_{indicator_type}_{symbol}_{timeframe}_{exchange_name}"
self.serialized_datacache_insert(key=user_cache_key, data=display_properties,
cache_name='user_display_properties')
def get_user_indicator_properties(self, user_id: str, indicator_type: str, symbol: str, timeframe: str,
exchange_name: str) -> dict:
"""
Retrieves user-specific display properties for an indicator.
"""
# Ensure the arguments are valid
if not isinstance(user_id, str) or not isinstance(indicator_type, str) or \
not isinstance(symbol, str) or not isinstance(timeframe, str) or \
not isinstance(exchange_name, str):
raise TypeError("All arguments must be of type str")
user_cache_key = f"user_{user_id}_{indicator_type}_{symbol}_{timeframe}_{exchange_name}"
return self.get_rows_from_datacache(key=user_cache_key, cache_name='user_display_properties')
def find_gaps_in_intervals(self, cached_data, start_idx, end_idx, timeframe, min_gap_size=None):
"""
Recursively searches for missing intervals in cached data based on time gaps.
:param cached_data: DataFrame containing cached time-series data.
:param start_idx: Start index in the DataFrame.
:param end_idx: End index in the DataFrame.
:param timeframe: Expected interval between consecutive time points (e.g., '5m', '1h').
:param min_gap_size: The minimum allowable gap (in terms of the number of missing records) to consider.
:return: A list of tuples representing the missing intervals (start, end).
"""
if cached_data.empty or end_idx - start_idx <= 1:
return []
start_time = cached_data['time'].iloc[start_idx]
end_time = cached_data['time'].iloc[end_idx]
expected_records = estimate_record_count(start_time, end_time, timeframe)
actual_records = end_idx - start_idx + 1
if expected_records == actual_records:
return []
if min_gap_size and (expected_records - actual_records) < min_gap_size:
return []
mid_idx = (start_idx + end_idx) // 2
left_missing = self.find_gaps_in_intervals(cached_data, start_idx, mid_idx, timeframe, min_gap_size)
right_missing = self.find_gaps_in_intervals(cached_data, mid_idx, end_idx, timeframe, min_gap_size)
return left_missing + right_missing
def calculate_indicator(self, user_name: str, symbol: str, timeframe: str, exchange_name: str, indicator_type: str,
start_datetime: dt.datetime, end_datetime: dt.datetime, properties: dict) -> dict:
"""
Fetches or calculates the indicator data for the specified range, caches the shared results,
and combines them with user-specific display properties or indicator defaults.
"""
# Ensure start_datetime and end_datetime are timezone-aware
if start_datetime.tzinfo is None or end_datetime.tzinfo is None:
raise ValueError("Datetime objects must be timezone-aware")
# Step 1: Create cache key and indicator instance
cache_key = self._make_indicator_key(symbol, timeframe, exchange_name, indicator_type, properties['period'])
indicator_instance = self.get_indicator_instance(indicator_type, properties)
# Step 2: Check cache for existing data
calculated_data = self._fetch_cached_data(cache_key, start_datetime, end_datetime)
missing_intervals = self._determine_missing_intervals(calculated_data, start_datetime, end_datetime, timeframe)
# Step 3: Fetch and calculate data for missing intervals
if missing_intervals:
calculated_data = self._fetch_and_calculate_missing_data(
missing_intervals, calculated_data, indicator_instance, user_name, symbol, timeframe, exchange_name
)
# Step 4: Cache the newly calculated data
self.serialized_datacache_insert(key=cache_key, data=calculated_data, cache_name='indicator_data')
# Step 5: Retrieve and merge user-specific display properties with defaults
merged_properties = self._get_merged_properties(user_name, indicator_type, symbol, timeframe, exchange_name,
indicator_instance)
return {
'calculation_data': calculated_data.to_dict('records'),
'display_properties': merged_properties
}
def _fetch_cached_data(self, cache_key: str, start_datetime: dt.datetime,
end_datetime: dt.datetime) -> pd.DataFrame:
"""
Fetches cached data for the given time range.
"""
# Retrieve cached data (expected to be a DataFrame with 'time' in Unix ms)
cached_df = self.get_rows_from_datacache(key=cache_key, cache_name='indicator_data')
# If no cached data, return an empty DataFrame
if cached_df is None or cached_df.empty:
return pd.DataFrame()
# Convert start and end datetime to Unix timestamps (milliseconds)
start_timestamp = int(start_datetime.timestamp() * 1000)
end_timestamp = int(end_datetime.timestamp() * 1000)
# Convert cached 'time' column to Unix timestamps for comparison
cached_df['time'] = cached_df['time'].astype('int64') // 10 ** 6
# Return only the data within the specified time range
return cached_df[(cached_df['time'] >= start_timestamp) & (cached_df['time'] <= end_timestamp)]
def _determine_missing_intervals(self, cached_data: pd.DataFrame, start_datetime: dt.datetime,
end_datetime: dt.datetime, timeframe: str):
"""
Determines missing intervals in the cached data by comparing the requested time range with cached data.
:param cached_data: Cached DataFrame with time-series data in Unix timestamp (ms).
:param start_datetime: Start datetime of the requested data range.
:param end_datetime: End datetime of the requested data range.
:param timeframe: Expected interval between consecutive time points (e.g., '5m', '1h').
:return: A list of tuples representing the missing intervals (start, end).
"""
# Convert start and end datetime to Unix timestamps (milliseconds)
# start_timestamp = int(start_datetime.timestamp() * 1000)
end_timestamp = int(end_datetime.timestamp() * 1000)
if cached_data is not None and not cached_data.empty:
# Find the last (most recent) time in the cached data (which is in Unix timestamp ms)
cached_end = cached_data['time'].max()
# If the cached data ends before the requested end time, add that range to missing intervals
if cached_end < end_timestamp:
# Convert cached_end back to datetime for consistency in returned intervals
cached_end_dt = pd.to_datetime(cached_end, unit='ms').to_pydatetime()
return [(cached_end_dt, end_datetime)]
# Otherwise, check for gaps within the cached data itself
return self.find_gaps_in_intervals(cached_data, 0, len(cached_data) - 1, timeframe)
# If no cached data exists, the entire range is missing
return [(start_datetime, end_datetime)]
def _fetch_and_calculate_missing_data(self, missing_intervals, calculated_data, indicator_instance, user_name,
symbol, timeframe, exchange_name):
"""
Fetches and calculates missing data for the specified intervals.
"""
for interval_start, interval_end in missing_intervals:
# Ensure interval_start and interval_end are timezone-aware
if interval_start.tzinfo is None:
interval_start = interval_start.replace(tzinfo=dt.timezone.utc)
if interval_end.tzinfo is None:
interval_end = interval_end.replace(tzinfo=dt.timezone.utc)
ohlc_data = self.get_or_fetch_from('data', start_datetime=interval_start, end_datetime=interval_end,
ex_details=[symbol, timeframe, exchange_name, user_name])
if ohlc_data.empty or 'close' not in ohlc_data.columns:
continue
new_data = indicator_instance.calculate(candles=ohlc_data, user_name=user_name)
calculated_data = pd.concat([calculated_data, new_data], ignore_index=True)
return calculated_data
def _get_merged_properties(self, user_name, indicator_type, symbol, timeframe, exchange_name, indicator_instance):
"""
Retrieves and merges user-specific properties with default indicator properties.
"""
user_properties = self.get_user_indicator_properties(user_name, indicator_type,
symbol, timeframe, exchange_name)
if not user_properties:
user_properties = {
key: value for key,
value in indicator_instance.properties.items() if key.startswith('color') or key.startswith('thickness')
}
return {**indicator_instance.properties, **user_properties}
class DataCache(IndicatorCache):
"""
Extends ServerInteractions, DatabaseInteractions, SnapshotDataCache, and DataCacheBase to create a fully functional
cache management system that can store and manage custom caching objects of type InMemoryCache. The following
illustrates the complete structure of data, leveraging all parts of the system.
`caches` is a dictionary of InMemoryCache objects indexed by name. An InMemoryCache is a custom object that contains
attributes such as `limit`, `eviction_policy`, and a `cache`, which is a DataFrame containing the actual data. This
structure allows the cache management system to leverage pandas' powerful querying and filtering capabilities to
manipulate and query the data efficiently while maintaining the cache's overall structure.
Example Structure:
caches = {
[name1]: {
'limit': [limit],
'eviction_policy': [eviction_policy],
'cache': pd.DataFrame({
'tbl_key': [key],
'creation_time': [creation_time],
'expire_delta': [expire_delta],
'data': pd.DataFrame({
'column1': [data],
'column2': [data],
'column3': [data]
})
})
}
}
# The following is an example of the structure in use:
caches = {
'users': {
'limit': 100,
'eviction_policy': 'deny',
'cache': pd.DataFrame({
'tbl_key': ['user1', 'user2'],
'creation_time': [
datetime.datetime(2024, 8, 24, 12, 45, 31, 117684),
datetime.datetime(2023, 8, 12, 6, 35, 21, 113644)
],
'expire_delta': [
datetime.timedelta(seconds=600),
None
],
'data': [
pd.DataFrame({
'user_name': ['Billy'],
'password': [1234],
'exchanges': [['ex1', 'ex2', 'ex3']]
}),
pd.DataFrame({
'user_name': ['Patty'],
'password': [5678],
'exchanges': [['ex1', 'ex3', 'ex5']]
})
]
})
},
'strategies': {
'limit': 100,
'eviction_policy': 'expire',
'cache': pd.DataFrame({
'tbl_key': ['strategy1', 'strategy2'],
'creation_time': [
datetime.datetime(2024, 3, 13, 12, 45, 31, 117684),
datetime.datetime(2023, 4, 12, 6, 35, 21, 113644)
],
'expire_delta': [
datetime.timedelta(seconds=600),
datetime.timedelta(seconds=600)
],
'data': [
pd.DataFrame({
'strategy_name': ['Awesome_v1.2'],
'code': [json.dumps({"params": {"t": "ex1", "m": 100, "ex": True}})],
'max_loss': [100]
}),
pd.DataFrame({
'strategy_name': ['Awesome_v2.0'],
'code': [json.dumps({"params": {"t": "ex2", "m": 67, "ex": False}})],
'max_loss': [40]
})
]
})
},
...
}
Usage Examples:
# Extract all user data DataFrames
user_data = pd.concat(caches['users']['cache']['data'].tolist(), ignore_index=True)
# Check if a username is taken
username_to_check = 'Billy'
is_taken = not user_data[user_data['user_name'] == username_to_check].empty
# Extract all user data DataFrames
user_data = pd.concat(caches['users']['cache']['data'].tolist(), ignore_index=True)
# Find users with a specific exchange
exchange_to_check = 'ex3'
users_with_exchange = user_data[user_data['exchanges'].apply(lambda x: exchange_to_check in x)]
# Filter out the expired users
current_time = datetime.datetime.now(datetime.timezone.utc)
non_expired_users = caches['users']['cache'].apply(
lambda row: current_time <= row['creation_time'] + row['expire_delta'] if row['expire_delta'] else True,
axis=1
)
caches['users']['cache'] = caches['users']['cache'][non_expired_users]
"""
def __init__(self):
super().__init__()
logger.info("DataCache initialized.")