1920 lines
88 KiB
Python
1920 lines
88 KiB
Python
import copy
|
|
import pickle
|
|
import time
|
|
import logging
|
|
import datetime as dt
|
|
from collections import deque
|
|
from typing import Any, Tuple, List, Optional
|
|
|
|
import pandas as pd
|
|
import numpy as np
|
|
|
|
from indicators import Indicator, indicators_registry
|
|
from shared_utilities import unix_time_millis
|
|
from Database import Database
|
|
|
|
# Configure logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Helper Methods
|
|
def timeframe_to_timedelta(timeframe: str) -> pd.Timedelta | pd.DateOffset:
|
|
digits = int("".join([i if i.isdigit() else "" for i in timeframe]))
|
|
unit = "".join([i if i.isalpha() else "" for i in timeframe])
|
|
|
|
if unit == 's': # Add support for seconds
|
|
return pd.Timedelta(seconds=digits)
|
|
elif unit == 'm':
|
|
return pd.Timedelta(minutes=digits)
|
|
elif unit == 'h':
|
|
return pd.Timedelta(hours=digits)
|
|
elif unit == 'd':
|
|
return pd.Timedelta(days=digits)
|
|
elif unit == 'w':
|
|
return pd.Timedelta(weeks=digits)
|
|
elif unit == 'M':
|
|
return pd.DateOffset(months=digits)
|
|
elif unit == 'Y':
|
|
return pd.DateOffset(years=digits)
|
|
else:
|
|
raise ValueError(f"Invalid timeframe unit: {unit}")
|
|
|
|
|
|
def estimate_record_count(start_time, end_time, timeframe: str) -> int:
|
|
"""
|
|
Estimate the number of records expected between start_time and end_time based on the given timeframe.
|
|
Accepts either datetime objects or Unix timestamps in milliseconds.
|
|
"""
|
|
# Check for invalid inputs
|
|
if not isinstance(start_time, (int, float, np.integer, dt.datetime)):
|
|
raise ValueError(f"Invalid type for start_time: {type(start_time)}. Expected datetime or timestamp.")
|
|
if not isinstance(end_time, (int, float, np.integer, dt.datetime)):
|
|
raise ValueError(f"Invalid type for end_time: {type(end_time)}. Expected datetime or timestamp.")
|
|
|
|
# Convert milliseconds to datetime if needed
|
|
if isinstance(start_time, (int, float, np.integer)) and isinstance(end_time, (int, float, np.integer)):
|
|
start_time = dt.datetime.utcfromtimestamp(start_time / 1000).replace(tzinfo=dt.timezone.utc)
|
|
end_time = dt.datetime.utcfromtimestamp(end_time / 1000).replace(tzinfo=dt.timezone.utc)
|
|
elif isinstance(start_time, dt.datetime) and isinstance(end_time, dt.datetime):
|
|
if start_time.tzinfo is None or end_time.tzinfo is None:
|
|
raise ValueError("start_time and end_time must be timezone-aware.")
|
|
# Normalize both to UTC to avoid time zone differences affecting results
|
|
start_time = start_time.astimezone(dt.timezone.utc)
|
|
end_time = end_time.astimezone(dt.timezone.utc)
|
|
|
|
# If the interval is negative, return 0
|
|
if end_time <= start_time:
|
|
return 0
|
|
|
|
# Use the existing timeframe_to_timedelta function
|
|
delta = timeframe_to_timedelta(timeframe)
|
|
|
|
# Fixed-length intervals optimization
|
|
if isinstance(delta, dt.timedelta):
|
|
total_seconds = (end_time - start_time).total_seconds()
|
|
return int(total_seconds // delta.total_seconds())
|
|
|
|
# Handle 'M' (months) and 'Y' (years) using DateOffset
|
|
elif isinstance(delta, pd.DateOffset):
|
|
if 'months' in delta.kwds:
|
|
return (end_time.year - start_time.year) * 12 + (end_time.month - start_time.month)
|
|
elif 'years' in delta.kwds:
|
|
return end_time.year - start_time.year
|
|
|
|
raise ValueError(f"Invalid timeframe: {timeframe}")
|
|
|
|
|
|
class CacheEntryMetadata:
|
|
"""Stores metadata for a cache entry (row)."""
|
|
|
|
def __init__(self, expiration_time: Optional[int] = None):
|
|
self.creation_time = time.time()
|
|
self.expiration_time = expiration_time
|
|
self.expiration_timestamp = self.creation_time + expiration_time if expiration_time else None
|
|
|
|
def is_expired(self) -> bool:
|
|
"""Check if the entry is expired."""
|
|
if self.expiration_time is None:
|
|
return False # No expiration set, entry never expires
|
|
if self.expiration_time == 0:
|
|
return True # Expire immediately
|
|
return time.time() > self.expiration_timestamp
|
|
|
|
|
|
class CacheEntry:
|
|
"""Stores data and its expiration metadata."""
|
|
|
|
def __init__(self, data: Any, expiration_time: Optional[int] = None):
|
|
self.data = data
|
|
self.metadata = CacheEntryMetadata(expiration_time)
|
|
|
|
|
|
class RowBasedCache:
|
|
"""Cache for storing individual rows, where each entry has a unique key."""
|
|
|
|
def __init__(self, default_expiration: Optional[int] = None, size_limit: Optional[int] = None,
|
|
eviction_policy: str = "evict", purge_threshold: int = 10):
|
|
self.cache = {}
|
|
self.default_expiration = default_expiration
|
|
self.size_limit = size_limit
|
|
self.eviction_policy = eviction_policy
|
|
self.access_order = deque() # Tracks the order of access for eviction
|
|
self.access_counter = 0 # Counter to track accesses
|
|
self.purge_threshold = purge_threshold # Define how often to trigger purge
|
|
|
|
def add_entry(self, key: str, data: Any, expiration_time: Optional[int] = None):
|
|
"""Add an entry to the cache."""
|
|
self._check_purge() # Check if purge is needed
|
|
if self.size_limit is not None and len(self.cache) >= self.size_limit:
|
|
if self.eviction_policy == "evict":
|
|
self.evict()
|
|
elif self.eviction_policy == "deny":
|
|
return "Cache limit reached. Entry not added."
|
|
|
|
# If key already exists and is a DataFrame, append the new data
|
|
if key in self.cache and isinstance(self.cache[key].data, pd.DataFrame):
|
|
# If the key already exists, append the new data to the existing DataFrame
|
|
if isinstance(self.cache[key].data, pd.DataFrame):
|
|
self.cache[key].data = pd.concat([self.cache[key].data, data], ignore_index=True)
|
|
else:
|
|
self.cache[key].data = data # For non-DataFrame types, just replace the data
|
|
else:
|
|
# Otherwise, replace the entry with the new data
|
|
expiration_time = expiration_time or self.default_expiration
|
|
self.cache[key] = CacheEntry(data, expiration_time)
|
|
|
|
# Update access order
|
|
if key not in self.access_order:
|
|
self.access_order.append(key)
|
|
else:
|
|
self.access_order.remove(key) # Move the key to the end
|
|
self.access_order.append(key)
|
|
|
|
def get_entry(self, key: str) -> Any:
|
|
"""Retrieve an entry by key, ensuring expired entries are ignored."""
|
|
self._check_purge() # Check if purge is needed
|
|
if key in self.cache:
|
|
if not self.cache[key].metadata.is_expired():
|
|
self.access_order.remove(key)
|
|
self.access_order.append(key) # Update access order for eviction
|
|
return self.cache[key].data
|
|
else:
|
|
del self.cache[key] # Remove expired entry
|
|
return None
|
|
|
|
def query(self, conditions: List[Tuple[str, Any]]) -> pd.DataFrame:
|
|
"""Query cache entries by conditions, ignoring expired entries."""
|
|
self._check_purge() # Check if purge is needed
|
|
|
|
# Get the value of tbl_key out of the list of key-value pairs.
|
|
key_value = next((value for key, value in conditions if key == 'tbl_key'), None)
|
|
if key_value is None or key_value not in self.cache:
|
|
return pd.DataFrame() # Return an empty DataFrame if key is not found
|
|
|
|
entry = self.cache[key_value]
|
|
|
|
# Expire entry if expired.
|
|
if entry.metadata.is_expired():
|
|
# Remove expired entry and Return an empty DataFrame
|
|
del self.cache[key_value]
|
|
return pd.DataFrame()
|
|
|
|
data = entry.data
|
|
|
|
# If the data is a DataFrame, apply the conditions using pandas .query()
|
|
if isinstance(data, pd.DataFrame):
|
|
# Construct the query string and prepare local variables for the query
|
|
query_conditions = ' and '.join([f'`{col}` == @val_{col}' for col, _ in conditions if col != 'tbl_key'])
|
|
query_vars = {f'val_{col}': val for col, val in conditions if col != 'tbl_key'}
|
|
|
|
# Use pandas .query() with local_dict to pass the variables
|
|
return data.query(query_conditions, local_dict=query_vars) if query_conditions else data
|
|
|
|
return pd.DataFrame([data]) # Return the non-DataFrame data as a single row DataFrame if possible
|
|
|
|
def is_attr_taken(self, column: str, value: Any) -> bool:
|
|
"""Check if a column contains the specified value in the Row-Based Cache."""
|
|
self._check_purge() # Check if purge is needed
|
|
for key, entry in self.cache.items():
|
|
if isinstance(entry.data, pd.DataFrame): # Only apply to DataFrames
|
|
if column in entry.data.columns:
|
|
# Use DataFrame.query to check if the column contains the value
|
|
query_result = entry.data.query(f'`{column}` == @value', local_dict={'value': value})
|
|
if not query_result.empty:
|
|
return True # Return True if the value exists in any DataFrame row
|
|
return False # Return False if no match is found
|
|
|
|
def evict(self):
|
|
"""Evict the oldest accessed entry."""
|
|
oldest_key = self.access_order.popleft()
|
|
del self.cache[oldest_key]
|
|
|
|
def _check_purge(self):
|
|
"""Increment the access counter and trigger purge if threshold is reached."""
|
|
self.access_counter += 1
|
|
if self.access_counter >= self.purge_threshold:
|
|
self._purge_expired()
|
|
self.access_counter = 0 # Reset the counter after purging
|
|
|
|
def _purge_expired(self):
|
|
"""Remove expired entries from the cache."""
|
|
expired_keys = [key for key, entry in self.cache.items() if entry.metadata.is_expired()]
|
|
for key in expired_keys:
|
|
del self.cache[key]
|
|
self.access_order.remove(key)
|
|
|
|
def get_all_items(self) -> dict[str, Any]:
|
|
"""Retrieve all non-expired items in the cache."""
|
|
self._check_purge() # Ensure expired entries are purged as needed
|
|
return {key: entry.data for key, entry in self.cache.items() if not entry.metadata.is_expired()}
|
|
|
|
def remove_item(self, conditions: List[Tuple[str, Any]]) -> bool:
|
|
"""Remove an item from the cache using key-value conditions.
|
|
In row cache, only 'tbl_key' is used to identify the entry.
|
|
"""
|
|
# Find the value of 'tbl_key' from the conditions
|
|
key_value = next((value for key, value in conditions if key == 'tbl_key'), None)
|
|
if key_value is None or key_value not in self.cache:
|
|
return False # Key not found, so nothing to remove
|
|
|
|
# If no additional conditions are provided, remove the entire entry by key
|
|
if len(conditions) == 1:
|
|
del self.cache[key_value]
|
|
self.access_order.remove(key_value)
|
|
return True
|
|
|
|
entry = self.cache[key_value]
|
|
# If the data is a DataFrame, apply additional filtering
|
|
if isinstance(entry.data, pd.DataFrame):
|
|
# Construct the query string and prepare local variables for the query
|
|
query_conditions = ' and '.join([f'`{col}` == @val_{col}' for col, _ in conditions if col != 'tbl_key'])
|
|
query_vars = {f'val_{col}': val for col, val in conditions if col != 'tbl_key'}
|
|
|
|
# Apply the query to the DataFrame, removing matching rows
|
|
remaining_data = entry.data.query(f'not ({query_conditions})', local_dict=query_vars)
|
|
if remaining_data.empty:
|
|
# If all rows are removed, delete the entire entry
|
|
del self.cache[key_value]
|
|
self.access_order.remove(key_value)
|
|
else:
|
|
# Update the entry with the remaining rows
|
|
entry.data = remaining_data
|
|
else:
|
|
# If the data is not a DataFrame, remove the entire entry if the 'tbl_key' matches
|
|
del self.cache[key_value]
|
|
self.access_order.remove(key_value)
|
|
return True # Successfully removed the item
|
|
|
|
|
|
class TableBasedCache:
|
|
"""Cache for storing entire tables with expiration applied to rows."""
|
|
|
|
def __init__(self, default_expiration: Optional[int] = None, size_limit: Optional[int] = None,
|
|
eviction_policy: str = "evict"):
|
|
self.cache = pd.DataFrame() # The DataFrame where both data and metadata are stored
|
|
self.default_expiration = default_expiration
|
|
self.size_limit = size_limit
|
|
self.eviction_policy = eviction_policy
|
|
self.access_order = deque() # Tracks the order of access for eviction
|
|
|
|
def _check_size_limit(self):
|
|
"""Check and enforce the size limit."""
|
|
if self.size_limit and len(self.cache) > self.size_limit:
|
|
if self.eviction_policy == "evict":
|
|
excess = len(self.cache) - self.size_limit
|
|
self.evict(excess) # Evict excess rows
|
|
elif self.eviction_policy == "deny":
|
|
return False # Don't allow more rows to be added
|
|
return True
|
|
|
|
def add_table(self, df: pd.DataFrame, expiration_time: Optional[int] = None, overwrite: Optional[str] = None,
|
|
key: Optional[str] = None):
|
|
"""
|
|
Adds a DataFrame to the cache, attaching metadata to each row.
|
|
Optionally overwrites rows based on a column value.
|
|
|
|
:param overwrite: Column name to use for identifying rows to overwrite.
|
|
:param df: The DataFrame to add.
|
|
:param expiration_time: Optional expiration time for the rows.
|
|
:param key:
|
|
"""
|
|
expiration_time = expiration_time or self.default_expiration
|
|
if expiration_time is not None:
|
|
metadata = [CacheEntryMetadata(expiration_time) for _ in range(len(df))]
|
|
else:
|
|
metadata = [CacheEntryMetadata() for _ in range(len(df))]
|
|
|
|
# Add metadata to each row of the DataFrame
|
|
df_with_metadata = df.copy()
|
|
df_with_metadata['metadata'] = metadata
|
|
|
|
# If a key is provided, add a 'tbl_key' column to the DataFrame
|
|
if key is not None:
|
|
df_with_metadata['tbl_key'] = key
|
|
|
|
if getattr(self, 'cache', None) is None:
|
|
# If the cache is empty, initialize it with the new DataFrame
|
|
self.cache = df_with_metadata
|
|
else:
|
|
# Append the new rows
|
|
self.cache = pd.concat([self.cache, df_with_metadata], ignore_index=True)
|
|
|
|
if overwrite:
|
|
# Drop duplicates based on the overwrite column, keeping the last occurrence (new data)
|
|
self.cache = self.cache.drop_duplicates(subset=overwrite, keep='last')
|
|
|
|
# Enforce size limit
|
|
if not self._check_size_limit():
|
|
return "Cache limit reached. Table not added."
|
|
|
|
def _purge_expired(self):
|
|
"""Remove expired rows from the cache."""
|
|
try:
|
|
# Filter rows where metadata is not expired, keep columns even if no valid rows
|
|
is_valid = self.cache['metadata'].apply(lambda meta: not meta.is_expired())
|
|
|
|
# Filter DataFrame, ensuring columns are always kept
|
|
self.cache = self.cache.loc[is_valid].reindex(self.cache.columns, axis=1).reset_index(drop=True)
|
|
except KeyError:
|
|
raise KeyError("The 'metadata' column is missing from the cache.")
|
|
except AttributeError as e:
|
|
raise AttributeError(f"Error in metadata processing: {e}")
|
|
|
|
def query(self, conditions: List[Tuple[str, Any]]) -> pd.DataFrame:
|
|
"""Query rows based on conditions and return valid (non-expired) entries."""
|
|
self._purge_expired() # Remove expired rows before querying
|
|
|
|
# Start with the entire cache
|
|
result = self.cache.copy()
|
|
|
|
# Apply conditions using pandas .query()
|
|
if not result.empty:
|
|
query_conditions = ' and '.join([f'`{col}` == @val_{col}' for col, _ in conditions])
|
|
query_vars = {f'val_{col}': val for col, val in conditions}
|
|
|
|
# Use pandas .query() with local_dict to pass the variables
|
|
result = result.query(query_conditions, local_dict=query_vars) if query_conditions else result
|
|
|
|
# Remove the metadata and tbl_key columns for the result
|
|
return result.drop(columns=['metadata'], errors='ignore')
|
|
|
|
def is_attr_taken(self, column: str, value: Any) -> bool:
|
|
"""Check if a column contains the specified value in the Table-Based Cache."""
|
|
self._purge_expired() # Ensure expired entries are removed
|
|
if column not in self.cache.columns:
|
|
return False # Column does not exist
|
|
|
|
# Use DataFrame.query to check if the column contains the value
|
|
query_result = self.cache.query(f'`{column}` == @value', local_dict={'value': value})
|
|
return not query_result.empty # Return True if the value exists, otherwise False
|
|
|
|
def evict(self, num_rows: int = 1):
|
|
"""Evict the oldest accessed rows based on the access order."""
|
|
if len(self.cache) == 0:
|
|
return
|
|
|
|
# Evict the first num_rows rows
|
|
self.cache = self.cache.iloc[num_rows:].reset_index(drop=True)
|
|
|
|
def get_all_items(self) -> pd.DataFrame:
|
|
"""Retrieve all non-expired rows from the table-based cache."""
|
|
self._purge_expired() # Ensure expired rows are removed
|
|
return self.cache
|
|
|
|
def remove_item(self, conditions: List[Tuple[str, Any]]) -> bool:
|
|
"""Remove rows from the table-based cache that match the key-value conditions."""
|
|
self._purge_expired() # Ensure expired entries are removed
|
|
if self.cache.empty:
|
|
return False # Cache is empty
|
|
|
|
# Construct the query string and prepare local variables for the query
|
|
query_conditions = ' and '.join([f'`{col}` == @val_{col}' for col, _ in conditions])
|
|
query_vars = {f'val_{col}': val for col, val in conditions}
|
|
|
|
# Apply the query to find matching rows
|
|
remaining_data = self.cache.query(f'not ({query_conditions})', local_dict=query_vars)
|
|
if len(remaining_data) == len(self.cache):
|
|
return False # No rows matched the conditions, so nothing was removed
|
|
|
|
# Update the cache with the remaining data
|
|
self.cache = remaining_data
|
|
return True # Successfully removed matching rows
|
|
|
|
|
|
class CacheManager:
|
|
"""Manages different cache types (row-based and table-based)."""
|
|
|
|
def __init__(self):
|
|
self.caches = {}
|
|
|
|
def create_cache(self, name: str, cache_type: str,
|
|
size_limit: Optional[int] = None,
|
|
eviction_policy: str = 'evict',
|
|
default_expiration: Optional[dt.timedelta] = None,
|
|
columns: Optional[list] = None) -> TableBasedCache | RowBasedCache:
|
|
"""
|
|
Creates a new cache with the given parameters.
|
|
|
|
:param name: The name of the cache.
|
|
:param cache_type: The type of cache ('row' or 'table').
|
|
:param size_limit: Maximum number of items allowed in the cache.
|
|
:param eviction_policy: Policy for evicting items when cache limit is reached.
|
|
:param default_expiration: A timedelta object representing the expiration time.
|
|
:param columns: Optional list of column names to initialize an empty DataFrame for a table-based cache.
|
|
:return: The created cache.
|
|
"""
|
|
# Convert default_expiration timedelta to seconds
|
|
expiration_in_seconds = default_expiration.total_seconds() if \
|
|
default_expiration not in [None, 0] else default_expiration
|
|
|
|
# Create cache using expiration_in_seconds
|
|
if cache_type == 'row':
|
|
self.caches[name] = RowBasedCache(size_limit=size_limit, eviction_policy=eviction_policy,
|
|
default_expiration=expiration_in_seconds)
|
|
elif cache_type == 'table':
|
|
self.caches[name] = TableBasedCache(size_limit=size_limit, eviction_policy=eviction_policy,
|
|
default_expiration=expiration_in_seconds)
|
|
|
|
# Initialize the DataFrame with provided columns if specified
|
|
if columns:
|
|
self.caches[name].add_table(df=pd.DataFrame(columns=columns))
|
|
logging.info(f"Table-based cache '{name}' initialized with columns: {columns}")
|
|
else:
|
|
raise ValueError(f"Unsupported cache type: {cache_type}")
|
|
|
|
logging.info(f"Cache '{name}' of type '{cache_type}' created with expiration: {default_expiration}")
|
|
return self.caches.get(name)
|
|
|
|
def get_cache(self, name: str) -> RowBasedCache | TableBasedCache:
|
|
"""Retrieve a cache by name."""
|
|
if name in self.caches:
|
|
return self.caches[name]
|
|
else:
|
|
raise KeyError(f"Cache: {name}, does not exist.")
|
|
|
|
def get_rows_from_cache(self, cache_name: str, filter_vals: list[tuple[str, Any]]) -> pd.DataFrame | None:
|
|
"""
|
|
Retrieves rows from the cache if available;
|
|
|
|
:param cache_name: The key used to identify the cache.
|
|
:param filter_vals: A list of tuples, each containing a column name and the value(s) to filter by.
|
|
:return: A DataFrame containing the requested rows, or None if no matching rows are found.
|
|
:raises ValueError: If the cache is not a DataFrame or does not contain DataFrames in the 'data' column.
|
|
"""
|
|
# Check if the cache exists
|
|
if cache_name in self.caches:
|
|
cache = self.get_cache(cache_name)
|
|
|
|
# Ensure the cache contains DataFrames (required for querying)
|
|
if not isinstance(cache, (TableBasedCache, RowBasedCache)):
|
|
raise ValueError(f"Cache '{cache_name}' does not contain TableBasedCache or RowBasedCache.")
|
|
|
|
# Perform the query on the cache using filter_vals
|
|
filtered_cache = cache.query(filter_vals) # Pass the list of filters
|
|
|
|
# If data is found in the cache, return it
|
|
if not filtered_cache.empty:
|
|
return filtered_cache
|
|
|
|
# No result return an empty Dataframe
|
|
return pd.DataFrame()
|
|
|
|
def get_cache_item(self, item_name: str, cache_name: str, filter_vals: tuple[str, any]) -> any:
|
|
"""
|
|
Retrieves a specific item from the cache.
|
|
|
|
:param item_name: The name of the column to retrieve.
|
|
:param cache_name: The name used to identify the cache (also the name of the database table).
|
|
:param filter_vals: A tuple containing the column name and the value to filter by.
|
|
:return: The value of the requested item.
|
|
:raises ValueError: If the item is not found in either the cache,
|
|
or if the column does not exist.
|
|
"""
|
|
# Fetch the relevant rows from the cache or database
|
|
rows = self.get_rows_from_cache(cache_name=cache_name, filter_vals=[filter_vals])
|
|
if rows is not None and not rows.empty:
|
|
if item_name not in rows.columns:
|
|
raise ValueError(f"Column '{item_name}' does not exist in the cache '{cache_name}'.")
|
|
# Return the specific item from the first matching row.
|
|
return rows.iloc[0][item_name]
|
|
|
|
# No item found in the cache that satisfied the query.
|
|
return None
|
|
|
|
def insert_row_into_cache(self, cache_name: str, columns: tuple, values: tuple, key: str = None) -> None:
|
|
"""
|
|
Inserts a single row into the specified cache.
|
|
|
|
:param cache_name: The name of the cache where the row should be inserted.
|
|
:param columns: A tuple of column names corresponding to the values.
|
|
:param values: A tuple of values to insert into the specified columns.
|
|
:param key: Optional key for the cache item.
|
|
"""
|
|
# Create a DataFrame for the new row
|
|
new_row_df = pd.DataFrame([values], columns=list(columns))
|
|
|
|
# Determine if the cache is row-based or table-based, and insert accordingly
|
|
cache = self.get_cache(cache_name)
|
|
|
|
if isinstance(cache, RowBasedCache):
|
|
if key is None:
|
|
raise ValueError('A key must be provided for row based cache.')
|
|
# For row-based cache, insert the new row as a new cache entry using the key
|
|
cache.add_entry(key=key, data=new_row_df)
|
|
elif isinstance(cache, TableBasedCache):
|
|
# For table-based cache, append the new row to the existing DataFrame
|
|
cache.add_table(df=new_row_df)
|
|
else:
|
|
raise ValueError(f"Unknown cache type for {cache_name}")
|
|
|
|
def insert_df_into_cache(self, df: pd.DataFrame, cache_name: str) -> None:
|
|
"""
|
|
Inserts data from a DataFrame into the specified cache.
|
|
|
|
:param df: The DataFrame containing the data to insert.
|
|
:param cache_name: The name of the cache where the data should be inserted.
|
|
"""
|
|
cache = self.get_cache(cache_name)
|
|
|
|
if isinstance(cache, RowBasedCache):
|
|
# For row-based cache, insert each row of the DataFrame individually using the first column as the key
|
|
for idx, row in df.iterrows():
|
|
key = str(row[0]) # Assuming the first column is the unique key for each row
|
|
cache.add_entry(key=key, data=row.to_frame().T) # Convert row back to DataFrame for insertion
|
|
elif isinstance(cache, TableBasedCache):
|
|
# For table-based cache, insert the entire DataFrame
|
|
cache.add_table(df=df)
|
|
else:
|
|
raise ValueError(f"Unknown cache type for {cache_name}")
|
|
|
|
def remove_row_from_cache(self, cache_name: str, filter_vals: List[tuple[str, Any]]) -> None:
|
|
"""
|
|
Removes rows from the cache based on multiple filter criteria.
|
|
|
|
This method is specifically designed for caches stored as DataFrames.
|
|
|
|
:param cache_name: The name of the cache (or table) from which to remove rows.
|
|
:param filter_vals: A list of tuples, each containing a column name and the value to filter by.
|
|
:raises ValueError: If the cache is not a DataFrame or if no valid cache is found.
|
|
"""
|
|
# Ensure filter_vals is a list of tuples
|
|
if not isinstance(filter_vals, list) or not all(isinstance(item, tuple) for item in filter_vals):
|
|
raise ValueError("filter_vals must be a list of tuples (column, value)")
|
|
|
|
cache = self.get_cache(cache_name)
|
|
if cache is None:
|
|
raise ValueError(f"Cache '{cache_name}' not found.")
|
|
|
|
# Call the cache system to remove the filtered rows
|
|
cache.remove_item(filter_vals)
|
|
|
|
def modify_cache_item(self, cache_name: str, filter_vals: List[Tuple[str, any]], field_name: str,
|
|
new_data: any) -> None:
|
|
"""
|
|
Modifies a specific field in a row within the cache.
|
|
|
|
:param cache_name: The name used to identify the cache.
|
|
:param filter_vals: A list of tuples containing column names and values to filter by.
|
|
:param field_name: The field to be updated.
|
|
:param new_data: The new data to be set.
|
|
:raises ValueError: If the row is not found in the cache, or if multiple rows are returned.
|
|
"""
|
|
# Retrieve the row from the cache
|
|
rows = self.get_rows_from_cache(cache_name=cache_name, filter_vals=filter_vals)
|
|
|
|
if rows is None or rows.empty:
|
|
raise ValueError(f"Row not found in cache for {filter_vals}")
|
|
|
|
# Check if multiple rows are returned
|
|
if len(rows) > 1:
|
|
raise ValueError(f"Multiple rows found for {filter_vals}. Please provide a more specific filter.")
|
|
|
|
# Update the DataFrame with the new value
|
|
rows[field_name] = new_data
|
|
|
|
# Get the cache instance
|
|
cache = self.get_cache(cache_name)
|
|
|
|
# Set the updated row in the cache
|
|
if isinstance(cache, RowBasedCache):
|
|
# For row-based cache, the 'tbl_key' must be in filter_vals
|
|
key_value = next((val for key, val in filter_vals if key == 'tbl_key'), None)
|
|
if key_value is None:
|
|
raise ValueError("'tbl_key' must be present in filter_vals for row-based caches.")
|
|
# Update the cache entry with the modified row
|
|
cache.add_entry(key=key_value, data=rows)
|
|
elif isinstance(cache, TableBasedCache):
|
|
# For table-based cache, use the existing query method to update the correct rows
|
|
cache.add_table(rows)
|
|
else:
|
|
raise ValueError(f"Unsupported cache type for {cache_name}")
|
|
|
|
@staticmethod
|
|
def key_exists(cache, key):
|
|
# Handle different cache types
|
|
if isinstance(cache, RowBasedCache):
|
|
return True if key in cache.cache else False
|
|
if isinstance(cache, TableBasedCache):
|
|
existing_rows = cache.query([("tbl_key", key)])
|
|
return False if existing_rows.empty else True
|
|
|
|
|
|
class SnapshotDataCache(CacheManager):
|
|
"""
|
|
Extends DataCacheBase with snapshot functionality for both row-based and table-based caches.
|
|
|
|
Attributes:
|
|
snapshots (dict): A dictionary to store snapshots, with cache names as keys and snapshot data as values.
|
|
|
|
Methods:
|
|
snapshot_cache(cache_name: str): Takes a snapshot of the specified cache and stores it.
|
|
get_snapshot(cache_name: str): Retrieves the most recent snapshot of the specified cache.
|
|
list_snapshots() -> dict: Lists all available snapshots along with their timestamps.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__() # Call the constructor of CacheManager
|
|
self.snapshots = {} # Initialize the snapshots dictionary
|
|
|
|
def snapshot_cache(self, cache_name: str):
|
|
"""Takes a snapshot of the specified cache and stores it with a timestamp."""
|
|
if cache_name not in self.caches:
|
|
raise ValueError(f"Cache '{cache_name}' does not exist.")
|
|
|
|
# Deep copy of the cache to ensure that the snapshot is independent
|
|
cache_data = copy.deepcopy(self.caches[cache_name])
|
|
|
|
# Store the snapshot with a timestamp
|
|
self.snapshots[cache_name] = (cache_data, dt.datetime.now())
|
|
print(f"Snapshot taken for cache '{cache_name}' at {self.snapshots[cache_name][1]}.")
|
|
|
|
def get_snapshot(self, cache_name: str):
|
|
"""Retrieves the most recent snapshot of the specified cache."""
|
|
if cache_name not in self.snapshots:
|
|
raise ValueError(f"No snapshot available for cache '{cache_name}'.")
|
|
|
|
snapshot, timestamp = self.snapshots[cache_name]
|
|
print(f"Returning snapshot of cache '{cache_name}' taken at {timestamp}.")
|
|
return snapshot
|
|
|
|
def list_snapshots(self) -> dict:
|
|
"""Lists all available snapshots along with their timestamps."""
|
|
return {cache_name: timestamp for cache_name, (_, timestamp) in self.snapshots.items()}
|
|
|
|
|
|
class DatabaseInteractions(SnapshotDataCache):
|
|
"""
|
|
Extends SnapshotDataCache with additional functionality, including database interactions and cache management.
|
|
|
|
Attributes:
|
|
db (Database): A database connection instance for executing queries.
|
|
TYPECHECKING_ENABLED (bool): A class attribute to toggle type checking.
|
|
|
|
Methods:
|
|
get_or_fetch_rows(cache_name: str, filter_vals: tuple[str, any]) -> pd.DataFrame | None:
|
|
Retrieves rows from the cache or database.
|
|
fetch_item(item_name: str, cache_name: str, filter_vals: tuple[str, any]) -> any:
|
|
Retrieves a specific item from the cache or database.
|
|
insert_row(cache_name: str, columns: tuple, values: tuple, skip_cache: bool = False) -> None:
|
|
Inserts a row into the cache and database.
|
|
insert_df(df: pd.DataFrame, cache_name: str, skip_cache: bool = False) -> None:
|
|
Inserts data from a DataFrame into the cache and database.
|
|
remove_row(cache_name: str, filter_vals: tuple[str, any], additional_filter: tuple[str, any] = None,
|
|
remove_from_db: bool = True) -> None:
|
|
Removes a row from the cache and optionally the database.
|
|
is_attr_taken(cache_name: str, attr: str, val: any) -> bool:
|
|
Checks if a specific attribute has a given value in the cache.
|
|
modify_item(cache_name: str, filter_vals: tuple[str, any], field_name: str, new_data: any) -> None:
|
|
Modifies a field in a row within the cache and updates the database.
|
|
update_cached_dict(cache_name: str, cache_key: str, dict_key: str, data: any) -> None:
|
|
Updates a dictionary stored in the DataFrame cache.
|
|
|
|
Usage Example:
|
|
# Create a DatabaseInteractions instance
|
|
db_cache_manager = DatabaseInteractions(exchanges=exchanges)
|
|
|
|
# Set some data in the cache
|
|
db_cache_manager.set_cache_item(key='AAPL', data={'price': 100}, cache_name='stock_prices')
|
|
|
|
# Fetch or query rows from the cache or database
|
|
rows = db_cache_manager.get_or_fetch_rows(cache_name='stock_prices', filter_vals=('symbol', 'AAPL'))
|
|
print(f"Fetched Rows:\n{rows}")
|
|
|
|
# Fetch a specific item from the cache or database
|
|
price = db_cache_manager.fetch_item(item_name='price', cache_name='stock_prices',
|
|
filter_vals=('symbol', 'AAPL'))
|
|
print(f"Fetched Price: {price}")
|
|
|
|
# Insert a new row into the cache and database
|
|
db_cache_manager.insert_row(cache_name='stock_prices', columns=('symbol', 'price'), values=('TSLA', 800))
|
|
|
|
# Check if an attribute value is already taken
|
|
is_taken = db_cache_manager.is_attr_taken(cache_name='stock_prices', attr='symbol', val='AAPL')
|
|
print(f"Is Symbol Taken: {is_taken}")
|
|
|
|
# Modify an existing item in the cache and database
|
|
db_cache_manager.modify_item(cache_name='stock_prices', filter_vals=('symbol', 'AAPL'),
|
|
field_name='price', new_data=105)
|
|
|
|
# Update a dictionary within the cache
|
|
db_cache_manager.update_cached_dict(cache_name='stock_prices', cache_key='AAPL', dict_key='price', data=110)
|
|
"""
|
|
TYPECHECKING_ENABLED = True
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.db = Database()
|
|
|
|
def get_rows_from_datacache(self, cache_name: str, filter_vals: list[tuple[str, Any]] = None,
|
|
key: str = None, include_tbl_key: bool = False) -> pd.DataFrame | None:
|
|
"""
|
|
Retrieves rows from the cache if available; otherwise, queries the database and caches the result.
|
|
|
|
:param include_tbl_key: If True, includes 'tbl_key' in the returned DataFrame.
|
|
:param key: Optional key to filter by 'tbl_key'.
|
|
:param cache_name: The key used to identify the cache (also the name of the database table).
|
|
:param filter_vals: A list of tuples, each containing a column name and the value(s) to filter by.
|
|
If a value is a list, it will use the SQL 'IN' clause.
|
|
:return: A DataFrame containing the requested rows, or None if no matching rows are found.
|
|
:raises ValueError: If the cache is not a DataFrame or does not contain DataFrames in the 'data' column.
|
|
"""
|
|
# Ensure at least one of filter_vals or key is provided
|
|
if not filter_vals and not key:
|
|
raise ValueError("At least one of 'filter_vals' or 'key' must be provided.")
|
|
|
|
# Use an empty list if filter_vals is None
|
|
filter_vals = filter_vals or []
|
|
|
|
# Insert the key if provided
|
|
if key:
|
|
filter_vals.insert(0, ('tbl_key', key))
|
|
|
|
# Convert filter values that are lists to 'IN' clauses for cache filtering
|
|
result = self.get_rows_from_cache(cache_name, filter_vals)
|
|
|
|
# Fallback to database if no result found in cache
|
|
if result.empty:
|
|
result = self._fetch_from_database(cache_name, filter_vals)
|
|
|
|
# Only use _fetch_from_database_with_list_support if any filter values are lists
|
|
if result.empty and any(isinstance(val, list) for _, val in filter_vals):
|
|
result = self._fetch_from_database_with_list_support(cache_name, filter_vals)
|
|
|
|
# Remove 'tbl_key' unless include_tbl_key is True
|
|
if not include_tbl_key:
|
|
result = result.drop(columns=['tbl_key'], errors='ignore')
|
|
|
|
return result
|
|
|
|
def _fetch_from_database_with_list_support(self, cache_name: str,
|
|
filter_vals: List[tuple[str, Any]]) -> pd.DataFrame:
|
|
"""
|
|
Fetch rows from the database, supporting list values that require SQL 'IN' clauses.
|
|
|
|
:param cache_name: The name of the table or key used to store/retrieve data.
|
|
:param filter_vals: A list of tuples with the filter column and value, supporting lists for 'IN' clause.
|
|
:return: A DataFrame with the fetched rows, or None if no data is found.
|
|
"""
|
|
where_clauses = []
|
|
params = []
|
|
|
|
for col, val in filter_vals:
|
|
if isinstance(val, list):
|
|
placeholders = ', '.join('?' for _ in val)
|
|
where_clauses.append(f"{col} IN ({placeholders})")
|
|
params.extend(val)
|
|
else:
|
|
where_clauses.append(f"{col} = ?")
|
|
params.append(val)
|
|
|
|
where_clause = " AND ".join(where_clauses)
|
|
sql_query = f"SELECT * FROM {cache_name} WHERE {where_clause}"
|
|
|
|
# Execute the SQL query with the prepared parameters
|
|
rows = self.db.get_rows_where(sql_query, params)
|
|
|
|
# Cache the result (either row or table based)
|
|
if rows is not None and not rows.empty:
|
|
cache = self.get_cache(cache_name)
|
|
|
|
if isinstance(cache, RowBasedCache):
|
|
for _, row in rows.iterrows():
|
|
cache.add_entry(key=row['tbl_key'], data=row)
|
|
else:
|
|
cache.add_table(rows, overwrite='tbl_key')
|
|
|
|
return rows
|
|
|
|
def _fetch_from_database(self, cache_name: str, filter_vals: List[tuple[str, Any]]) -> pd.DataFrame:
|
|
"""
|
|
Fetch rows from the database and cache the result.
|
|
|
|
:param cache_name: The name of the table or key used to store/retrieve data.
|
|
:param filter_vals: A list of tuples with the filter column and value.
|
|
:return: A DataFrame with the fetched rows, or None if no data is found.
|
|
"""
|
|
# Use db.get_rows_where, assuming it can handle multiple filters
|
|
rows = self.db.get_rows_where(cache_name, filter_vals)
|
|
|
|
if rows is not None and not rows.empty:
|
|
# Cache the fetched data (let the caching system handle whether it's row or table-based)
|
|
cache = self.get_cache(cache_name)
|
|
|
|
if isinstance(cache, RowBasedCache):
|
|
# For row-based cache, assume the first filter value is used as the key
|
|
key_value = filter_vals[0][1] # Use the value of the first filter as the key
|
|
cache.add_entry(key=key_value, data=rows)
|
|
else:
|
|
# For table-based cache, add the entire DataFrame to the cache
|
|
cache.add_table(df=rows, overwrite='tbl_key')
|
|
|
|
# Return the fetched rows
|
|
return rows
|
|
|
|
# If no rows are found, return None
|
|
return pd.DataFrame()
|
|
|
|
def get_datacache_item(self, item_name: str, cache_name: str, filter_vals: tuple[str, any]) -> any:
|
|
"""
|
|
Retrieves a specific item from the cache or database, caching the result if necessary.
|
|
|
|
:param item_name: The name of the column to retrieve.
|
|
:param cache_name: The name used to identify the cache (also the name of the database table).
|
|
:param filter_vals: A tuple containing the column name and the value to filter by.
|
|
:return: The value of the requested item.
|
|
:raises ValueError: If the item is not found in either the cache or the database,
|
|
or if the column does not exist.
|
|
"""
|
|
# Fetch the relevant rows from the cache or database
|
|
rows = self.get_rows_from_datacache(cache_name=cache_name, filter_vals=[filter_vals])
|
|
if rows is not None and not rows.empty:
|
|
if item_name not in rows.columns:
|
|
raise ValueError(f"Column '{item_name}' does not exist in the cache '{cache_name}'.")
|
|
# Return the specific item from the first matching row.
|
|
return rows.iloc[0][item_name]
|
|
|
|
# The item was not found.
|
|
return None
|
|
|
|
def insert_row_into_datacache(self, cache_name: str, columns: tuple, values: tuple, key: str = None,
|
|
skip_cache: bool = False) -> None:
|
|
"""
|
|
Inserts a single row into the specified cache and database, with an option to skip cache insertion.
|
|
|
|
:param cache_name: The name of the cache (and database table) where the row should be inserted.
|
|
:param columns: A tuple of column names corresponding to the values.
|
|
:param values: A tuple of values to insert into the specified columns.
|
|
:param key: Optional key for the cache item. If None, the auto-incremented ID from the database will be used.
|
|
:param skip_cache: If True, skips inserting the row into the cache. Default is False.
|
|
"""
|
|
if key:
|
|
columns, values = columns + ('tbl_key',), values + (key,)
|
|
|
|
# Insert the row into the database
|
|
last_inserted_id = self.db.insert_row(table=cache_name, columns=columns, values=values)
|
|
|
|
# Now include the 'id' in the columns and values when inserting into the cache
|
|
columns = ('id',) + columns
|
|
values = (last_inserted_id,) + values
|
|
|
|
# Insert the row into the cache
|
|
if skip_cache:
|
|
return
|
|
self.insert_row_into_cache(cache_name, columns, values, key)
|
|
|
|
def insert_df_into_datacache(self, df: pd.DataFrame, cache_name: str, skip_cache: bool = False) -> None:
|
|
"""
|
|
Inserts data from a DataFrame into the specified cache and database, with an option to skip cache insertion.
|
|
|
|
:param df: The DataFrame containing the data to insert.
|
|
:param cache_name: The name of the cache (and database table) where the data should be inserted.
|
|
:param skip_cache: If True, skips inserting the data into the cache. Default is False.
|
|
"""
|
|
# Insert the data into the database
|
|
self.db.insert_dataframe(df=df, table=cache_name)
|
|
|
|
if not skip_cache:
|
|
self.insert_df_into_cache(df, cache_name)
|
|
|
|
def remove_row_from_datacache(self, cache_name: str, filter_vals: List[tuple[str, Any]],
|
|
remove_from_db: bool = True, key: str = None) -> None:
|
|
"""
|
|
Removes rows from the cache and optionally from the database based on multiple filter criteria.
|
|
|
|
This method is specifically designed for caches stored as DataFrames.
|
|
|
|
:param key: Optional key
|
|
:param cache_name: The name of the cache (or table) from which to remove rows.
|
|
:param filter_vals: A list of tuples, each containing a column name and the value to filter by.
|
|
:param remove_from_db: If True, also removes the rows from the database. Default is True.
|
|
:raises ValueError: If the cache is not a DataFrame or if no valid cache is found.
|
|
"""
|
|
if key:
|
|
filter_vals.insert(0, ('tbl_key', key))
|
|
|
|
self.remove_row_from_cache(cache_name, filter_vals)
|
|
|
|
# Remove from the database if required
|
|
if remove_from_db:
|
|
# Construct SQL to remove the rows from the database based on filter_vals
|
|
sql = f"DELETE FROM {cache_name} WHERE " + " AND ".join([f"{col} = ?" for col, _ in filter_vals])
|
|
params = [val for _, val in filter_vals]
|
|
|
|
# Execute the SQL query to remove the row from the database
|
|
self.db.execute_sql(sql, params)
|
|
|
|
def modify_datacache_item(self, cache_name: str, filter_vals: List[Tuple[str, any]], field_names: Tuple[str, ...],
|
|
new_values: Tuple[Any, ...], key: str = None, overwrite: str = None) -> None:
|
|
"""
|
|
Modifies specific fields in a row within the cache and updates the database accordingly.
|
|
|
|
:param cache_name: The name used to identify the cache.
|
|
:param filter_vals: A list of tuples containing column names and values to filter by.
|
|
:param field_names: A tuple of field names to be updated.
|
|
:param new_values: A tuple of new values corresponding to field_names.
|
|
:param key: Optional key to identify the entry.
|
|
:param overwrite: Column name(s) to use for overwriting in the cache.
|
|
:raises ValueError: If the row is not found or multiple rows are returned.
|
|
"""
|
|
if key:
|
|
filter_vals.insert(0, ('tbl_key', key))
|
|
|
|
# Retrieve the row from the cache or database
|
|
rows = self.get_rows_from_datacache(cache_name=cache_name, filter_vals=filter_vals)
|
|
|
|
if rows is None or rows.empty:
|
|
raise ValueError(f"Row not found in cache or database for {filter_vals}")
|
|
|
|
# Check if multiple rows are returned
|
|
if len(rows) > 1:
|
|
raise ValueError(f"Multiple rows found for {filter_vals}. Please provide more specific filter.")
|
|
|
|
# Update the DataFrame with the new values
|
|
for field_name, new_value in zip(field_names, new_values):
|
|
rows[field_name] = new_value
|
|
|
|
# Get the cache instance
|
|
cache = self.get_cache(cache_name)
|
|
|
|
# Set the updated row in the cache
|
|
if isinstance(cache, RowBasedCache):
|
|
# Update the cache entry with the modified row
|
|
cache.add_entry(key=key, data=rows)
|
|
elif isinstance(cache, TableBasedCache):
|
|
# Use 'overwrite' to ensure correct row is updated
|
|
cache.add_table(rows, overwrite=overwrite)
|
|
else:
|
|
raise ValueError(f"Unsupported cache type for {cache_name}")
|
|
|
|
# Update the values in the database
|
|
set_clause = ", ".join([f"{field} = ?" for field in field_names])
|
|
where_clause = " AND ".join([f"{col} = ?" for col, _ in filter_vals])
|
|
sql_update = f"UPDATE {cache_name} SET {set_clause} WHERE {where_clause}"
|
|
params = list(new_values) + [val for _, val in filter_vals]
|
|
|
|
# Execute the SQL update to modify the database
|
|
self.db.execute_sql(sql_update, params)
|
|
|
|
def modify_multiple_datacache_items(self, cache_name: str, filter_vals: List[Tuple[str, any]],
|
|
field_names: Tuple[str, ...], new_values: Tuple[Any, ...],
|
|
key: str = None, overwrite: str = None) -> None:
|
|
"""
|
|
Modifies specific fields in multiple rows within the cache and updates the database accordingly.
|
|
|
|
:param cache_name: The name used to identify the cache.
|
|
:param filter_vals: A list of tuples containing column names and values to filter by.
|
|
If a filter value is a list, it will be used with the 'IN' clause.
|
|
:param field_names: A tuple of field names to be updated.
|
|
:param new_values: A tuple of new values corresponding to field_names.
|
|
:param key: Optional key to identify the entry.
|
|
:param overwrite: Column name(s) to use for overwriting in the cache.
|
|
:raises ValueError: If no rows are found.
|
|
"""
|
|
if key:
|
|
filter_vals.insert(0, ('tbl_key', key))
|
|
|
|
# Prepare the SQL query
|
|
where_clauses = []
|
|
query_params = []
|
|
|
|
for col, val in filter_vals:
|
|
if isinstance(val, list):
|
|
# Use the 'IN' clause if the value is a list
|
|
placeholders = ', '.join('?' for _ in val)
|
|
where_clauses.append(f"{col} IN ({placeholders})")
|
|
query_params.extend(val)
|
|
else:
|
|
where_clauses.append(f"{col} = ?")
|
|
query_params.append(val)
|
|
|
|
# Build the SQL query string
|
|
where_clause = " AND ".join(where_clauses)
|
|
set_clause = ", ".join([f"{field} = ?" for field in field_names])
|
|
sql_update = f"UPDATE {cache_name} SET {set_clause} WHERE {where_clause}"
|
|
|
|
# Add the new values to the parameters list
|
|
query_params = list(new_values) + query_params
|
|
|
|
# Execute the SQL update to modify the database
|
|
self.db.execute_sql(sql_update, query_params)
|
|
|
|
# Retrieve the rows from the cache to update the cache
|
|
rows = self.get_rows_from_datacache(cache_name=cache_name, filter_vals=filter_vals)
|
|
|
|
if rows is None or rows.empty:
|
|
raise ValueError(f"Rows not found in cache or database for {filter_vals}")
|
|
|
|
# Update the cache with the new values
|
|
for field_name, new_value in zip(field_names, new_values):
|
|
rows[field_name] = new_value
|
|
|
|
# Get the cache instance
|
|
cache = self.get_cache(cache_name)
|
|
|
|
if isinstance(cache, RowBasedCache):
|
|
for _, row in rows.iterrows():
|
|
key_value = row['tbl_key']
|
|
cache.add_entry(key=key_value, data=row)
|
|
elif isinstance(cache, TableBasedCache):
|
|
cache.add_table(rows, overwrite=overwrite)
|
|
else:
|
|
raise ValueError(f"Unsupported cache type for {cache_name}")
|
|
|
|
def serialized_datacache_insert(self, cache_name: str, data: Any, key: str = None,
|
|
do_not_overwrite: bool = False):
|
|
"""
|
|
Stores an item in the cache, with custom serialization for object instances.
|
|
If the data is not a DataFrame, the entire object is serialized and stored under a column named 'data'.
|
|
|
|
:param cache_name: The name of the cache.
|
|
:param data: Any object to store in the cache, but should be a DataFrame with one row for normal operations.
|
|
:param key: The key for row-based caches, used to identify the entry. Required for row-based caches.
|
|
:param do_not_overwrite: If True, prevents overwriting existing entries in the cache.
|
|
"""
|
|
|
|
# Retrieve the cache
|
|
cache = self.get_cache(cache_name)
|
|
|
|
# If overwrite is disabled and the key already exists, prevent overwriting
|
|
if do_not_overwrite and self.key_exists(cache, key):
|
|
logging.warning(f"Key '{key}' already exists in cache '{cache_name}'. Overwrite prevented.")
|
|
return
|
|
|
|
# If the data is a DataFrame, ensure it contains exactly one row
|
|
if isinstance(data, pd.DataFrame):
|
|
if len(data) != 1:
|
|
raise ValueError('This method is for inserting a DataFrame with exactly one row.')
|
|
|
|
# Ensure key is provided for RowBasedCache
|
|
if isinstance(cache, RowBasedCache) and key is None:
|
|
raise ValueError("RowBasedCache requires a key to store the data.")
|
|
|
|
# List of types to exclude from serialization
|
|
excluded_objects = (str, int, float, bool, type(None), bytes)
|
|
|
|
# Process and serialize non-excluded objects in the row
|
|
row = data.iloc[0] # Access the first (and only) row
|
|
row_values = []
|
|
for col_value in row:
|
|
# Serialize column value if it's not one of the excluded types
|
|
if not isinstance(col_value, excluded_objects):
|
|
col_value = pickle.dumps(col_value)
|
|
row_values.append(col_value)
|
|
|
|
# Insert the row into the cache and database (key is handled in insert_row_into_datacache)
|
|
self.insert_row_into_datacache(cache_name=cache_name, columns=tuple(data.columns),
|
|
values=tuple(row_values), key=key)
|
|
|
|
else:
|
|
# For non-DataFrame data, serialize the entire object
|
|
serialized_data = pickle.dumps(data)
|
|
|
|
# Insert the serialized object under a column named 'data'
|
|
self.insert_row_into_datacache(cache_name=cache_name, columns=('data',),
|
|
values=(serialized_data,), key=key)
|
|
|
|
return
|
|
|
|
def get_serialized_datacache(self,
|
|
cache_name: str,
|
|
filter_vals: List[Tuple[str, Any]] = None,
|
|
key: str = None) -> pd.DataFrame | Any:
|
|
"""
|
|
Retrieves an item from the specified cache and deserializes object columns if necessary.
|
|
If the stored data is a serialized object (not a DataFrame), it returns the deserialized object.
|
|
|
|
:param key: The key to identify the cache entry.
|
|
:param filter_vals: List of column filters (name, value) for the cache query.
|
|
:param cache_name: The name of the cache.
|
|
:return Any: Cached data with deserialized objects or the original non-DataFrame object, or None if not found.
|
|
"""
|
|
|
|
# Ensure at least one of filter_vals or key is provided
|
|
if not filter_vals and not key:
|
|
raise ValueError("At least one of 'filter_vals' or 'key' must be provided.")
|
|
|
|
# Prepare filter values
|
|
filter_vals = filter_vals or []
|
|
if key:
|
|
filter_vals.insert(0, ('tbl_key', key))
|
|
|
|
# Retrieve rows from the cache using the key
|
|
data = self.get_rows_from_datacache(cache_name=cache_name, filter_vals=filter_vals)
|
|
|
|
# Return None if no data is found
|
|
if data is None or (isinstance(data, pd.DataFrame) and data.empty):
|
|
logging.info(f"No data found in cache '{cache_name}' for key: {key}")
|
|
return pd.DataFrame()
|
|
|
|
# Handle non-DataFrame data
|
|
if not isinstance(data, pd.DataFrame):
|
|
logging.warning(f"Unexpected data format from cache '{cache_name}'. Returning None.")
|
|
return pd.DataFrame()
|
|
|
|
# Check for single column 'data' (serialized object case)
|
|
if 'data' in data.columns and len(data.columns) == 1:
|
|
return self._deserialize_object(data.iloc[0]['data'], cache_name)
|
|
|
|
# Handle deserialization of DataFrame columns
|
|
return self._deserialize_dataframe_row(data, cache_name)
|
|
|
|
@staticmethod
|
|
def _deserialize_object(serialized_data: Any, cache_name: str) -> Any:
|
|
"""
|
|
Deserializes an object stored as a serialized byte stream in the cache.
|
|
|
|
:param serialized_data: Serialized byte data to deserialize.
|
|
:param cache_name: The name of the cache (used for logging).
|
|
:return: Deserialized object, or the raw bytes if deserialization fails.
|
|
"""
|
|
if not isinstance(serialized_data, bytes):
|
|
logging.warning(f"Expected bytes for deserialization in cache '{cache_name}', got {type(serialized_data)}.")
|
|
return serialized_data
|
|
|
|
try:
|
|
deserialized_data = pickle.loads(serialized_data)
|
|
logging.info(f"Serialized object retrieved and deserialized from cache '{cache_name}'")
|
|
return deserialized_data
|
|
except (pickle.PickleError, TypeError) as e:
|
|
logging.warning(f"Failed to deserialize object from cache '{cache_name}': {e}")
|
|
return serialized_data # Fallback to the raw serialized data
|
|
|
|
@staticmethod
|
|
def _deserialize_dataframe_row(data: pd.DataFrame, cache_name: str) -> pd.DataFrame:
|
|
"""
|
|
Deserializes any serialized columns in a DataFrame row.
|
|
|
|
:param data: The DataFrame containing serialized columns.
|
|
:param cache_name: The name of the cache (used for logging).
|
|
:return: DataFrame with deserialized values.
|
|
"""
|
|
row = data.iloc[0] # Assuming we only retrieve one row
|
|
deserialized_row = []
|
|
|
|
for col_value in row:
|
|
if isinstance(col_value, bytes):
|
|
try:
|
|
deserialized_col_value = pickle.loads(col_value)
|
|
deserialized_row.append(deserialized_col_value)
|
|
except (pickle.PickleError, TypeError) as e:
|
|
logging.warning(f"Failed to deserialize column value in cache '{cache_name}': {e}")
|
|
deserialized_row.append(col_value) # Fallback to the original value
|
|
else:
|
|
deserialized_row.append(col_value)
|
|
|
|
deserialized_data = pd.DataFrame([deserialized_row], columns=data.columns)
|
|
logging.info(f"Data retrieved and deserialized from cache '{cache_name}'")
|
|
return deserialized_data
|
|
|
|
|
|
class ServerInteractions(DatabaseInteractions):
|
|
"""
|
|
Extends DataCache to specialize in handling candle (OHLC) data and server interactions.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.exchanges = None
|
|
|
|
def set_exchange(self, exchanges):
|
|
"""
|
|
Sets an exchange interface for this class to use.
|
|
:param exchanges: ExchangeInterface obj
|
|
:return: none.
|
|
"""
|
|
self.exchanges = exchanges
|
|
|
|
@staticmethod
|
|
def _make_key(ex_details: list[str]) -> str:
|
|
"""
|
|
Generates a unique key string based on the exchange details provided.
|
|
|
|
:param ex_details: A list of strings containing symbol, timeframe, exchange, etc.
|
|
:return: A key string composed of the concatenated details.
|
|
"""
|
|
symbol, timeframe, exchange, _ = ex_details
|
|
key = f'{symbol}_{timeframe}_{exchange}'
|
|
return key
|
|
|
|
def _update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None:
|
|
logger.debug('Updating data with new records.')
|
|
|
|
# Retrieve the 'candles' cache
|
|
candles = self.get_cache('candles')
|
|
|
|
# Store the updated records back in the cache
|
|
candles.add_entry(key=key, data=more_records)
|
|
|
|
def get_records_since(self, start_datetime: dt.datetime, ex_details: list[str]) -> pd.DataFrame:
|
|
"""
|
|
This gets up-to-date records from a specified market and exchange.
|
|
|
|
:param start_datetime: The approximate time the first record should represent.
|
|
:param ex_details: The user exchange and market.
|
|
:return: The records.
|
|
"""
|
|
if self.TYPECHECKING_ENABLED:
|
|
if not isinstance(start_datetime, dt.datetime):
|
|
raise TypeError("start_datetime must be a datetime object")
|
|
if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details):
|
|
raise TypeError("ex_details must be a list of strings")
|
|
if not len(ex_details) == 4:
|
|
raise TypeError("ex_details must include [asset, timeframe, exchange, user_name]")
|
|
|
|
if start_datetime.tzinfo is None:
|
|
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
|
end_datetime = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc)
|
|
|
|
try:
|
|
args = {
|
|
'start_datetime': start_datetime,
|
|
'end_datetime': end_datetime,
|
|
'ex_details': ex_details,
|
|
}
|
|
return self.get_or_fetch_from(target='data', **args)
|
|
except Exception as e:
|
|
logger.error(f"An error occurred: {str(e)}")
|
|
raise
|
|
|
|
def get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame:
|
|
"""
|
|
Fetches market records from a resource stack (data, database, exchange).
|
|
Fills incomplete request by fetching down the stack then updates the rest.
|
|
|
|
:param target: Starting point for the fetch. ['data', 'database', 'exchange']
|
|
:param kwargs: Details and credentials for the request.
|
|
:return: Records in a dataframe.
|
|
"""
|
|
start_datetime = kwargs.get('start_datetime')
|
|
if start_datetime.tzinfo is None:
|
|
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
|
|
|
end_datetime = kwargs.get('end_datetime')
|
|
if end_datetime.tzinfo is None:
|
|
raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
|
|
|
ex_details = kwargs.get('ex_details')
|
|
timeframe = ex_details[1]
|
|
|
|
if self.TYPECHECKING_ENABLED:
|
|
if not isinstance(start_datetime, dt.datetime):
|
|
raise TypeError("start_datetime must be a datetime object")
|
|
if not isinstance(end_datetime, dt.datetime):
|
|
raise TypeError("end_datetime must be a datetime object")
|
|
if not isinstance(timeframe, str):
|
|
raise TypeError("record_length must be a string representing the timeframe")
|
|
if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details):
|
|
raise TypeError("ex_details must be a list of strings")
|
|
if not all([start_datetime, end_datetime, timeframe, ex_details]):
|
|
raise ValueError("Missing required arguments")
|
|
|
|
request_criteria = {
|
|
'start_datetime': start_datetime,
|
|
'end_datetime': end_datetime,
|
|
'timeframe': timeframe,
|
|
}
|
|
|
|
key = self._make_key(ex_details=ex_details)
|
|
combined_data = pd.DataFrame()
|
|
|
|
if target == 'data':
|
|
resources = [self._get_candles_from_cache, self._get_from_database, self._get_from_server]
|
|
elif target == 'database':
|
|
resources = [self._get_from_database, self._get_from_server]
|
|
elif target == 'server':
|
|
resources = [self._get_from_server]
|
|
else:
|
|
raise ValueError('Not a valid Target!')
|
|
|
|
for fetch_method in resources:
|
|
result = fetch_method(**kwargs)
|
|
|
|
if result is not None and not result.empty:
|
|
# Drop the 'id' column if it exists in the result
|
|
if 'id' in result.columns:
|
|
result = result.drop(columns=['id'])
|
|
|
|
# Concatenate, drop duplicates based on 'time', and sort
|
|
combined_data = pd.concat([combined_data, result]).drop_duplicates(subset='time').sort_values(
|
|
by='time')
|
|
|
|
is_complete, request_criteria = self._data_complete(data=combined_data, **request_criteria)
|
|
if is_complete:
|
|
if fetch_method in [self._get_from_database, self._get_from_server]:
|
|
self._update_candle_cache(more_records=combined_data, key=key)
|
|
if fetch_method == self._get_from_server:
|
|
self._populate_db(ex_details=ex_details, data=combined_data)
|
|
return combined_data
|
|
|
|
kwargs.update(request_criteria) # Update kwargs with new start/end times for next fetch attempt
|
|
|
|
logger.error('Unable to fetch the requested data.')
|
|
return combined_data if not combined_data.empty else pd.DataFrame()
|
|
|
|
def _get_candles_from_cache(self, **kwargs) -> pd.DataFrame:
|
|
start_datetime = kwargs.get('start_datetime')
|
|
end_datetime = kwargs.get('end_datetime')
|
|
ex_details = kwargs.get('ex_details')
|
|
|
|
if not all([start_datetime, end_datetime, ex_details]):
|
|
raise ValueError("Missing required arguments for candle data retrieval.")
|
|
|
|
key = self._make_key(ex_details=ex_details)
|
|
logger.debug('Getting records from candles cache.')
|
|
|
|
df = self.get_cache('candles').get_entry(key=key)
|
|
if df is None or df.empty:
|
|
logger.debug("No cached records found.")
|
|
return pd.DataFrame()
|
|
|
|
df_filtered = df[(df['time'] >= unix_time_millis(start_datetime)) &
|
|
(df['time'] <= unix_time_millis(end_datetime))].reset_index(drop=True)
|
|
|
|
return df_filtered
|
|
|
|
def _get_from_database(self, **kwargs) -> pd.DataFrame:
|
|
start_datetime = kwargs.get('start_datetime')
|
|
if start_datetime.tzinfo is None:
|
|
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
|
|
|
end_datetime = kwargs.get('end_datetime')
|
|
if end_datetime.tzinfo is None:
|
|
raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
|
|
|
ex_details = kwargs.get('ex_details')
|
|
|
|
if self.TYPECHECKING_ENABLED:
|
|
if not isinstance(start_datetime, dt.datetime):
|
|
raise TypeError("start_datetime must be a datetime object")
|
|
if not isinstance(end_datetime, dt.datetime):
|
|
raise TypeError("end_datetime must be a datetime object")
|
|
if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details):
|
|
raise TypeError("ex_details must be a list of strings")
|
|
if not all([start_datetime, end_datetime, ex_details]):
|
|
raise ValueError("Missing required arguments")
|
|
|
|
table_name = self._make_key(ex_details=ex_details)
|
|
if not self.db.table_exists(table_name):
|
|
logger.debug('Records not in database.')
|
|
return pd.DataFrame()
|
|
|
|
logger.debug('Getting records from database.')
|
|
return self.db.get_timestamped_records(table_name=table_name, timestamp_field='time', st=start_datetime,
|
|
et=end_datetime)
|
|
|
|
def _get_from_server(self, **kwargs) -> pd.DataFrame:
|
|
symbol = kwargs.get('ex_details')[0]
|
|
interval = kwargs.get('ex_details')[1]
|
|
exchange_name = kwargs.get('ex_details')[2]
|
|
user_name = kwargs.get('ex_details')[3]
|
|
start_datetime = kwargs.get('start_datetime')
|
|
if start_datetime.tzinfo is None:
|
|
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
|
|
|
end_datetime = kwargs.get('end_datetime')
|
|
if end_datetime.tzinfo is None:
|
|
raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
|
|
|
if self.TYPECHECKING_ENABLED:
|
|
if not isinstance(symbol, str):
|
|
raise TypeError("symbol must be a string")
|
|
if not isinstance(interval, str):
|
|
raise TypeError("interval must be a string")
|
|
if not isinstance(exchange_name, str):
|
|
raise TypeError("exchange_name must be a string")
|
|
if not isinstance(user_name, str):
|
|
raise TypeError("user_name must be a string")
|
|
if not isinstance(start_datetime, dt.datetime):
|
|
raise TypeError("start_datetime must be a datetime object")
|
|
if not isinstance(end_datetime, dt.datetime):
|
|
raise TypeError("end_datetime must be a datetime object")
|
|
|
|
logger.debug('Getting records from server.')
|
|
return self._fetch_candles_from_exchange(symbol=symbol, interval=interval, exchange_name=exchange_name,
|
|
user_name=user_name, start_datetime=start_datetime,
|
|
end_datetime=end_datetime)
|
|
|
|
@staticmethod
|
|
def _data_complete(data: pd.DataFrame, **kwargs) -> (bool, dict):
|
|
"""
|
|
Checks if the data completely satisfies the request.
|
|
|
|
:param data: DataFrame containing the records.
|
|
:param kwargs: Arguments required for completeness check.
|
|
:return: A tuple (is_complete, updated_request_criteria) where is_complete is True if the data is complete,
|
|
False otherwise, and updated_request_criteria contains adjusted start/end times if data is incomplete.
|
|
"""
|
|
if data.empty:
|
|
logger.debug("Data is empty.")
|
|
return False, kwargs # No data at all, proceed with the full original request
|
|
|
|
start_datetime: dt.datetime = kwargs.get('start_datetime')
|
|
if start_datetime.tzinfo is None:
|
|
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
|
|
|
end_datetime: dt.datetime = kwargs.get('end_datetime')
|
|
if end_datetime.tzinfo is None:
|
|
raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.")
|
|
|
|
timeframe: str = kwargs.get('timeframe')
|
|
|
|
temp_data = data.copy()
|
|
|
|
# Convert 'time' to datetime with unit='ms' and localize to UTC
|
|
temp_data['time_dt'] = pd.to_datetime(temp_data['time'],
|
|
unit='ms', errors='coerce').dt.tz_localize('UTC')
|
|
|
|
min_timestamp = temp_data['time_dt'].min()
|
|
max_timestamp = temp_data['time_dt'].max()
|
|
|
|
logger.debug(f"Data time range: {min_timestamp} to {max_timestamp}")
|
|
logger.debug(f"Expected time range: {start_datetime} to {end_datetime}")
|
|
|
|
tolerance = pd.Timedelta(seconds=5)
|
|
|
|
# Initialize updated request criteria
|
|
updated_request_criteria = kwargs.copy()
|
|
|
|
# Check if data covers the required time range with tolerance
|
|
if min_timestamp > start_datetime + timeframe_to_timedelta(timeframe) + tolerance:
|
|
logger.debug("Data does not start early enough, even with tolerance.")
|
|
updated_request_criteria['end_datetime'] = min_timestamp # Fetch the missing earlier data
|
|
return False, updated_request_criteria
|
|
|
|
if max_timestamp < end_datetime - timeframe_to_timedelta(timeframe) - tolerance:
|
|
logger.debug("Data does not extend late enough, even with tolerance.")
|
|
updated_request_criteria['start_datetime'] = max_timestamp # Fetch the missing later data
|
|
return False, updated_request_criteria
|
|
|
|
# Filter data between start_datetime and end_datetime
|
|
mask = (temp_data['time_dt'] >= start_datetime) & (temp_data['time_dt'] <= end_datetime)
|
|
data_in_range = temp_data.loc[mask]
|
|
|
|
expected_count = estimate_record_count(start_datetime, end_datetime, timeframe)
|
|
actual_count = len(data_in_range)
|
|
|
|
logger.debug(f"Expected record count: {expected_count}, Actual record count: {actual_count}")
|
|
|
|
tolerance = 1
|
|
if actual_count < (expected_count - tolerance):
|
|
logger.debug("Insufficient records within the specified time range, even with tolerance.")
|
|
return False, updated_request_criteria
|
|
|
|
logger.debug("Data completeness check passed.")
|
|
return True, kwargs
|
|
|
|
def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str,
|
|
start_datetime: dt.datetime = None,
|
|
end_datetime: dt.datetime = None) -> pd.DataFrame:
|
|
if start_datetime is None:
|
|
start_datetime = dt.datetime(year=2017, month=1, day=1, tzinfo=dt.timezone.utc)
|
|
|
|
if end_datetime is None:
|
|
end_datetime = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc)
|
|
|
|
if start_datetime > end_datetime:
|
|
raise ValueError("Invalid start and end parameters: start_datetime must be before end_datetime.")
|
|
|
|
exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name)
|
|
|
|
expected_records = estimate_record_count(start_datetime, end_datetime, interval)
|
|
logger.info(
|
|
f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {expected_records}')
|
|
|
|
if start_datetime == end_datetime:
|
|
end_datetime = None
|
|
|
|
candles = exchange.get_historical_klines(symbol=symbol, interval=interval, start_dt=start_datetime,
|
|
end_dt=end_datetime)
|
|
|
|
num_rec_records = len(candles.index)
|
|
if num_rec_records == 0:
|
|
logger.warning(f"No OHLCV data returned for {symbol}.")
|
|
return pd.DataFrame(columns=['time', 'open', 'high', 'low', 'close', 'volume'])
|
|
|
|
logger.info(f'{num_rec_records} candles retrieved from the exchange.')
|
|
|
|
open_times = candles.time
|
|
min_open_time = open_times.min()
|
|
max_open_time = open_times.max()
|
|
|
|
if min_open_time < 1e10:
|
|
raise ValueError('Records are not in milliseconds')
|
|
|
|
estimated_num_records = estimate_record_count(min_open_time, max_open_time, interval) + 1
|
|
logger.info(f'Estimated number of records: {estimated_num_records}')
|
|
|
|
if num_rec_records < estimated_num_records:
|
|
logger.info('Detected gaps in the data, attempting to fill missing records.')
|
|
candles = self._fill_data_holes(records=candles, interval=interval)
|
|
|
|
return candles
|
|
|
|
@staticmethod
|
|
def _fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame:
|
|
time_span = timeframe_to_timedelta(interval).total_seconds() / 60
|
|
last_timestamp = None
|
|
filled_records = []
|
|
|
|
logger.info(f"Starting to fill data holes for interval: {interval}")
|
|
|
|
for index, row in records.iterrows():
|
|
time_stamp = row['time']
|
|
|
|
if last_timestamp is None:
|
|
last_timestamp = time_stamp
|
|
filled_records.append(row)
|
|
logger.debug(f"First timestamp: {time_stamp}")
|
|
continue
|
|
|
|
delta_ms = time_stamp - last_timestamp
|
|
delta_minutes = (delta_ms / 1000) / 60
|
|
|
|
logger.debug(f"Timestamp: {time_stamp}, Delta minutes: {delta_minutes}")
|
|
|
|
if delta_minutes > time_span:
|
|
num_missing_rec = int(delta_minutes / time_span)
|
|
step = int(delta_ms / num_missing_rec)
|
|
logger.debug(f"Gap detected. Filling {num_missing_rec} records with step: {step}")
|
|
|
|
for ts in range(int(last_timestamp) + step, int(time_stamp), step):
|
|
new_row = row.copy()
|
|
new_row['time'] = ts
|
|
filled_records.append(new_row)
|
|
logger.debug(f"Filled timestamp: {ts}")
|
|
|
|
filled_records.append(row)
|
|
last_timestamp = time_stamp
|
|
|
|
logger.info("Data holes filled successfully.")
|
|
return pd.DataFrame(filled_records)
|
|
|
|
def _populate_db(self, ex_details: list[str], data: pd.DataFrame = None) -> None:
|
|
if data is None or data.empty:
|
|
logger.debug(f'No records to insert {data}')
|
|
return
|
|
|
|
table_name = self._make_key(ex_details=ex_details)
|
|
symbol, _, exchange, _ = ex_details
|
|
|
|
self.db.insert_candles_into_db(candlesticks=data, table_name=table_name, symbol=symbol, exchange_name=exchange)
|
|
logger.info(f'Data inserted into table {table_name}')
|
|
|
|
|
|
class IndicatorCache(ServerInteractions):
|
|
"""
|
|
IndicatorCache extends ServerInteractions and manages caching of both instantiated indicator objects
|
|
and their calculated data. This class optimizes the indicator calculation process by caching and
|
|
reusing data where possible, minimizing redundant computations.
|
|
|
|
Attributes:
|
|
indicator_registry (dict): A dictionary mapping indicator types (e.g., 'SMA', 'EMA') to their classes.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""
|
|
Initialize the IndicatorCache with caches for indicators and their calculated data.
|
|
"""
|
|
super().__init__()
|
|
|
|
# Registry of available indicators
|
|
self.indicator_registry = indicators_registry
|
|
|
|
@staticmethod
|
|
def _make_indicator_key(symbol: str, timeframe: str, exchange_name: str, indicator_type: str, period: int) -> str:
|
|
"""
|
|
Generates a unique cache key for caching the indicator data based on the input properties.
|
|
"""
|
|
return f"{symbol}_{timeframe}_{exchange_name}_{indicator_type}_{period}"
|
|
|
|
def get_indicator_instance(self, indicator_type: str, properties: dict) -> Indicator:
|
|
"""
|
|
Retrieves or creates an instance of an indicator based on its type and properties.
|
|
"""
|
|
if indicator_type not in self.indicator_registry:
|
|
raise ValueError(f"Unsupported indicator type: {indicator_type}")
|
|
indicator_class = self.indicator_registry[indicator_type]
|
|
return indicator_class(name=indicator_type, indicator_type=indicator_type, properties=properties)
|
|
|
|
def set_user_indicator_properties(self, user_id: str, indicator_type: str, symbol: str, timeframe: str,
|
|
exchange_name: str, display_properties: dict):
|
|
"""
|
|
Stores or updates user-specific display properties for an indicator.
|
|
"""
|
|
if not isinstance(display_properties, dict):
|
|
raise ValueError("display_properties must be a dictionary")
|
|
|
|
user_cache_key = f"user_{user_id}_{indicator_type}_{symbol}_{timeframe}_{exchange_name}"
|
|
self.serialized_datacache_insert(key=user_cache_key, data=display_properties,
|
|
cache_name='user_display_properties')
|
|
|
|
def get_user_indicator_properties(self, user_id: str, indicator_type: str, symbol: str, timeframe: str,
|
|
exchange_name: str) -> dict:
|
|
"""
|
|
Retrieves user-specific display properties for an indicator.
|
|
"""
|
|
# Ensure the arguments are valid
|
|
if not isinstance(user_id, str) or not isinstance(indicator_type, str) or \
|
|
not isinstance(symbol, str) or not isinstance(timeframe, str) or \
|
|
not isinstance(exchange_name, str):
|
|
raise TypeError("All arguments must be of type str")
|
|
user_cache_key = f"user_{user_id}_{indicator_type}_{symbol}_{timeframe}_{exchange_name}"
|
|
return self.get_rows_from_datacache(key=user_cache_key, cache_name='user_display_properties')
|
|
|
|
def find_gaps_in_intervals(self, cached_data, start_idx, end_idx, timeframe, min_gap_size=None):
|
|
"""
|
|
Recursively searches for missing intervals in cached data based on time gaps.
|
|
|
|
:param cached_data: DataFrame containing cached time-series data.
|
|
:param start_idx: Start index in the DataFrame.
|
|
:param end_idx: End index in the DataFrame.
|
|
:param timeframe: Expected interval between consecutive time points (e.g., '5m', '1h').
|
|
:param min_gap_size: The minimum allowable gap (in terms of the number of missing records) to consider.
|
|
:return: A list of tuples representing the missing intervals (start, end).
|
|
"""
|
|
if cached_data.empty or end_idx - start_idx <= 1:
|
|
return []
|
|
|
|
start_time = cached_data['time'].iloc[start_idx]
|
|
end_time = cached_data['time'].iloc[end_idx]
|
|
|
|
expected_records = estimate_record_count(start_time, end_time, timeframe)
|
|
actual_records = end_idx - start_idx + 1
|
|
|
|
if expected_records == actual_records:
|
|
return []
|
|
|
|
if min_gap_size and (expected_records - actual_records) < min_gap_size:
|
|
return []
|
|
|
|
mid_idx = (start_idx + end_idx) // 2
|
|
left_missing = self.find_gaps_in_intervals(cached_data, start_idx, mid_idx, timeframe, min_gap_size)
|
|
right_missing = self.find_gaps_in_intervals(cached_data, mid_idx, end_idx, timeframe, min_gap_size)
|
|
|
|
return left_missing + right_missing
|
|
|
|
def calculate_indicator(self, user_name: str, symbol: str, timeframe: str, exchange_name: str, indicator_type: str,
|
|
start_datetime: dt.datetime, end_datetime: dt.datetime, properties: dict) -> dict:
|
|
"""
|
|
Fetches or calculates the indicator data for the specified range, caches the shared results,
|
|
and combines them with user-specific display properties or indicator defaults.
|
|
"""
|
|
# Ensure start_datetime and end_datetime are timezone-aware
|
|
if start_datetime.tzinfo is None or end_datetime.tzinfo is None:
|
|
raise ValueError("Datetime objects must be timezone-aware")
|
|
|
|
# Step 1: Create cache key and indicator instance
|
|
cache_key = self._make_indicator_key(symbol, timeframe, exchange_name, indicator_type, properties['period'])
|
|
indicator_instance = self.get_indicator_instance(indicator_type, properties)
|
|
|
|
# Step 2: Check cache for existing data
|
|
calculated_data = self._fetch_cached_data(cache_key, start_datetime, end_datetime)
|
|
missing_intervals = self._determine_missing_intervals(calculated_data, start_datetime, end_datetime, timeframe)
|
|
|
|
# Step 3: Fetch and calculate data for missing intervals
|
|
if missing_intervals:
|
|
calculated_data = self._fetch_and_calculate_missing_data(
|
|
missing_intervals, calculated_data, indicator_instance, user_name, symbol, timeframe, exchange_name
|
|
)
|
|
|
|
# Step 4: Cache the newly calculated data
|
|
self.serialized_datacache_insert(key=cache_key, data=calculated_data, cache_name='indicator_data')
|
|
|
|
# Step 5: Retrieve and merge user-specific display properties with defaults
|
|
merged_properties = self._get_merged_properties(user_name, indicator_type, symbol, timeframe, exchange_name,
|
|
indicator_instance)
|
|
|
|
return {
|
|
'calculation_data': calculated_data.to_dict('records'),
|
|
'display_properties': merged_properties
|
|
}
|
|
|
|
def _fetch_cached_data(self, cache_key: str, start_datetime: dt.datetime,
|
|
end_datetime: dt.datetime) -> pd.DataFrame:
|
|
"""
|
|
Fetches cached data for the given time range.
|
|
"""
|
|
# Retrieve cached data (expected to be a DataFrame with 'time' in Unix ms)
|
|
cached_df = self.get_rows_from_datacache(key=cache_key, cache_name='indicator_data')
|
|
|
|
# If no cached data, return an empty DataFrame
|
|
if cached_df is None or cached_df.empty:
|
|
return pd.DataFrame()
|
|
|
|
# Convert start and end datetime to Unix timestamps (milliseconds)
|
|
start_timestamp = int(start_datetime.timestamp() * 1000)
|
|
end_timestamp = int(end_datetime.timestamp() * 1000)
|
|
|
|
# Convert cached 'time' column to Unix timestamps for comparison
|
|
cached_df['time'] = cached_df['time'].astype('int64') // 10 ** 6
|
|
|
|
# Return only the data within the specified time range
|
|
return cached_df[(cached_df['time'] >= start_timestamp) & (cached_df['time'] <= end_timestamp)]
|
|
|
|
def _determine_missing_intervals(self, cached_data: pd.DataFrame, start_datetime: dt.datetime,
|
|
end_datetime: dt.datetime, timeframe: str):
|
|
"""
|
|
Determines missing intervals in the cached data by comparing the requested time range with cached data.
|
|
|
|
:param cached_data: Cached DataFrame with time-series data in Unix timestamp (ms).
|
|
:param start_datetime: Start datetime of the requested data range.
|
|
:param end_datetime: End datetime of the requested data range.
|
|
:param timeframe: Expected interval between consecutive time points (e.g., '5m', '1h').
|
|
:return: A list of tuples representing the missing intervals (start, end).
|
|
"""
|
|
# Convert start and end datetime to Unix timestamps (milliseconds)
|
|
# start_timestamp = int(start_datetime.timestamp() * 1000)
|
|
end_timestamp = int(end_datetime.timestamp() * 1000)
|
|
|
|
if cached_data is not None and not cached_data.empty:
|
|
# Find the last (most recent) time in the cached data (which is in Unix timestamp ms)
|
|
cached_end = cached_data['time'].max()
|
|
|
|
# If the cached data ends before the requested end time, add that range to missing intervals
|
|
if cached_end < end_timestamp:
|
|
# Convert cached_end back to datetime for consistency in returned intervals
|
|
cached_end_dt = pd.to_datetime(cached_end, unit='ms').to_pydatetime()
|
|
return [(cached_end_dt, end_datetime)]
|
|
|
|
# Otherwise, check for gaps within the cached data itself
|
|
return self.find_gaps_in_intervals(cached_data, 0, len(cached_data) - 1, timeframe)
|
|
|
|
# If no cached data exists, the entire range is missing
|
|
return [(start_datetime, end_datetime)]
|
|
|
|
def _fetch_and_calculate_missing_data(self, missing_intervals, calculated_data, indicator_instance, user_name,
|
|
symbol, timeframe, exchange_name):
|
|
"""
|
|
Fetches and calculates missing data for the specified intervals.
|
|
"""
|
|
for interval_start, interval_end in missing_intervals:
|
|
# Ensure interval_start and interval_end are timezone-aware
|
|
if interval_start.tzinfo is None:
|
|
interval_start = interval_start.replace(tzinfo=dt.timezone.utc)
|
|
if interval_end.tzinfo is None:
|
|
interval_end = interval_end.replace(tzinfo=dt.timezone.utc)
|
|
|
|
ohlc_data = self.get_or_fetch_from('data', start_datetime=interval_start, end_datetime=interval_end,
|
|
ex_details=[symbol, timeframe, exchange_name, user_name])
|
|
if ohlc_data.empty or 'close' not in ohlc_data.columns:
|
|
continue
|
|
|
|
new_data = indicator_instance.calculate(candles=ohlc_data, user_name=user_name)
|
|
|
|
calculated_data = pd.concat([calculated_data, new_data], ignore_index=True)
|
|
|
|
return calculated_data
|
|
|
|
def _get_merged_properties(self, user_name, indicator_type, symbol, timeframe, exchange_name, indicator_instance):
|
|
"""
|
|
Retrieves and merges user-specific properties with default indicator properties.
|
|
"""
|
|
user_properties = self.get_user_indicator_properties(user_name, indicator_type,
|
|
symbol, timeframe, exchange_name)
|
|
if not user_properties:
|
|
user_properties = {
|
|
key: value for key,
|
|
value in indicator_instance.properties.items() if key.startswith('color') or key.startswith('thickness')
|
|
}
|
|
return {**indicator_instance.properties, **user_properties}
|
|
|
|
|
|
class DataCache(IndicatorCache):
|
|
"""
|
|
Extends ServerInteractions, DatabaseInteractions, SnapshotDataCache, and DataCacheBase to create a fully functional
|
|
cache management system that can store and manage custom caching objects of type InMemoryCache. The following
|
|
illustrates the complete structure of data, leveraging all parts of the system.
|
|
|
|
`caches` is a dictionary of InMemoryCache objects indexed by name. An InMemoryCache is a custom object that contains
|
|
attributes such as `limit`, `eviction_policy`, and a `cache`, which is a DataFrame containing the actual data. This
|
|
structure allows the cache management system to leverage pandas' powerful querying and filtering capabilities to
|
|
manipulate and query the data efficiently while maintaining the cache's overall structure.
|
|
|
|
Example Structure:
|
|
|
|
caches = {
|
|
[name1]: {
|
|
'limit': [limit],
|
|
'eviction_policy': [eviction_policy],
|
|
'cache': pd.DataFrame({
|
|
'tbl_key': [key],
|
|
'creation_time': [creation_time],
|
|
'expire_delta': [expire_delta],
|
|
'data': pd.DataFrame({
|
|
'column1': [data],
|
|
'column2': [data],
|
|
'column3': [data]
|
|
})
|
|
})
|
|
}
|
|
}
|
|
|
|
# The following is an example of the structure in use:
|
|
|
|
caches = {
|
|
'users': {
|
|
'limit': 100,
|
|
'eviction_policy': 'deny',
|
|
'cache': pd.DataFrame({
|
|
'tbl_key': ['user1', 'user2'],
|
|
'creation_time': [
|
|
datetime.datetime(2024, 8, 24, 12, 45, 31, 117684),
|
|
datetime.datetime(2023, 8, 12, 6, 35, 21, 113644)
|
|
],
|
|
'expire_delta': [
|
|
datetime.timedelta(seconds=600),
|
|
None
|
|
],
|
|
'data': [
|
|
pd.DataFrame({
|
|
'user_name': ['Billy'],
|
|
'password': [1234],
|
|
'exchanges': [['ex1', 'ex2', 'ex3']]
|
|
}),
|
|
pd.DataFrame({
|
|
'user_name': ['Patty'],
|
|
'password': [5678],
|
|
'exchanges': [['ex1', 'ex3', 'ex5']]
|
|
})
|
|
]
|
|
})
|
|
},
|
|
'strategies': {
|
|
'limit': 100,
|
|
'eviction_policy': 'expire',
|
|
'cache': pd.DataFrame({
|
|
'tbl_key': ['strategy1', 'strategy2'],
|
|
'creation_time': [
|
|
datetime.datetime(2024, 3, 13, 12, 45, 31, 117684),
|
|
datetime.datetime(2023, 4, 12, 6, 35, 21, 113644)
|
|
],
|
|
'expire_delta': [
|
|
datetime.timedelta(seconds=600),
|
|
datetime.timedelta(seconds=600)
|
|
],
|
|
'data': [
|
|
pd.DataFrame({
|
|
'strategy_name': ['Awesome_v1.2'],
|
|
'code': [json.dumps({"params": {"t": "ex1", "m": 100, "ex": True}})],
|
|
'max_loss': [100]
|
|
}),
|
|
pd.DataFrame({
|
|
'strategy_name': ['Awesome_v2.0'],
|
|
'code': [json.dumps({"params": {"t": "ex2", "m": 67, "ex": False}})],
|
|
'max_loss': [40]
|
|
})
|
|
]
|
|
})
|
|
},
|
|
...
|
|
}
|
|
|
|
Usage Examples:
|
|
# Extract all user data DataFrames
|
|
user_data = pd.concat(caches['users']['cache']['data'].tolist(), ignore_index=True)
|
|
|
|
# Check if a username is taken
|
|
username_to_check = 'Billy'
|
|
is_taken = not user_data[user_data['user_name'] == username_to_check].empty
|
|
|
|
# Extract all user data DataFrames
|
|
user_data = pd.concat(caches['users']['cache']['data'].tolist(), ignore_index=True)
|
|
|
|
# Find users with a specific exchange
|
|
exchange_to_check = 'ex3'
|
|
users_with_exchange = user_data[user_data['exchanges'].apply(lambda x: exchange_to_check in x)]
|
|
|
|
# Filter out the expired users
|
|
current_time = datetime.datetime.now(datetime.timezone.utc)
|
|
non_expired_users = caches['users']['cache'].apply(
|
|
lambda row: current_time <= row['creation_time'] + row['expire_delta'] if row['expire_delta'] else True,
|
|
axis=1
|
|
)
|
|
caches['users']['cache'] = caches['users']['cache'][non_expired_users]
|
|
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
logger.info("DataCache initialized.")
|