From a16cc542d296fcf024dab1d4e078214dc78e4ff5 Mon Sep 17 00:00:00 2001 From: Rob Date: Sun, 25 Aug 2024 10:25:08 -0300 Subject: [PATCH] Refactored DataCache, again. Implemented more advance cache management. All DataCache tests pass. --- src/BrighterTrades.py | 4 +- src/DataCache_v2.py | 61 +- src/DataCache_v3.py | 1309 +++++++++++++++++ src/Database.py | 20 +- src/Strategies.py | 37 +- src/Users.py | 95 +- src/trade.py | 12 +- ...test_DataCache_v2.py => test_DataCache.py} | 525 +++++-- 8 files changed, 1817 insertions(+), 246 deletions(-) create mode 100644 src/DataCache_v3.py rename tests/{test_DataCache_v2.py => test_DataCache.py} (54%) diff --git a/src/BrighterTrades.py b/src/BrighterTrades.py index 7755dc7..33bd4d0 100644 --- a/src/BrighterTrades.py +++ b/src/BrighterTrades.py @@ -1,7 +1,7 @@ from typing import Any from Users import Users -from DataCache_v2 import DataCache +from DataCache_v3 import DataCache from Strategies import Strategies from backtesting import Backtester from candles import Candles @@ -503,8 +503,6 @@ class BrighterTrades: print(f'ERROR SETTING VALUE') print(f'The string received by the server was: /n{params}') - # Save any changes to storage - self.config.config_and_states('save') # Now that the state is changed reload price history. self.candles.set_cache(user_name=user_name) return diff --git a/src/DataCache_v2.py b/src/DataCache_v2.py index d9a2861..9bae91e 100644 --- a/src/DataCache_v2.py +++ b/src/DataCache_v2.py @@ -94,10 +94,35 @@ class DataCache: self.db = Database() self.exchanges = exchanges # Single DataFrame for all cached data - self.cache = pd.DataFrame(columns=['key', 'data']) # Assuming 'key' and 'data' are necessary + self.caches = {} logger.info("DataCache initialized.") - def fetch_cached_rows(self, table: str, filter_vals: Tuple[str, Any]) -> pd.DataFrame | None: + def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None: + """ + Sets or updates an entry in the cache with the provided key. If the key already exists, the existing entry + is replaced unless `do_not_overwrite` is True. In that case, the existing entry is preserved. + + Parameters: + data: The data to be cached. This can be of any type. + key: The unique key used to identify the cached data. + do_not_overwrite : The default is False, meaning that the existing entry will be replaced. + """ + if do_not_overwrite and key in self.cache['key'].values: + return + + # Construct a new DataFrame row with the key and data + new_row = pd.DataFrame({'key': [key], 'data': [data]}) + + # If the key already exists in the cache, remove the old entry + self.cache = self.cache[self.cache['key'] != key] + + # Append the new row to the cache + self.cache = pd.concat([self.cache, new_row], ignore_index=True) + + print(f'Current Cache: {self.cache}') + logger.debug(f'Cache set for key: {key}') + + def get_or_fetch_rows(self, table: str, filter_vals: Tuple[str, Any]) -> pd.DataFrame | None: """ Retrieves rows from the cache if available; otherwise, queries the database and caches the result. @@ -170,10 +195,10 @@ class DataCache: :return: True if the attribute is already taken, False otherwise. """ # Fetch rows from the specified table where the attribute matches the given value - result = self.fetch_cached_rows(table=table, filter_vals=(attr, val)) + result = self.get_or_fetch_rows(table=table, filter_vals=(attr, val)) return result is not None and not result.empty - def fetch_cached_item(self, item_name: str, table_name: str, filter_vals: Tuple[str, Any]) -> Any: + def fetch_item(self, item_name: str, table_name: str, filter_vals: Tuple[str, Any]) -> Any: """ Retrieves a specific item from the cache or database, caching the result if necessary. @@ -183,16 +208,16 @@ class DataCache: :return: The value of the requested item. :raises ValueError: If the item is not found in either the cache or the database. """ - # Fetch the relevant rows - rows = self.fetch_cached_rows(table_name, filter_vals) + # Fetch the relevant rows from the cache or database + rows = self.get_or_fetch_rows(table_name, filter_vals) if rows is not None and not rows.empty: # Return the specific item from the first matching row. return rows.iloc[0][item_name] # If the item is not found, raise an error. - raise ValueError(f"Item {item_name} not found in {table_name} where {filter_vals[0]} = {filter_vals[1]}") + raise ValueError(f"Item '{item_name}' not found in '{table_name}' where {filter_vals[0]} = {filter_vals[1]}") - def modify_cached_row(self, table: str, filter_vals: Tuple[str, Any], field_name: str, new_data: Any) -> None: + def modify_item(self, table: str, filter_vals: Tuple[str, Any], field_name: str, new_data: Any) -> None: """ Modifies a specific field in a row within the cache and updates the database accordingly. @@ -202,7 +227,7 @@ class DataCache: :param new_data: The new data to be set. """ # Retrieve the row from the cache or database - row = self.fetch_cached_rows(table, filter_vals) + row = self.get_or_fetch_rows(table, filter_vals) if row is None or row.empty: raise ValueError(f"Row not found in cache or database for {filter_vals[0]} = {filter_vals[1]}") @@ -223,7 +248,7 @@ class DataCache: # Update the database with the modified row self.db.insert_dataframe(row.drop(columns='id'), table) - def insert_data(self, df: pd.DataFrame, table: str, skip_cache: bool = False) -> None: + def insert_df(self, df: pd.DataFrame, table: str, skip_cache: bool = False) -> None: """ Inserts data into the specified table in the database, with an option to skip cache insertion. @@ -569,22 +594,6 @@ class DataCache: else: raise KeyError(f"Cache key '{cache_key}' not found.") - def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None: - if do_not_overwrite and key in self.cache['key'].values: - return - - # Corrected construction of the new row - new_row = pd.DataFrame({'key': [key], 'data': [data]}) - - # If the key already exists, drop the old entry - self.cache = self.cache[self.cache['key'] != key] - - # Append the new row to the cache - self.cache = pd.concat([self.cache, new_row], ignore_index=True) - - print(f'Current Cache: {self.cache}') - logger.debug(f'Cache set for key: {key}') - def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str, start_datetime: dt.datetime = None, end_datetime: dt.datetime = None) -> pd.DataFrame: diff --git a/src/DataCache_v3.py b/src/DataCache_v3.py new file mode 100644 index 0000000..fc593d8 --- /dev/null +++ b/src/DataCache_v3.py @@ -0,0 +1,1309 @@ +from abc import ABC, abstractmethod +import logging +import datetime as dt +import pandas as pd +import numpy as np +import json + +from shared_utilities import unix_time_millis +from Database import Database + +# Configure logging +logger = logging.getLogger(__name__) + + +# Helper Methods +def timeframe_to_timedelta(timeframe: str) -> pd.Timedelta | pd.DateOffset: + digits = int("".join([i if i.isdigit() else "" for i in timeframe])) + unit = "".join([i if i.isalpha() else "" for i in timeframe]) + + if unit == 'm': + return pd.Timedelta(minutes=digits) + elif unit == 'h': + return pd.Timedelta(hours=digits) + elif unit == 'd': + return pd.Timedelta(days=digits) + elif unit == 'w': + return pd.Timedelta(weeks=digits) + elif unit == 'M': + return pd.DateOffset(months=digits) + elif unit == 'Y': + return pd.DateOffset(years=digits) + else: + raise ValueError(f"Invalid timeframe unit: {unit}") + + +def estimate_record_count(start_time, end_time, timeframe: str) -> int: + """ + Estimate the number of records expected between start_time and end_time based on the given timeframe. + Accepts either datetime objects or Unix timestamps in milliseconds. + """ + # Check if the input is in milliseconds (timestamp) + if isinstance(start_time, (int, float, np.integer)) and isinstance(end_time, (int, float, np.integer)): + # Convert timestamps from milliseconds to seconds for calculation + start_time = int(start_time) / 1000 + end_time = int(end_time) / 1000 + start_datetime = dt.datetime.utcfromtimestamp(start_time).replace(tzinfo=dt.timezone.utc) + end_datetime = dt.datetime.utcfromtimestamp(end_time).replace(tzinfo=dt.timezone.utc) + elif isinstance(start_time, dt.datetime) and isinstance(end_time, dt.datetime): + if start_time.tzinfo is None: + raise ValueError("start_time is timezone naive. Please provide a timezone-aware datetime.") + if end_time.tzinfo is None: + raise ValueError("end_time is timezone naive. Please provide a timezone-aware datetime.") + start_datetime = start_time + end_datetime = end_time + else: + raise ValueError("start_time and end_time must be either both " + "datetime objects or both Unix timestamps in milliseconds.") + + delta = timeframe_to_timedelta(timeframe) + total_seconds = (end_datetime - start_datetime).total_seconds() + expected_records = total_seconds // delta.total_seconds() + return int(expected_records) + + +# Cache Interface +class Cache(ABC): + """ + Abstract base class that defines the interface for a cache. + """ + + @abstractmethod + def set_item(self, key: str, data: any, expire_delta: dt.timedelta = None): + pass + + @abstractmethod + def get_item(self, key: str) -> any: + pass + + @abstractmethod + def remove_item(self, key: str): + pass + + @abstractmethod + def clean_expired_items(self): + pass + + @abstractmethod + def get_all_items(self) -> pd.DataFrame: + pass + + +# In-Memory Cache Implementation +class InMemoryCache(Cache): + """ + In-memory storage with a size limit and customizable eviction policies. + + Attributes: + cache (pd.DataFrame): The in-memory storage for cache items. + limit (int): The maximum number of items allowed in the cache. None means no limit. + eviction_policy (str): The policy used when the cache reaches its limit. Options: 'evict', 'deny'. + + Methods: + set_item(key: str, data: any, expire_delta: dt.timedelta = None): Adds an item to the cache. + get_item(key: str) -> any: Retrieves an item from the cache by its key. + get_all_items() -> pd.DataFrame: Returns all items currently stored in the cache. + remove_item(key: str): Removes an item from the cache by its key. + clean_expired_items(): Cleans up expired items from the cache. + + Usage Example: + # Create a cache with a limit of 2 items and 'evict' policy + cached_users = InMemoryCache(limit=2, eviction_policy='evict') + + # Set some items in the cache. + cached_users.set_item("user_bob", "{password:'BobPass'}", expire_delta=dt.timedelta(seconds=10)) + cached_users.set_item("user_alice", "{password:'AlicePass'}", expire_delta=dt.timedelta(seconds=20)) + + # Retrieve an item + retrieved_item = cached_users.get_item('user_bob') + print(f"Retrieved: {retrieved_item}") # Output: Retrieved: {password:'BobPass'} + + # Add another item, causing the oldest item to be evicted + cached_users.set_item("user_billy", "{password:'BillyPass'}") + + # Attempt to retrieve the evicted item + evicted_item = cached_users.get_item('user_bob') + print(f"Evicted Item: {evicted_item}") # Output: Evicted Item: None + + # Retrieve the current items in the cache + all_items = cached_users.get_all_items() + print(all_items) + + # Clean expired items + cached_users.clean_expired_items() + """ + + def __init__(self, limit: int = None, eviction_policy: str = 'evict'): + """ + Initializes the InMemoryCache with an empty DataFrame, an optional size limit, and a specified eviction policy. + + :param limit: The maximum number of items allowed in the cache. If None, the cache size is unlimited. + :param eviction_policy: The policy used when the cache reaches its limit. Options: 'evict', 'deny'. + """ + self.cache = pd.DataFrame(columns=['key', 'data', 'creation_time', 'expire_delta']) + self.limit = limit + self.eviction_policy = eviction_policy + + def set_item(self, key: str, data: any, expire_delta: dt.timedelta = None): + """ + Adds an item to the cache, optionally specifying an expiration duration. + + :param key: The key associated with the cache item. + :param data: The data to be cached. + :param expire_delta: Optional duration after which the cache will expire. + """ + if self.limit is not None and len(self.cache) >= self.limit: + if self.eviction_policy == 'evict': + # Evict the oldest item (based on creation time) + self.cache = self.cache.sort_values(by='creation_time').iloc[1:] + elif self.eviction_policy == 'deny': + # Deny adding the new item if the limit is reached + print(f"Cache limit reached. Item with key '{key}' was not added.") + return + + creation_time = dt.datetime.now(dt.timezone.utc) + new_item = pd.DataFrame({ + 'key': [key], + 'data': [data], + 'creation_time': [creation_time], + 'expire_delta': [expire_delta] + }) + + # Remove any existing item with the same key + self.cache = self.cache[self.cache['key'] != key] + + # Add the new item + self.cache = pd.concat([self.cache, new_item], ignore_index=True) + + def get_item(self, key: str) -> any: + """ + Retrieves an item from the cache by its key. + + :param key: The key associated with the cache item. + :return Any: The cached data, or None if the key does not exist or the item is expired. + """ + item = self.cache[self.cache['key'] == key] + if item.empty: + return None + + current_time = dt.datetime.now(dt.timezone.utc) + creation_time = item['creation_time'].iloc[0] + expire_delta = item['expire_delta'].iloc[0] + + if pd.notna(expire_delta) and current_time > creation_time + expire_delta: + self.remove_item(key) # Remove expired item + return None + + return item['data'].iloc[0] + + def get_all_items(self) -> pd.DataFrame: + """ + Returns all items currently stored in the cache. + + :return pd.DataFrame: A DataFrame containing all cached items. + """ + return self.cache + + def remove_item(self, key: str): + """ + Removes an item from the cache by its key. + + :param key: The key associated with the cache item to be removed. + """ + self.cache = self.cache[self.cache['key'] != key] + + def clean_expired_items(self): + """ + Cleans up expired items from the cache. Items with no expiration time (expire_delta is None) are not removed. + """ + current_time = dt.datetime.now(dt.timezone.utc) + + # Mask for non-expiring items (where expire_delta is None) + non_expiring_mask = self.cache['expire_delta'].isna() + + # Mask for items that have not yet expired + not_expired_mask = ( + self.cache['creation_time'] + self.cache['expire_delta'].fillna(pd.Timedelta(0)) > current_time) + + # Combine the masks + mask_to_keep = non_expiring_mask | not_expired_mask + + # Apply the mask to filter the cache + self.cache = self.cache[mask_to_keep].reset_index(drop=True) + + +class DataCacheBase: + """ + Manages multiple caches, delegating cache operations to the appropriate cache instance. + + Attributes: + caches (dict[str, 'Cache']): A dictionary mapping cache names to cache instances. + + Methods: + create_cache(cache_name: str, cache_type: type = 'InMemoryCache', **kwargs): Creates a new cache with the + specified name and type. + set_cache_item(key: str, data: any, expire_delta: dt.timedelta = None, do_not_overwrite: bool = False, + cache_name: str = 'default_cache', limit: int = None, eviction_policy: str = 'evict'): + Sets an item in the specified cache, creating the cache if it doesn't exist. + cache_exists(cache_name: str, key: str) -> bool: Checks if a specific key exists in the specified cache. + get_cache_item(key: str, cache_name: str = 'default_cache') -> any: Retrieves an item from the specified cache. + get_all_cache_items(cache_name: str) -> pd.DataFrame: Returns all items from the specified cache. + remove_cache_item(cache_name: str, key: str): Removes an item from the specified cache. + clean_expired_items(cache_name: str = None): Cleans up expired items from the specified cache or all caches. + + Usage Example: + # Create a DataCacheBase instance + cache_manager = DataCacheBase() + + # Set some items in the default cache. The cache is created automatically with default settings. + cache_manager.set_cache_item('key1', 'data1', expire_delta=dt.timedelta(seconds=10)) + cache_manager.set_cache_item('key2', 'data2', expire_delta=dt.timedelta(seconds=20)) + + # Check if a key exists in the default cache. + exists = cache_manager.cache_exists('default_cache', 'key1') + print(f"Key1 exists: {exists}") # Output: Key1 exists: True + + # Add another item, causing the oldest item to be evicted. + cache_manager.set_cache_item('key3', 'data3', cache_name='default_cache') + + # Retrieve an item from the default cache. + item = cache_manager.get_cache_item('key2') + print(f"Retrieved Item: {item}") # Output: Retrieved Item: data2 + + # Attempt to retrieve the evicted item. + evicted_item = cache_manager.get_cache_item('key1') + print(f"Evicted Item: {evicted_item}") # Output: Evicted Item: None + + # Create a named cache with a limit and custom eviction policy. + cache_manager.set_cache_item('keyA', 'dataA', cache_name='my_cache', limit=3, eviction_policy='deny') + + # Set items in the named cache. + cache_manager.set_cache_item('keyB', 'dataB', cache_name='my_cache') + cache_manager.set_cache_item('keyC', 'dataC', cache_name='my_cache') + + # Retrieve all items in the named cache. + all_items = cache_manager.get_all_cache_items('my_cache') + print(all_items) + + # Remove an item from the named cache + cache_manager.remove_cache_item('my_cache', 'keyB') + + # Clean expired items in the named cache + cache_manager.clean_expired_items('my_cache') + + # Clean expired items in all caches + cache_manager.clean_expired_items() + """ + + def __init__(self): + self.caches: dict[str, 'Cache'] = {} + + def create_cache(self, cache_name: str, cache_type: type = InMemoryCache, **kwargs): + """ + Creates a new cache with the specified name and type. + + :param cache_name: The name of the cache. + :param cache_type: Optional type of cache to create (default is InMemoryCache). + :param kwargs: Additional arguments to pass to the cache constructor. + """ + if cache_name in self.caches: + raise ValueError(f"Cache with name '{cache_name}' already exists.") + self.caches[cache_name] = cache_type(**kwargs) + + def set_cache_item(self, key: str, data: any, expire_delta: dt.timedelta = None, do_not_overwrite: bool = False, + cache_name: str = 'default_cache', limit: int = None, eviction_policy: str = 'evict'): + """ + Sets or updates an entry in the specified cache. If the key already exists, the existing entry + is replaced unless `do_not_overwrite` is True. Automatically creates the cache if it doesn't exist. + + :param key: The key associated with the cache item. + :param data: The data to be cached. + :param expire_delta: The optional duration after which the cache will expire. + :param do_not_overwrite: If True, the existing entry will not be overwritten. Default is False. + :param cache_name: The name of the cache to use. Default is 'default_cache'. + :param limit: The maximum number of items allowed in the cache (only used if creating a new cache). + :param eviction_policy: The policy used when the cache reaches its limit (only used if creating a new cache). + """ + # Automatically create the cache if it doesn't exist + if cache_name not in self.caches: + print(f"Creating Cache '{cache_name}' because it does not exist.") + self.create_cache(cache_name, cache_type=InMemoryCache, limit=limit, eviction_policy=eviction_policy) + + # Check if the key exists and handle `do_not_overwrite` + existing_data = self.get_cache_item(key=key, cache_name=cache_name) + if do_not_overwrite and existing_data is not None: + print(f"Key '{key}' already exists in cache '{cache_name}' and" + f" `do_not_overwrite` is True. Skipping update.") + return + + # Set or overwrite the cache item + self._get_cache(cache_name).set_item(key, data, expire_delta) + + def cache_exists(self, cache_name: str, key: str) -> bool: + """ + Checks if a specific key exists in the specified cache. + + :param cache_name: The name of the cache to check. + :param key: The key to look for in the cache. + :return: True if the key exists in the cache, False otherwise. + """ + if cache_name not in self.caches: + return False + + cache_df = self.caches[cache_name].get_all_items() + return key in cache_df['key'].values + + def _get_cache(self, cache_name: str) -> Cache | None: + """ + Retrieves the cache instance associated with the given cache name. + + :param cache_name: The name of the cache. + :return Cache: The cache instance associated with the cache name. + :raises ValueError: If the cache with the given name does not exist. + """ + if cache_name not in self.caches: + return None + return self.caches[cache_name] + + def get_cache_item(self, key: str, cache_name: str = 'default_cache') -> any: + """ + Retrieves an item from the specified cache. + + :param cache_name: The name of the cache. + :param key: The key associated with the cache item. + :return Any: The cached data, or None if the key does not exist or the item is expired. + """ + cache = self._get_cache(cache_name) + if cache: + return cache.get_item(key) + else: + return None + + def get_all_cache_items(self, cache_name: str) -> pd.DataFrame: + """ + Returns all items from the specified cache. + + :param cache_name: The name of the cache. + :return pd.DataFrame: A DataFrame containing all cached items from the specified cache. + """ + return self._get_cache(cache_name).get_all_items() + + def remove_cache_item(self, cache_name: str, key: str): + """ + Removes an item from the specified cache. + + :param cache_name: The name of the cache. + :param key: The key associated with the cache item to be removed. + """ + self._get_cache(cache_name).remove_item(key) + + def clean_expired_items(self, cache_name: str = None): + """ + Cleans up expired items from the specified cache or all caches if no cache name is provided. + + :param cache_name: The name of the cache to clean, or None to clean all caches. + """ + if cache_name: + self._get_cache(cache_name).clean_expired_items() + else: + for cache in self.caches.values(): + cache.clean_expired_items() + + +class SnapshotDataCache(DataCacheBase): + """ + Extends DataCacheBase with snapshot functionality. + + Attributes: + snapshots (dict): A dictionary to store snapshots, with cache names as keys and snapshot data as values. + + Methods: + snapshot_cache(cache_name: str): Takes a snapshot of the specified cache and stores it. + get_snapshot(cache_name: str): Retrieves the most recent snapshot of the specified cache. + list_snapshots() -> dict: Lists all available snapshots along with their timestamps. + + Usage Example: + # Create a SnapshotDataCache instance + snapshot_cache_manager = SnapshotDataCache() + + # Create an in-memory cache with a limit of 2 items and 'evict' policy + snapshot_cache_manager.create_cache('my_cache', cache_type=InMemoryCache, limit=2, eviction_policy='evict') + + # Set some items in the cache + snapshot_cache_manager.set_cache_item('my_cache', 'key1', 'data1', expire_delta=dt.timedelta(seconds=10)) + snapshot_cache_manager.set_cache_item('my_cache', 'key2', 'data2', expire_delta=dt.timedelta(seconds=20)) + + # Take a snapshot of the current state of 'my_cache' + snapshot_cache_manager.snapshot_cache('my_cache') + + # Add another item, causing the oldest item to be evicted + snapshot_cache_manager.set_cache_item('my_cache', 'key3', 'data3') + + # Retrieve the most recent snapshot of 'my_cache' + snapshot = snapshot_cache_manager.get_snapshot('my_cache') + print(f"Snapshot Data:\n{snapshot}") + + # List all available snapshots with their timestamps + snapshots_list = snapshot_cache_manager.list_snapshots() + print(f"Snapshots List: {snapshots_list}") + """ + + def __init__(self): + super().__init__() + self.snapshots = {} # Dictionary to store snapshots + + def snapshot_cache(self, cache_name: str): + """ + Takes a snapshot of the specified cache and stores it for later retrieval. + + :param cache_name: The name of the cache to snapshot. + :raises ValueError: If the cache with the given name does not exist. + """ + if cache_name not in self.caches: + raise ValueError(f"Cache with name '{cache_name}' does not exist.") + + # Create a deep copy of the cache to store as a snapshot + snapshot = self.caches[cache_name].get_all_items().copy() + + # Store the snapshot in the snapshots dictionary with a timestamp + timestamp = dt.datetime.now(dt.timezone.utc).isoformat() + self.snapshots[cache_name] = {'timestamp': timestamp, 'data': snapshot} + + print(f"Snapshot of cache '{cache_name}' taken at {timestamp}.") + + def get_snapshot(self, cache_name: str): + """ + Retrieves the most recent snapshot of the specified cache. + + :param cache_name: The name of the cache whose snapshot is to be retrieved. + :return: A DataFrame containing the snapshot data, or None if no snapshot exists. + """ + if cache_name not in self.snapshots: + print(f"No snapshot available for cache '{cache_name}'.") + return None + + return self.snapshots[cache_name]['data'] + + def list_snapshots(self): + """ + Lists all available snapshots along with their timestamps. + + :return: A dictionary where keys are cache names and values are timestamps of the snapshots. + """ + return {cache: info['timestamp'] for cache, info in self.snapshots.items()} + + +class DatabaseInteractions(SnapshotDataCache): + """ + Extends SnapshotDataCache with additional functionality, including database interactions and cache management. + + Attributes: + db (Database): A database connection instance for executing queries. + exchanges (list): A list of exchanges or other relevant entities. + TYPECHECKING_ENABLED (bool): A class attribute to toggle type checking. + + Methods: + get_or_fetch_rows(cache_name: str, filter_vals: tuple[str, any]) -> pd.DataFrame | None: + Retrieves rows from the cache or database. + fetch_item(item_name: str, cache_name: str, filter_vals: tuple[str, any]) -> any: + Retrieves a specific item from the cache or database. + insert_row(cache_name: str, columns: tuple, values: tuple, skip_cache: bool = False) -> None: + Inserts a row into the cache and database. + insert_df(df: pd.DataFrame, cache_name: str, skip_cache: bool = False) -> None: + Inserts data from a DataFrame into the cache and database. + remove_row(cache_name: str, filter_vals: tuple[str, any], additional_filter: tuple[str, any] = None, + remove_from_db: bool = True) -> None: + Removes a row from the cache and optionally the database. + is_attr_taken(cache_name: str, attr: str, val: any) -> bool: + Checks if a specific attribute has a given value in the cache. + modify_item(cache_name: str, filter_vals: tuple[str, any], field_name: str, new_data: any) -> None: + Modifies a field in a row within the cache and updates the database. + update_cached_dict(cache_name: str, cache_key: str, dict_key: str, data: any) -> None: + Updates a dictionary stored in the DataFrame cache. + + Usage Example: + # Create a DatabaseInteractions instance + db_cache_manager = DatabaseInteractions(exchanges=exchanges) + + # Set some data in the cache + db_cache_manager.set_cache_item(key='AAPL', data={'price': 100}, cache_name='stock_prices') + + # Fetch or query rows from the cache or database + rows = db_cache_manager.get_or_fetch_rows(cache_name='stock_prices', filter_vals=('symbol', 'AAPL')) + print(f"Fetched Rows:\n{rows}") + + # Fetch a specific item from the cache or database + price = db_cache_manager.fetch_item(item_name='price', cache_name='stock_prices', + filter_vals=('symbol', 'AAPL')) + print(f"Fetched Price: {price}") + + # Insert a new row into the cache and database + db_cache_manager.insert_row(cache_name='stock_prices', columns=('symbol', 'price'), values=('TSLA', 800)) + + # Check if an attribute value is already taken + is_taken = db_cache_manager.is_attr_taken(cache_name='stock_prices', attr='symbol', val='AAPL') + print(f"Is Symbol Taken: {is_taken}") + + # Modify an existing item in the cache and database + db_cache_manager.modify_item(cache_name='stock_prices', filter_vals=('symbol', 'AAPL'), + field_name='price', new_data=105) + + # Update a dictionary within the cache + db_cache_manager.update_cached_dict(cache_name='stock_prices', cache_key='AAPL', dict_key='price', data=110) + """ + TYPECHECKING_ENABLED = True + + def __init__(self, exchanges): + super().__init__() + self.db = Database() + self.exchanges = exchanges + logger.info("DataCache initialized.") + + def get_or_fetch_rows(self, cache_name: str, filter_vals: tuple[str, any]) -> pd.DataFrame | None: + """ + Retrieves rows from the cache if available; otherwise, queries the database and caches the result. + + :param cache_name: The key used to identify the cache (also the name of the database table). + :param filter_vals: A tuple containing the column name and the value to filter by. + :return: A DataFrame containing the requested rows, or None if no matching rows are found. + :raises ValueError: If the cache is not a DataFrame or does not contain DataFrames in the 'data' column. + """ + if cache_name in self.caches: + + # Retrieve all items in the specified cache + cache_df = self.get_all_cache_items(cache_name=cache_name) + + if not isinstance(cache_df, pd.DataFrame): + raise ValueError(f"Cache '{cache_name}' is not a DataFrame and cannot be used with get_or_fetch_rows.") + + # Combine all the DataFrames in the 'data' column into a single DataFrame + combined_data = pd.concat(cache_df['data'].values.tolist(), ignore_index=True) + + # Filter the combined DataFrame + query_str = f"{filter_vals[0]} == @filter_vals[1]" + matching_rows = combined_data.query(query_str) + + if not matching_rows.empty: + return matching_rows + + # If no data is found in the cache, fetch from the database + rows = self.db.get_rows_where(cache_name, filter_vals) + if rows is not None and not rows.empty: + # Store the fetched rows in the cache for future use + self.set_cache_item(key=filter_vals[1], data=rows, cache_name=cache_name) + return rows + + return None + + def fetch_item(self, item_name: str, cache_name: str, filter_vals: tuple[str, any]) -> any: + """ + Retrieves a specific item from the cache or database, caching the result if necessary. + + :param item_name: The name of the column to retrieve. + :param cache_name: The name used to identify the cache (also the name of the database table). + :param filter_vals: A tuple containing the column name and the value to filter by. + :return: The value of the requested item. + :raises ValueError: If the item is not found in either the cache or the database, + or if the column does not exist. + """ + # Fetch the relevant rows from the cache or database + rows = self.get_or_fetch_rows(cache_name=cache_name, filter_vals=filter_vals) + if rows is not None and not rows.empty: + if item_name not in rows.columns: + raise ValueError(f"Column '{item_name}' does not exist in the cache '{cache_name}'.") + # Return the specific item from the first matching row. + return rows.iloc[0][item_name] + + # If the item is not found, raise an error. + raise ValueError( + f"Item '{item_name}' not found in cache or table '{cache_name}' where {filter_vals[0]} = {filter_vals[1]}") + + def insert_row(self, cache_name: str, columns: tuple, values: tuple, key: str = None, + skip_cache: bool = False) -> None: + """ + Inserts a single row into the specified cache and database, with an option to skip cache insertion. + + :param cache_name: The name of the cache (and database table) where the row should be inserted. + :param columns: A tuple of column names corresponding to the values. + :param values: A tuple of values to insert into the specified columns. + :param key: Optional key for the cache item. If None, the auto-incremented ID from the database will be used. + :param skip_cache: If True, skips inserting the row into the cache. Default is False. + """ + # Insert the row into the database and fetch the auto-incremented ID + auto_incremented_id = self.db.insert_row(table=cache_name, columns=columns, values=values) + + if not skip_cache: + # Create a DataFrame for the new row + new_row_df = pd.DataFrame([values], columns=list(columns)) + + # Use the auto-incremented ID as the key if none was provided + if key is None: + key = str(auto_incremented_id) + + # Check if there is already a cache item for this key + existing_data = self.get_cache_item(key=key, cache_name=cache_name) + + if existing_data is not None and isinstance(existing_data, pd.DataFrame): + # Append the new row to the existing DataFrame in the cache + combined_df = pd.concat([existing_data, new_row_df], ignore_index=True) + else: + # If no existing data, use the new DataFrame + combined_df = new_row_df + + # Set the combined DataFrame back into the cache + self.set_cache_item(cache_name=cache_name, key=key, data=combined_df) + + def insert_df(self, df: pd.DataFrame, cache_name: str, skip_cache: bool = False) -> None: + """ + Inserts data from a DataFrame into the specified cache and database, with an option to skip cache insertion. + + :param df: The DataFrame containing the data to insert. + :param cache_name: The name of the cache (and database table) where the data should be inserted. + :param skip_cache: If True, skips inserting the data into the cache. Default is False. + """ + # Insert the data into the database and fetch the auto-incremented ID + auto_incremented_id = self.db.insert_dataframe(df=df, table=cache_name) + + if not skip_cache: + # Use the auto-incremented ID as the key for the cache item + self.set_cache_item(cache_name=cache_name, key=str(auto_incremented_id), data=df) + + def remove_row(self, cache_name: str, filter_vals: tuple[str, any], additional_filter: tuple[str, any] = None, + remove_from_db: bool = True) -> None: + """ + Removes a specific row from the cache and optionally from the database based on filter criteria. + + This method is specifically designed for caches stored as DataFrames. + + :param cache_name: The name of the cache (or table) from which to remove the row. + :param filter_vals: A tuple containing the column name and the value to filter by. + :param additional_filter: An optional additional filter to apply. + :param remove_from_db: If True, also removes the row from the database. Default is True. + :raises ValueError: If the cache is not a DataFrame. + """ + if cache_name not in self.caches: + raise ValueError(f"Cache '{cache_name}' does not exist.") + + # Retrieve the cache object + cache_obj = self.caches[cache_name] + + # Retrieve all items in the specified cache + cache_df = cache_obj.get_all_items() + + if not isinstance(cache_df, pd.DataFrame): + raise ValueError(f"Cache '{cache_name}' is not a DataFrame and cannot be used with remove_row.") + + # Apply filtering on the 'data' column + condition = cache_df['data'].apply(lambda df: df[filter_vals[0]].eq(filter_vals[1])).any(axis=1) + + # If an additional filter is provided, apply it + if additional_filter: + condition &= cache_df['data'].apply(lambda df: df[additional_filter[0]].eq(additional_filter[1])).any( + axis=1) + + # Filter the cache DataFrame to exclude the rows that match the condition + updated_cache_df = cache_df[~condition].reset_index(drop=True) + + # Update the cache with the modified DataFrame + cache_obj.cache = updated_cache_df + + if remove_from_db: + sql = f"DELETE FROM {cache_name} WHERE {filter_vals[0]} = ?" + params = [filter_vals[1]] + + if additional_filter: + sql += f" AND {additional_filter[0]} = ?" + params.append(additional_filter[1]) + + self.db.execute_sql(sql, tuple(params)) + + def is_attr_taken(self, cache_name: str, attr: str, val: any) -> bool: + """ + Checks if a specific attribute in any of the DataFrames stored within the cache + (which is stored as a DataFrame in the 'data' column) has the given value. + + :param cache_name: The key used to identify the cache (also the name of the database table). + :param attr: The attribute/column name to check (e.g., 'username', 'email'). + :param val: The value of the attribute to check. + :return: True if the attribute value is found in any of the DataFrames in the cache, False otherwise. + """ + # Retrieve all items in the cache + all_items_df = self.get_all_cache_items(cache_name) + + if all_items_df.empty: + return False + + # Concatenate all DataFrames stored in the 'data' column into a single DataFrame + combined_df = pd.concat(all_items_df['data'].tolist(), ignore_index=True) + + # Check if the combined DataFrame contains the attribute and if the value matches + if attr in combined_df.columns and not combined_df[combined_df[attr] == val].empty: + return True + + return False + + def modify_item(self, cache_name: str, filter_vals: tuple[str, any], field_name: str, new_data: any) -> None: + """ + Modifies a specific field in a row within the cache and updates the database accordingly. + + :param cache_name: The name used to identify the cache (also the name of the database table). + :param filter_vals: A tuple containing the column name and the value to filter by. + :param field_name: The field to be updated. + :param new_data: The new data to be set. + :raises ValueError: If the row is not found in the cache or the database. + """ + # Retrieve the row from the cache or database + row = self.get_or_fetch_rows(cache_name=cache_name, filter_vals=filter_vals) + + if row is None or row.empty: + raise ValueError(f"Row not found in cache or database for {filter_vals[0]} = {filter_vals[1]}") + + # Modify the specified field + if isinstance(new_data, str): + row.loc[0, field_name] = new_data + else: + # If new_data is not a string, convert it to a JSON string before inserting into the DataFrame. + row.loc[0, field_name] = json.dumps(new_data) + + # Update the cache by removing the old entry and adding the modified row + self.remove_row(cache_name=cache_name, filter_vals=filter_vals) + self.set_cache_item(cache_name=cache_name, key=filter_vals[1], data=row) + + # Update the database with the modified row (excluding the 'id' column if necessary) + self.db.insert_dataframe(row.drop(columns='id', errors='ignore'), table=cache_name) + + def update_cached_dict(self, cache_name: str, cache_key: str, dict_key: str, data: any) -> None: + """ + Updates a dictionary stored in the DataFrame cache. + + :param cache_name: The name of the cache that holds the dictionary. + :param cache_key: The key in the cache corresponding to the dictionary. + :param dict_key: The key within the dictionary to update. + :param data: The data to insert into the dictionary. + :return: None + """ + # Retrieve the item from the cache + cache_item = self.get_cache_item(key=cache_key, cache_name=cache_name) + + if cache_item is not None: + # Ensure the item is a dictionary + if isinstance(cache_item, dict): + # Update the dictionary with the new data + cache_item[dict_key] = data + + # Save the updated dictionary back into the cache + self.set_cache_item(cache_name=cache_name, key=cache_key, data=cache_item) + else: + raise ValueError(f"Expected a dictionary in cache, but found {type(cache_item)}.") + else: + raise KeyError(f"Cache key '{cache_key}' not found in cache '{cache_name}'.") + + +class ServerInteractions(DatabaseInteractions): + """ + Extends DataCache to specialize in handling candle (OHLC) data and server interactions. + """ + + def __init__(self, exchanges): + super().__init__(exchanges) + + @staticmethod + def _make_key(ex_details: list[str]) -> str: + """ + Generates a unique key string based on the exchange details provided. + + :param ex_details: A list of strings containing symbol, timeframe, exchange, etc. + :return: A key string composed of the concatenated details. + """ + symbol, timeframe, exchange, _ = ex_details + key = f'{symbol}_{timeframe}_{exchange}' + return key + + def _update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None: + logger.debug('Updating data with new records.') + + existing_records = self.get_cache_item(cache_name='candles', key=key) + if existing_records is None or existing_records.empty: + existing_records = pd.DataFrame() + + records = pd.concat([existing_records, more_records], axis=0, ignore_index=True) + records = records.drop_duplicates(subset="open_time", keep='first') + records = records.sort_values(by='open_time').reset_index(drop=True) + records['id'] = range(1, len(records) + 1) + + self.set_cache_item(cache_name='candles', key=key, data=records) + + def get_records_since(self, start_datetime: dt.datetime, ex_details: list[str]) -> pd.DataFrame: + """ + This gets up-to-date records from a specified market and exchange. + + :param start_datetime: The approximate time the first record should represent. + :param ex_details: The user exchange and market. + :return: The records. + """ + if self.TYPECHECKING_ENABLED: + if not isinstance(start_datetime, dt.datetime): + raise TypeError("start_datetime must be a datetime object") + if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details): + raise TypeError("ex_details must be a list of strings") + if not len(ex_details) == 4: + raise TypeError("ex_details must include [asset, timeframe, exchange, user_name]") + + if start_datetime.tzinfo is None: + raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") + end_datetime = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc) + + try: + args = { + 'start_datetime': start_datetime, + 'end_datetime': end_datetime, + 'ex_details': ex_details, + } + return self._get_or_fetch_from(target='data', **args) + except Exception as e: + logger.error(f"An error occurred: {str(e)}") + raise + + def _get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame: + """ + Fetches market records from a resource stack (data, database, exchange). + Fills incomplete request by fetching down the stack then updates the rest. + + :param target: Starting point for the fetch. ['data', 'database', 'exchange'] + :param kwargs: Details and credentials for the request. + :return: Records in a dataframe. + """ + start_datetime = kwargs.get('start_datetime') + if start_datetime.tzinfo is None: + raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") + + end_datetime = kwargs.get('end_datetime') + if end_datetime.tzinfo is None: + raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.") + + ex_details = kwargs.get('ex_details') + timeframe = ex_details[1] + + if self.TYPECHECKING_ENABLED: + if not isinstance(start_datetime, dt.datetime): + raise TypeError("start_datetime must be a datetime object") + if not isinstance(end_datetime, dt.datetime): + raise TypeError("end_datetime must be a datetime object") + if not isinstance(timeframe, str): + raise TypeError("record_length must be a string representing the timeframe") + if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details): + raise TypeError("ex_details must be a list of strings") + if not all([start_datetime, end_datetime, timeframe, ex_details]): + raise ValueError("Missing required arguments") + + request_criteria = { + 'start_datetime': start_datetime, + 'end_datetime': end_datetime, + 'timeframe': timeframe, + } + + key = self._make_key(ex_details=ex_details) + combined_data = pd.DataFrame() + + if target == 'data': + resources = [self._get_candles_from_cache, self._get_from_database, self._get_from_server] + elif target == 'database': + resources = [self._get_from_database, self._get_from_server] + elif target == 'server': + resources = [self._get_from_server] + else: + raise ValueError('Not a valid Target!') + + for fetch_method in resources: + result = fetch_method(**kwargs) + + if result is not None and not result.empty: + # Drop the 'id' column if it exists in the result + if 'id' in result.columns: + result = result.drop(columns=['id']) + + # Concatenate, drop duplicates based on 'open_time', and sort + combined_data = pd.concat([combined_data, result]).drop_duplicates(subset='open_time').sort_values( + by='open_time') + + is_complete, request_criteria = self._data_complete(data=combined_data, **request_criteria) + if is_complete: + if fetch_method in [self._get_from_database, self._get_from_server]: + self._update_candle_cache(more_records=combined_data, key=key) + if fetch_method == self._get_from_server: + self._populate_db(ex_details=ex_details, data=combined_data) + return combined_data + + kwargs.update(request_criteria) # Update kwargs with new start/end times for next fetch attempt + + logger.error('Unable to fetch the requested data.') + return combined_data if not combined_data.empty else pd.DataFrame() + + def _get_candles_from_cache(self, **kwargs) -> pd.DataFrame: + start_datetime = kwargs.get('start_datetime') + end_datetime = kwargs.get('end_datetime') + ex_details = kwargs.get('ex_details') + + if not all([start_datetime, end_datetime, ex_details]): + raise ValueError("Missing required arguments for candle data retrieval.") + + key = self._make_key(ex_details=ex_details) + logger.debug('Getting records from candles cache.') + + df = self.get_cache_item(cache_name='candles', key=key) + if df is None or df.empty: + logger.debug("No cached records found.") + return pd.DataFrame() + + df_filtered = df[(df['open_time'] >= unix_time_millis(start_datetime)) & + (df['open_time'] <= unix_time_millis(end_datetime))].reset_index(drop=True) + + return df_filtered + + def _get_from_database(self, **kwargs) -> pd.DataFrame: + start_datetime = kwargs.get('start_datetime') + if start_datetime.tzinfo is None: + raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") + + end_datetime = kwargs.get('end_datetime') + if end_datetime.tzinfo is None: + raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.") + + ex_details = kwargs.get('ex_details') + + if self.TYPECHECKING_ENABLED: + if not isinstance(start_datetime, dt.datetime): + raise TypeError("start_datetime must be a datetime object") + if not isinstance(end_datetime, dt.datetime): + raise TypeError("end_datetime must be a datetime object") + if not isinstance(ex_details, list) or not all(isinstance(i, str) for i in ex_details): + raise TypeError("ex_details must be a list of strings") + if not all([start_datetime, end_datetime, ex_details]): + raise ValueError("Missing required arguments") + + table_name = self._make_key(ex_details=ex_details) + if not self.db.table_exists(table_name): + logger.debug('Records not in database.') + return pd.DataFrame() + + logger.debug('Getting records from database.') + return self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=start_datetime, + et=end_datetime) + + def _get_from_server(self, **kwargs) -> pd.DataFrame: + symbol = kwargs.get('ex_details')[0] + interval = kwargs.get('ex_details')[1] + exchange_name = kwargs.get('ex_details')[2] + user_name = kwargs.get('ex_details')[3] + start_datetime = kwargs.get('start_datetime') + if start_datetime.tzinfo is None: + raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") + + end_datetime = kwargs.get('end_datetime') + if end_datetime.tzinfo is None: + raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.") + + if self.TYPECHECKING_ENABLED: + if not isinstance(symbol, str): + raise TypeError("symbol must be a string") + if not isinstance(interval, str): + raise TypeError("interval must be a string") + if not isinstance(exchange_name, str): + raise TypeError("exchange_name must be a string") + if not isinstance(user_name, str): + raise TypeError("user_name must be a string") + if not isinstance(start_datetime, dt.datetime): + raise TypeError("start_datetime must be a datetime object") + if not isinstance(end_datetime, dt.datetime): + raise TypeError("end_datetime must be a datetime object") + + logger.debug('Getting records from server.') + return self._fetch_candles_from_exchange(symbol=symbol, interval=interval, exchange_name=exchange_name, + user_name=user_name, start_datetime=start_datetime, + end_datetime=end_datetime) + + @staticmethod + def _data_complete(data: pd.DataFrame, **kwargs) -> (bool, dict): + """ + Checks if the data completely satisfies the request. + + :param data: DataFrame containing the records. + :param kwargs: Arguments required for completeness check. + :return: A tuple (is_complete, updated_request_criteria) where is_complete is True if the data is complete, + False otherwise, and updated_request_criteria contains adjusted start/end times if data is incomplete. + """ + if data.empty: + logger.debug("Data is empty.") + return False, kwargs # No data at all, proceed with the full original request + + start_datetime: dt.datetime = kwargs.get('start_datetime') + if start_datetime.tzinfo is None: + raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") + + end_datetime: dt.datetime = kwargs.get('end_datetime') + if end_datetime.tzinfo is None: + raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.") + + timeframe: str = kwargs.get('timeframe') + + temp_data = data.copy() + + # Convert 'open_time' to datetime with unit='ms' and localize to UTC + temp_data['open_time_dt'] = pd.to_datetime(temp_data['open_time'], + unit='ms', errors='coerce').dt.tz_localize('UTC') + + min_timestamp = temp_data['open_time_dt'].min() + max_timestamp = temp_data['open_time_dt'].max() + + logger.debug(f"Data time range: {min_timestamp} to {max_timestamp}") + logger.debug(f"Expected time range: {start_datetime} to {end_datetime}") + + tolerance = pd.Timedelta(seconds=5) + + # Initialize updated request criteria + updated_request_criteria = kwargs.copy() + + # Check if data covers the required time range with tolerance + if min_timestamp > start_datetime + timeframe_to_timedelta(timeframe) + tolerance: + logger.debug("Data does not start early enough, even with tolerance.") + updated_request_criteria['end_datetime'] = min_timestamp # Fetch the missing earlier data + return False, updated_request_criteria + + if max_timestamp < end_datetime - timeframe_to_timedelta(timeframe) - tolerance: + logger.debug("Data does not extend late enough, even with tolerance.") + updated_request_criteria['start_datetime'] = max_timestamp # Fetch the missing later data + return False, updated_request_criteria + + # Filter data between start_datetime and end_datetime + mask = (temp_data['open_time_dt'] >= start_datetime) & (temp_data['open_time_dt'] <= end_datetime) + data_in_range = temp_data.loc[mask] + + expected_count = estimate_record_count(start_datetime, end_datetime, timeframe) + actual_count = len(data_in_range) + + logger.debug(f"Expected record count: {expected_count}, Actual record count: {actual_count}") + + tolerance = 1 + if actual_count < (expected_count - tolerance): + logger.debug("Insufficient records within the specified time range, even with tolerance.") + return False, updated_request_criteria + + logger.debug("Data completeness check passed.") + return True, kwargs + + def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str, + start_datetime: dt.datetime = None, + end_datetime: dt.datetime = None) -> pd.DataFrame: + if start_datetime is None: + start_datetime = dt.datetime(year=2017, month=1, day=1, tzinfo=dt.timezone.utc) + + if end_datetime is None: + end_datetime = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc) + + if start_datetime > end_datetime: + raise ValueError("Invalid start and end parameters: start_datetime must be before end_datetime.") + + exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name) + + expected_records = estimate_record_count(start_datetime, end_datetime, interval) + logger.info( + f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {expected_records}') + + if start_datetime == end_datetime: + end_datetime = None + + candles = exchange.get_historical_klines(symbol=symbol, interval=interval, start_dt=start_datetime, + end_dt=end_datetime) + + num_rec_records = len(candles.index) + if num_rec_records == 0: + logger.warning(f"No OHLCV data returned for {symbol}.") + return pd.DataFrame(columns=['open_time', 'open', 'high', 'low', 'close', 'volume']) + + logger.info(f'{num_rec_records} candles retrieved from the exchange.') + + open_times = candles.open_time + min_open_time = open_times.min() + max_open_time = open_times.max() + + if min_open_time < 1e10: + raise ValueError('Records are not in milliseconds') + + estimated_num_records = estimate_record_count(min_open_time, max_open_time, interval) + 1 + logger.info(f'Estimated number of records: {estimated_num_records}') + + if num_rec_records < estimated_num_records: + logger.info('Detected gaps in the data, attempting to fill missing records.') + candles = self._fill_data_holes(records=candles, interval=interval) + + return candles + + @staticmethod + def _fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame: + time_span = timeframe_to_timedelta(interval).total_seconds() / 60 + last_timestamp = None + filled_records = [] + + logger.info(f"Starting to fill data holes for interval: {interval}") + + for index, row in records.iterrows(): + time_stamp = row['open_time'] + + if last_timestamp is None: + last_timestamp = time_stamp + filled_records.append(row) + logger.debug(f"First timestamp: {time_stamp}") + continue + + delta_ms = time_stamp - last_timestamp + delta_minutes = (delta_ms / 1000) / 60 + + logger.debug(f"Timestamp: {time_stamp}, Delta minutes: {delta_minutes}") + + if delta_minutes > time_span: + num_missing_rec = int(delta_minutes / time_span) + step = int(delta_ms / num_missing_rec) + logger.debug(f"Gap detected. Filling {num_missing_rec} records with step: {step}") + + for ts in range(int(last_timestamp) + step, int(time_stamp), step): + new_row = row.copy() + new_row['open_time'] = ts + filled_records.append(new_row) + logger.debug(f"Filled timestamp: {ts}") + + filled_records.append(row) + last_timestamp = time_stamp + + logger.info("Data holes filled successfully.") + return pd.DataFrame(filled_records) + + def _populate_db(self, ex_details: list[str], data: pd.DataFrame = None) -> None: + if data is None or data.empty: + logger.debug(f'No records to insert {data}') + return + + table_name = self._make_key(ex_details=ex_details) + symbol, _, exchange, _ = ex_details + + self.db.insert_candles_into_db(candlesticks=data, table_name=table_name, symbol=symbol, exchange_name=exchange) + logger.info(f'Data inserted into table {table_name}') + + +class DataCache(ServerInteractions): + """ + Extends ServerInteractions, DatabaseInteractions, SnapshotDataCache, and DataCacheBase to create a fully functional + cache management system that can store and manage custom caching objects of type InMemoryCache. The following + illustrates the complete structure of data, leveraging all parts of the system. + + `caches` is a dictionary of InMemoryCache objects indexed by name. An InMemoryCache is a custom object that contains + attributes such as `limit`, `eviction_policy`, and a `cache`, which is a DataFrame containing the actual data. This + structure allows the cache management system to leverage pandas' powerful querying and filtering capabilities to + manipulate and query the data efficiently while maintaining the cache's overall structure. + + Example Structure: + + caches = { + [name1]: { + 'limit': [limit], + 'eviction_policy': [eviction_policy], + 'cache': pd.DataFrame({ + 'key': [key], + 'creation_time': [creation_time], + 'expire_delta': [expire_delta], + 'data': pd.DataFrame({ + 'column1': [data], + 'column2': [data], + 'column3': [data] + }) + }) + } + } + + # The following is an example of the structure in use: + + caches = { + 'users': { + 'limit': 100, + 'eviction_policy': 'deny', + 'cache': pd.DataFrame({ + 'key': ['user1', 'user2'], + 'creation_time': [ + datetime.datetime(2024, 8, 24, 12, 45, 31, 117684), + datetime.datetime(2023, 8, 12, 6, 35, 21, 113644) + ], + 'expire_delta': [ + datetime.timedelta(seconds=600), + None + ], + 'data': [ + pd.DataFrame({ + 'user_name': ['Billy'], + 'password': [1234], + 'exchanges': [['ex1', 'ex2', 'ex3']] + }), + pd.DataFrame({ + 'user_name': ['Patty'], + 'password': [5678], + 'exchanges': [['ex1', 'ex3', 'ex5']] + }) + ] + }) + }, + 'strategies': { + 'limit': 100, + 'eviction_policy': 'expire', + 'cache': pd.DataFrame({ + 'key': ['strategy1', 'strategy2'], + 'creation_time': [ + datetime.datetime(2024, 3, 13, 12, 45, 31, 117684), + datetime.datetime(2023, 4, 12, 6, 35, 21, 113644) + ], + 'expire_delta': [ + datetime.timedelta(seconds=600), + datetime.timedelta(seconds=600) + ], + 'data': [ + pd.DataFrame({ + 'strategy_name': ['Awesome_v1.2'], + 'code': [json.dumps({"params": {"t": "ex1", "m": 100, "ex": True}})], + 'max_loss': [100] + }), + pd.DataFrame({ + 'strategy_name': ['Awesome_v2.0'], + 'code': [json.dumps({"params": {"t": "ex2", "m": 67, "ex": False}})], + 'max_loss': [40] + }) + ] + }) + }, + ... + } + + Usage Examples: + # Extract all user data DataFrames + user_data = pd.concat(caches['users']['cache']['data'].tolist(), ignore_index=True) + + # Check if a username is taken + username_to_check = 'Billy' + is_taken = not user_data[user_data['user_name'] == username_to_check].empty + + # Extract all user data DataFrames + user_data = pd.concat(caches['users']['cache']['data'].tolist(), ignore_index=True) + + # Find users with a specific exchange + exchange_to_check = 'ex3' + users_with_exchange = user_data[user_data['exchanges'].apply(lambda x: exchange_to_check in x)] + + # Filter out the expired users + current_time = datetime.datetime.now(datetime.timezone.utc) + non_expired_users = caches['users']['cache'].apply( + lambda row: current_time <= row['creation_time'] + row['expire_delta'] if row['expire_delta'] else True, + axis=1 + ) + caches['users']['cache'] = caches['users']['cache'][non_expired_users] + + """ + + def __init__(self, exchanges): + super().__init__(exchanges) + logger.info("DataCache initialized.") diff --git a/src/Database.py b/src/Database.py index b0477c1..198deec 100644 --- a/src/Database.py +++ b/src/Database.py @@ -133,29 +133,41 @@ class Database: print(f"Error querying table '{table}' for column '{filter_vals[0]}': {e}") return None - def insert_dataframe(self, df: pd.DataFrame, table: str) -> None: + def insert_dataframe(self, df: pd.DataFrame, table: str) -> int: """ - Inserts a DataFrame into a specified table. + Inserts a DataFrame into a specified table and returns the last inserted row's ID. :param df: DataFrame to insert. :param table: Name of the table. + :return: The auto-incremented ID of the last inserted row. """ with SQLite(self.db_file) as con: + # Insert the DataFrame into the specified table df.to_sql(name=table, con=con, index=False, if_exists='append') - def insert_row(self, table: str, columns: Tuple[str, ...], values: Tuple[Any, ...]) -> None: + # Fetch the last inserted row ID + cursor = con.execute('SELECT last_insert_rowid()') + last_id = cursor.fetchone()[0] + + return last_id + + def insert_row(self, table: str, columns: Tuple[str, ...], values: Tuple[Any, ...]) -> int: """ - Inserts a row into a specified table. + Inserts a row into a specified table and returns the auto-incremented ID. :param table: Name of the table. :param columns: Tuple of column names. :param values: Tuple of values to insert. + :return: The auto-incremented ID of the inserted row. """ with SQLite(self.db_file) as conn: cursor = conn.cursor() sql = make_insert(table=table, columns=columns) cursor.execute(sql, values) + # Return the auto-incremented ID + return cursor.lastrowid + def table_exists(self, table_name: str) -> bool: """ Checks if a table exists in the database. diff --git a/src/Strategies.py b/src/Strategies.py index fb0d303..850b519 100644 --- a/src/Strategies.py +++ b/src/Strategies.py @@ -1,12 +1,26 @@ import json from DataCache_v2 import DataCache + class Strategy: def __init__(self, **args): """ :param args: An object containing key_value pairs representing strategy attributes. Strategy format is defined in strategies.js """ + self.active = None + self.type = None + self.trade_amount = None + self.max_position = None + self.side = None + self.trd_in_conds = None + self.merged_loss = None + self.gross_loss = None + self.stop_loss = None + self.take_profit = None + self.gross_profit = None + self.merged_profit = None + self.name = None self.current_value = None self.opening_value = None self.gross_pl = None @@ -161,19 +175,22 @@ class Strategies: # Reference to the trades object that maintains all trading actions and data. self.trades = trades + self.strat_list = [] + def get_all_strategy_names(self) -> list | None: """Return a list of all strategies in the database""" - self.data._get_from_database() - # Load existing Strategies from file. - loaded_strategies = config.get_setting('strategies') - if loaded_strategies is None: - # Populate the list and file with defaults defined in this class. - loaded_strategies = self.get_strategy_defaults() - config.set_setting('strategies', loaded_strategies) + # # Load existing Strategies from file. + # loaded_strategies = self.data.get_setting('strategies') + # if loaded_strategies is None: + # # Populate the list and file with defaults defined in this class. + # loaded_strategies = self.get_strategy_defaults() + # config.set_setting('strategies', loaded_strategies) + # + # for entry in loaded_strategies: + # # Initialise all the strategy objects with data from file. + # self.strat_list.append(Strategy(**entry)) - for entry in loaded_strategies: - # Initialise all the strategy objects with data from file. - self.strat_list.append(Strategy(**entry)) return None + return None def new_strategy(self, data): # Create an instance of the new Strategy. diff --git a/src/Users.py b/src/Users.py index 6f450ec..113735b 100644 --- a/src/Users.py +++ b/src/Users.py @@ -4,7 +4,7 @@ import random from typing import Any from passlib.hash import bcrypt import pandas as pd -from DataCache_v2 import DataCache +from DataCache_v3 import DataCache class BaseUser: @@ -28,9 +28,9 @@ class BaseUser: :param user_name: The name of the user. :return: The ID of the user as an integer. """ - return self.data.fetch_cached_item( + return self.data.fetch_item( item_name='id', - table_name='users', + cache_name='users', filter_vals=('user_name', user_name) ) @@ -40,10 +40,10 @@ class BaseUser: :param user_name: The name of the user to remove from the cache. """ - # Use DataCache to remove the user from the cache only + # Remove the user from the cache only self.data.remove_row( - table='users', - filter_vals=('user_name', user_name) + cache_name='users', + filter_vals=('user_name', user_name), remove_from_db=False ) def delete_user(self, user_name: str) -> None: @@ -54,7 +54,7 @@ class BaseUser: """ self.data.remove_row( filter_vals=('user_name', user_name), - table='users' + cache_name='users' ) def get_user_data(self, user_name: str) -> pd.DataFrame | None: @@ -67,8 +67,8 @@ class BaseUser: :raises ValueError: If the user is not found in both the cache and the database. """ # Attempt to fetch the user data from the cache or database via DataCache - user = self.data.fetch_cached_rows( - table='users', + user = self.data.get_or_fetch_rows( + cache_name='users', filter_vals=('user_name', user_name) ) @@ -86,8 +86,8 @@ class BaseUser: :param new_data: The new data to be set. """ # Use DataCache to modify the user's data - self.data.modify_cached_row( - table='users', + self.data.modify_item( + cache_name='users', filter_vals=('user_name', username), field_name=field_name, new_data=new_data @@ -112,8 +112,8 @@ class UserAccountManagement(BaseUser): self.max_guests = max_guests # Maximum number of guests # Initialize data for guest suffixes and cached users - self.data.set_cache(data=[], key='guest_suffixes', do_not_overwrite=True) - self.data.set_cache(data={}, key='cached_users', do_not_overwrite=True) + self.data.set_cache_item(data=[], key='guest_suffixes', do_not_overwrite=True) + self.data.set_cache_item(data={}, key='cached_users', do_not_overwrite=True) def is_logged_in(self, user_name: str) -> bool: """ @@ -138,9 +138,9 @@ class UserAccountManagement(BaseUser): # If the user is logged in, check if they are a guest. if is_guest(user): # Update the guest suffix cache if the user is a guest. - guest_suffixes = self.data.get_cache('guest_suffixes') - guest_suffixes.append(user_name[1]) - self.data.set_cache(data=guest_suffixes, key='guest_suffixes') + guest_suffixes = self.data.get_cache_item(key='guest_suffixes') or [] + guest_suffixes.append(user_name.split('_')[1]) + self.data.set_cache_item(data=guest_suffixes, key='guest_suffixes') return True else: # If the user is not logged in, remove their data from the cache. @@ -159,7 +159,7 @@ class UserAccountManagement(BaseUser): :return: True if the password is correct, False otherwise. """ # Retrieve the hashed password using DataCache - user_data = self.data.fetch_cached_rows(table='users', filter_vals=('user_name', username)) + user_data = self.data.get_or_fetch_rows(cache_name='users', filter_vals=('user_name', username)) if user_data is None or user_data.empty: return False @@ -215,23 +215,24 @@ class UserAccountManagement(BaseUser): :param enforcement: 'soft' or 'hard' - Determines how strictly users are logged out. """ - if enforcement == 'soft': - self._soft_log_out_all_users() - elif enforcement == 'hard': - # Clear all user-related entries from the cache - for index, row in self.data.cache.iterrows(): - if 'user_name' in row: - self._remove_user_from_memory(row['user_name']) - - df = self.data.fetch_cached_rows(table='users', filter_vals=('status', 'logged_in')) - if df is not None: - df = df[df.user_name != 'guest'] - - # Update the status of all logged-in users to 'logged_out' - for user_name in df.user_name.values: - self.modify_user_data(username=user_name, field_name='status', new_data='logged_out') - else: - raise ValueError("Invalid enforcement type. Use 'soft' or 'hard'.") + # if enforcement == 'soft': + # self._soft_log_out_all_users() + # elif enforcement == 'hard': + # # Clear all user-related entries from the cache + # for index, row in self.data.cache.iterrows(): + # if 'user_name' in row: + # self._remove_user_from_memory(row['user_name']) + # + # df = self.data.get_or_fetch_rows(cache_name='users', filter_vals=('status', 'logged_in')) + # if df is not None: + # df = df[df.user_name != 'guest'] + # + # # Update the status of all logged-in users to 'logged_out' + # for user_name in df.user_name.values: + # self.modify_user_data(username=user_name, field_name='status', new_data='logged_out') + # else: + # raise ValueError("Invalid enforcement type. Use 'soft' or 'hard'.") + pass def _soft_log_out_all_users(self) -> None: """ @@ -250,7 +251,7 @@ class UserAccountManagement(BaseUser): :return: True if the attribute is already taken, False otherwise. """ # Use DataCache to check if the attribute is taken - return self.data.is_attr_taken(table='users', attr=attr, val=val) + return self.data.is_attr_taken(cache_name='users', attr=attr, val=val) def create_unique_guest_name(self) -> str | None: """ @@ -258,12 +259,16 @@ class UserAccountManagement(BaseUser): :return: A unique guest username or None if the guest limit is reached. """ - guest_suffixes = self.data.get_cache('guest_suffixes') - if len(guest_suffixes) > self.max_guests: + guest_suffixes = self.data.get_cache_item(key='guest_suffixes') or [] + if len(guest_suffixes) >= self.max_guests: return None - suffix = random.choice(range(0, (self.max_guests * 9))) + + suffix = random.choice(range(0, self.max_guests * 9)) while suffix in guest_suffixes: - suffix = random.choice(range(0, (self.max_guests * 9))) + suffix = random.choice(range(0, self.max_guests * 9)) + + guest_suffixes.append(suffix) + self.data.set_cache_item(key='guest_suffixes', data=guest_suffixes) return f'guest_{suffix}' def create_guest(self) -> str | None: @@ -292,7 +297,7 @@ class UserAccountManagement(BaseUser): raise ValueError("Attributes must be a tuple of single key-value pair dictionaries.") # Retrieve the default user template from the database using DataCache - default_user = self.data.fetch_cached_rows(table='users', filter_vals=('user_name', 'guest')) + default_user = self.data.get_or_fetch_rows(cache_name='users', filter_vals=('user_name', 'guest')) if default_user is None or default_user.empty: raise ValueError("Default user template not found in the database.") @@ -306,7 +311,7 @@ class UserAccountManagement(BaseUser): default_user = default_user.drop(columns='id') # Insert the modified user data into the database, skipping cache insertion - self.data.insert_data(df=default_user, table="users", skip_cache=True) + self.data.insert_df(df=default_user, cache_name="users", skip_cache=True) def create_new_user(self, username: str, email: str, password: str) -> bool: """ @@ -456,7 +461,7 @@ class UserIndicatorManagement(UserExchangeManagement): user_id = self.get_id(user_name) # Fetch the indicators from the database using DataCache - df = self.data.fetch_cached_rows(table='indicators', filter_vals=('creator', user_id)) + df = self.data.get_or_fetch_rows(cache_name='indicators', filter_vals=('creator', user_id)) # If indicators are found, process the JSON fields if df is not None and not df.empty: @@ -481,8 +486,8 @@ class UserIndicatorManagement(UserExchangeManagement): indicator['kind'], src_string, prop_string) columns = ('creator', 'name', 'visible', 'kind', 'source', 'properties') - # Insert the row into the database using DataCache - self.data.insert_row(table='indicators', columns=columns, values=values) + # Insert the row into the database and cache using DataCache + self.data.insert_row(cache_name='indicators', columns=columns, values=values) def remove_indicator(self, indicator_name: str, user_name: str) -> None: """ @@ -495,7 +500,7 @@ class UserIndicatorManagement(UserExchangeManagement): self.data.remove_row( filter_vals=('name', indicator_name), additional_filter=('creator', user_id), - table='indicators' + cache_name='indicators' ) def get_chart_view(self, user_name: str, prop: str | None = None): diff --git a/src/trade.py b/src/trade.py index 75b3976..08d30ce 100644 --- a/src/trade.py +++ b/src/trade.py @@ -1,6 +1,6 @@ import json import uuid - +from Users import Users import requests from datetime import datetime @@ -267,7 +267,7 @@ class Trade: class Trades: - def __init__(self, users): + def __init__(self, users: Users): """ This class receives, executes, tracks and stores all active_trades. :param users: A class that maintains users each user may have trades. @@ -291,10 +291,10 @@ class Trades: self.stats = {'num_trades': 0, 'total_position': 0, 'total_position_value': 0} # Load all trades. - loaded_trades = users.get_all_active_user_trades() - if loaded_trades is not None: - # Create the active_trades loaded from file. - self.load_trades(loaded_trades) + # loaded_trades = users.get_all_active_user_trades() + # if loaded_trades is not None: + # # Create the active_trades loaded from file. + # self.load_trades(loaded_trades) def connect_exchanges(self, exchanges): """ diff --git a/tests/test_DataCache_v2.py b/tests/test_DataCache.py similarity index 54% rename from tests/test_DataCache_v2.py rename to tests/test_DataCache.py index 77edae4..05bbe93 100644 --- a/tests/test_DataCache_v2.py +++ b/tests/test_DataCache.py @@ -1,5 +1,7 @@ +import time import pytz -from DataCache_v2 import DataCache, timeframe_to_timedelta, estimate_record_count +from DataCache_v3 import DataCache, timeframe_to_timedelta, estimate_record_count, InMemoryCache, DataCacheBase, \ + SnapshotDataCache from ExchangeInterface import ExchangeInterface import unittest import pandas as pd @@ -191,7 +193,7 @@ class DataGenerator: return dt_obj -class TestDataCacheV2(unittest.TestCase): +class TestDataCache(unittest.TestCase): def setUp(self): # Set up database and exchanges self.exchanges = ExchangeInterface() @@ -229,11 +231,16 @@ class TestDataCacheV2(unittest.TestCase): key TEXT PRIMARY KEY, data TEXT NOT NULL )""" - sql_create_table_5 = f""" - CREATE TABLE IF NOT EXISTS users ( - users_data TEXT PRIMARY KEY, - data TEXT NOT NULL - )""" + sql_create_table_5 = """ + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT, + age INTEGER, + users_data TEXT, + data TEXT, + password TEXT -- Moved to a new line and added a comma after 'data' + ) + """ with SQLite(db_file=self.db_file) as con: con.execute(sql_create_table_1) @@ -251,44 +258,188 @@ class TestDataCacheV2(unittest.TestCase): if os.path.exists(self.db_file): os.remove(self.db_file) - def test_set_cache(self): - print('\nTesting set_cache() method without no-overwrite flag:') - self.data.set_cache(data='data', key=self.key) + def test_InMemoryCache(self): + # Step 1: Create a cache with a limit of 2 items and 'evict' policy + print("Creating a cache with a limit of 2 items and 'evict' policy.") + cached_users = InMemoryCache(limit=2, eviction_policy='evict') - # Access the cache data using the DataFrame structure - cached_value = self.data.get_cache(key=self.key) - self.assertEqual(cached_value, 'data') - print(' - Set data without no-overwrite flag passed.') + # Step 2: Set some items in the cache. + print("Setting 'user_bob' in the cache with an expiration of 10 seconds.") + cached_users.set_item("user_bob", "{password:'BobPass'}", expire_delta=dt.timedelta(seconds=10)) - print('Testing set_cache() once again with new data without no-overwrite flag:') - self.data.set_cache(data='more_data', key=self.key) + print("Setting 'user_alice' in the cache with an expiration of 20 seconds.") + cached_users.set_item("user_alice", "{password:'AlicePass'}", expire_delta=dt.timedelta(seconds=20)) - # Access the updated cache data - cached_value = self.data.get_cache(key=self.key) - self.assertEqual(cached_value, 'more_data') - print(' - Set data with new data without no-overwrite flag passed.') + # Step 3: Retrieve 'user_bob' from the cache + print("Retrieving 'user_bob' from the cache.") + retrieved_item = cached_users.get_item('user_bob') + print(f"Retrieved: {retrieved_item}") + assert retrieved_item == "{password:'BobPass'}", "user_bob should have been retrieved successfully." - print('Testing set_cache() method once again with more data with no-overwrite flag set:') - self.data.set_cache(data='even_more_data', key=self.key, do_not_overwrite=True) + # Step 4: Add another item, causing the oldest item to be evicted + print("Adding 'user_billy' to the cache, which should evict 'user_bob' due to the limit.") + cached_users.set_item("user_billy", "{password:'BillyPass'}") - # Since do_not_overwrite is True, the cached data should not change - cached_value = self.data.get_cache(key=self.key) - self.assertEqual(cached_value, 'more_data') - print(' - Set data with no-overwrite flag passed.') + # Step 5: Attempt to retrieve the evicted item 'user_bob' + print("Attempting to retrieve the evicted item 'user_bob'.") + evicted_item = cached_users.get_item('user_bob') + print(f"Evicted Item: {evicted_item}") + assert evicted_item is None, "user_bob should have been evicted from the cache." - def test_cache_exists(self): - print('Testing cache_exists() method:') + # Step 6: Retrieve the current items in the cache + print("Retrieving all current items in the cache after eviction.") + all_items = cached_users.get_all_items() + print("Current items in cache:\n", all_items) + assert "user_alice" in all_items['key'].values, "user_alice should still be in the cache." + assert "user_billy" in all_items['key'].values, "user_billy should still be in the cache." - # Check that the cache does not contain the key before setting it - self.assertFalse(self.data.cache_exists(key=self.key)) - print(' - Check for non-existent data passed.') + # Step 7: Simulate waiting for 'user_alice' to expire (assuming 20 seconds pass) + print("Simulating time passing to expire 'user_alice' (20 seconds).") + time.sleep(20) # This is to simulate the passage of time; in real tests, you may mock datetime. - # Set the cache with a DataFrame containing the key-value pair - self.data.set_cache(data='data', key=self.key) + # Step 8: Clean expired items from the cache + print("Cleaning expired items from the cache.") + cached_users.clean_expired_items() - # Check that the cache now contains the key - self.assertTrue(self.data.cache_exists(key=self.key)) - print(' - Check for existent data passed.') + # Step 9: Retrieve the current items in the cache after cleaning expired items + print("Retrieving all current items in the cache after cleaning expired items.") + all_items_after_cleaning = cached_users.get_all_items() + print("Current items in cache after cleaning:\n", all_items_after_cleaning) + assert "user_alice" not in all_items_after_cleaning[ + 'key'].values, "user_alice should have been expired and removed from the cache." + assert "user_billy" in all_items_after_cleaning['key'].values, "user_billy should still be in the cache." + + # Step 10: Check if 'user_billy' still exists as it should not expire + print("Checking if 'user_billy' still exists in the cache (it should not have expired).") + user_billy_item = cached_users.get_item('user_billy') + print(f"'user_billy' still exists: {user_billy_item}") + assert user_billy_item == "{password:'BillyPass'}", "user_billy should still exist in the cache." + + def test_DataCacheBase(self): + # Step 1: Create a DataCacheBase instance + print("Creating a DataCacheBase instance.") + cache_manager = DataCacheBase() + + # Step 2: Set some items in 'my_cache'. The cache is created automatically with limit 2 and 'evict' policy. + print("Setting 'key1' in 'my_cache' with an expiration of 10 seconds.") + cache_manager.set_cache_item('key1', 'data1', expire_delta=dt.timedelta(seconds=10), cache_name='my_cache', + limit=2, eviction_policy='evict') + + print("Setting 'key2' in 'my_cache' with an expiration of 20 seconds.") + cache_manager.set_cache_item('key2', 'data2', expire_delta=dt.timedelta(seconds=20), cache_name='my_cache') + + # Step 3: Set some items in 'second_cache'. The cache is created automatically with limit 3 and 'deny' policy. + print("Setting 'keyA' in 'second_cache' with an expiration of 15 seconds.") + cache_manager.set_cache_item('keyA', 'dataA', expire_delta=dt.timedelta(seconds=15), cache_name='second_cache', + limit=3, eviction_policy='deny') + + print("Setting 'keyB' in 'second_cache' with an expiration of 30 seconds.") + cache_manager.set_cache_item('keyB', 'dataB', expire_delta=dt.timedelta(seconds=30), cache_name='second_cache') + + print("Setting 'keyC' in 'second_cache' with no expiration.") + cache_manager.set_cache_item('keyC', 'dataC', cache_name='second_cache') + + # Step 4: Add another item to 'my_cache', causing the oldest item to be evicted. + print("Adding 'key3' to 'my_cache', which should evict 'key1' due to the limit.") + cache_manager.set_cache_item('key3', 'data3', cache_name='my_cache') + + # Step 5: Attempt to retrieve the evicted item 'key1' from 'my_cache'. + print("Attempting to retrieve the evicted item 'key1' from 'my_cache'.") + evicted_item = cache_manager.get_cache_item('key1', cache_name='my_cache') + print(f"Evicted Item from 'my_cache': {evicted_item}") + assert evicted_item is None, "'key1' should have been evicted from 'my_cache'." + + # Step 6: Retrieve all current items in both caches before cleaning. + print("Retrieving all current items in 'my_cache' before cleaning.") + all_items_my_cache = cache_manager.get_all_cache_items('my_cache') + print("Current items in 'my_cache':\n", all_items_my_cache) + + print("Retrieving all current items in 'second_cache' before cleaning.") + all_items_second_cache = cache_manager.get_all_cache_items('second_cache') + print("Current items in 'second_cache':\n", all_items_second_cache) + + # Step 7: Simulate time passing to expire 'key2' in 'my_cache' and 'keyA' in 'second_cache'. + print("Simulating time passing to expire 'key2' in 'my_cache' (20 seconds)" + " and 'keyA' in 'second_cache' (15 seconds).") + time.sleep(20) # Simulate the passage of time; in real tests, you may mock datetime. + + # Step 8: Clean expired items in all caches + print("Cleaning expired items in all caches.") + cache_manager.clean_expired_items() + + # Step 9: Verify the cleaning of expired items in 'my_cache'. + print("Retrieving all current items in 'my_cache' after cleaning expired items.") + all_items_after_cleaning_my_cache = cache_manager.get_all_cache_items('my_cache') + print("Items in 'my_cache' after cleaning:\n", all_items_after_cleaning_my_cache) + assert 'key2' not in all_items_after_cleaning_my_cache[ + 'key'].values, "'key2' should have been expired and removed from 'my_cache'." + assert 'key3' in all_items_after_cleaning_my_cache['key'].values, "'key3' should still be in 'my_cache'." + + # Step 10: Verify the cleaning of expired items in 'second_cache'. + print("Retrieving all current items in 'second_cache' after cleaning expired items.") + all_items_after_cleaning_second_cache = cache_manager.get_all_cache_items('second_cache') + print("Items in 'second_cache' after cleaning:\n", all_items_after_cleaning_second_cache) + assert 'keyA' not in all_items_after_cleaning_second_cache[ + 'key'].values, "'keyA' should have been expired and removed from 'second_cache'." + assert 'keyB' in all_items_after_cleaning_second_cache[ + 'key'].values, "'keyB' should still be in 'second_cache'." + assert 'keyC' in all_items_after_cleaning_second_cache[ + 'key'].values, "'keyC' should still be in 'second_cache' since it has no expiration." + + def test_SnapshotDataCache(self): + # Step 1: Create a SnapshotDataCache instance + print("Creating a SnapshotDataCache instance.") + snapshot_cache_manager = SnapshotDataCache() + + # Step 2: Create an in-memory cache with a limit of 2 items and 'evict' policy + print("Creating an in-memory cache named 'my_cache' with a limit of 2 items and 'evict' policy.") + snapshot_cache_manager.create_cache('my_cache', cache_type=InMemoryCache, limit=2, eviction_policy='evict') + + # Step 3: Set some items in the cache + print("Setting 'key1' in 'my_cache' with an expiration of 10 seconds.") + snapshot_cache_manager.set_cache_item(key='key1', data='data1', expire_delta=dt.timedelta(seconds=10), + cache_name='my_cache') + + print("Setting 'key2' in 'my_cache' with an expiration of 20 seconds.") + snapshot_cache_manager.set_cache_item(key='key2', data='data2', expire_delta=dt.timedelta(seconds=20), + cache_name='my_cache') + + # Step 4: Take a snapshot of the current state of 'my_cache' + print("Taking a snapshot of the current state of 'my_cache'.") + snapshot_cache_manager.snapshot_cache('my_cache') + + # Step 5: Add another item, causing the oldest item to be evicted + print("Adding 'key3' to 'my_cache', which should evict 'key1' due to the limit.") + snapshot_cache_manager.set_cache_item(key='key3', data='data3', cache_name='my_cache') + + # Step 6: Retrieve the most recent snapshot of 'my_cache' + print("Retrieving the most recent snapshot of 'my_cache'.") + snapshot = snapshot_cache_manager.get_snapshot('my_cache') + print(f"Snapshot Data:\n{snapshot}") + + # Assert that the snapshot contains 'key1' and 'key2', but not 'key3' + assert 'key1' in snapshot['key'].values, "'key1' should be in the snapshot." + assert 'key2' in snapshot['key'].values, "'key2' should be in the snapshot." + assert 'key3' not in snapshot[ + 'key'].values, "'key3' should not be in the snapshot as it was added after the snapshot." + + # Step 7: List all available snapshots with their timestamps + print("Listing all available snapshots with their timestamps.") + snapshots_list = snapshot_cache_manager.list_snapshots() + print(f"Snapshots List: {snapshots_list}") + + # Assert that the snapshot list contains 'my_cache' + assert 'my_cache' in snapshots_list, "'my_cache' should be in the snapshots list." + assert isinstance(snapshots_list['my_cache'], str), "The snapshot for 'my_cache' should have a timestamp." + + # Additional validation: Ensure 'key3' is present in the live cache but not in the snapshot + print("Ensuring 'key3' is present in the live 'my_cache'.") + live_cache_items = snapshot_cache_manager.get_all_cache_items('my_cache') + print(f"Live 'my_cache' items after adding 'key3':\n{live_cache_items}") + assert 'key3' in live_cache_items['key'].values, "'key3' should be in the live cache." + + # Ensure the live cache does not contain 'key1' + assert 'key1' not in live_cache_items['key'].values, "'key1' should have been evicted from the live cache." def test_update_candle_cache(self): print('Testing update_candle_cache() method:') @@ -296,18 +447,18 @@ class TestDataCacheV2(unittest.TestCase): # Initialize the DataGenerator with the 5-minute timeframe data_gen = DataGenerator('5m') - # Create initial DataFrame and insert into cache + # Create initial DataFrame and insert it into the cache df_initial = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 0, 0, tzinfo=dt.timezone.utc)) print(f'Inserting this table into cache:\n{df_initial}\n') - self.data.set_cache(data=df_initial, key=self.key) + self.data.set_cache_item(key=self.key, data=df_initial, cache_name='candles') - # Create new DataFrame to be added to cache + # Create new DataFrame to be added to the cache df_new = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 15, 0, tzinfo=dt.timezone.utc)) print(f'Updating cache with this table:\n{df_new}\n') self.data._update_candle_cache(more_records=df_new, key=self.key) - # Retrieve the resulting DataFrame from cache - result = self.data.get_cache(key=self.key) + # Retrieve the resulting DataFrame from the cache + result = self.data.get_cache_item(key=self.key, cache_name='candles') print(f'The resulting table in cache is:\n{result}\n') # Create the expected DataFrame @@ -316,8 +467,7 @@ class TestDataCacheV2(unittest.TestCase): # Assert that the open_time values in the result match those in the expected DataFrame, in order assert result['open_time'].tolist() == expected['open_time'].tolist(), \ - f"open_time values in result are {result['open_time'].tolist()}" \ - f" but expected {expected['open_time'].tolist()}" + f"open_time values in result are {result['open_time'].tolist()} expected {expected['open_time'].tolist()}" print(f'The result open_time values match:\n{result["open_time"].tolist()}\n') print(' - Update cache with new records passed.') @@ -325,32 +475,25 @@ class TestDataCacheV2(unittest.TestCase): def test_update_cached_dict(self): print('Testing update_cached_dict() method:') - # Set an empty dictionary in the cache for the specified key - self.data.set_cache(data={}, key=self.key) + # Step 1: Set an empty dictionary in the cache for the specified key + print(f'Setting an empty dictionary in the cache with key: {self.key}') + self.data.set_cache_item(data={}, key=self.key) - # Update the cached dictionary with a new key-value pair - self.data.update_cached_dict(cache_key=self.key, dict_key='sub_key', data='value') + # Step 2: Update the cached dictionary with a new key-value pair + print(f'Updating the cached dictionary with key: {self.key}, adding sub_key="sub_key" with value="value".') + self.data.update_cached_dict(cache_name='default_cache', cache_key=self.key, dict_key='sub_key', data='value') - # Retrieve the updated cache - cache = self.data.get_cache(key=self.key) + # Step 3: Retrieve the updated cache + print(f'Retrieving the updated dictionary from the cache with key: {self.key}') + cache = self.data.get_cache_item(key=self.key) - # Verify that the 'sub_key' in the cached dictionary has the correct value + # Step 4: Verify that the 'sub_key' in the cached dictionary has the correct value + print(f'Verifying that "sub_key" in the cached dictionary has the value "value".') + self.assertIsInstance(cache, dict, "The cache should be a dictionary.") + self.assertIn('sub_key', cache, "The 'sub_key' should be present in the cached dictionary.") self.assertEqual(cache['sub_key'], 'value') print(' - Update dictionary in cache passed.') - def test_get_cache(self): - print('Testing get_cache() method:') - - # Set some data into the cache - self.data.set_cache(data='data', key=self.key) - - # Retrieve the cached data using the get_cache method - result = self.data.get_cache(key=self.key) - - # Verify that the result matches the data we set - self.assertEqual(result, 'data') - print(' - Retrieve data passed.') - def _test_get_records_since(self, set_cache=True, set_db=True, query_offset=None, num_rec=None, ex_details=None, simulate_scenarios=None): """ @@ -371,94 +514,70 @@ class TestDataCacheV2(unittest.TestCase): print('Testing get_records_since() method:') - # Use provided ex_details or fallback to the class attribute. ex_details = ex_details or self.ex_details - # Generate a data/database key using exchange details. key = f'{ex_details[0]}_{ex_details[1]}_{ex_details[2]}' - - # Set default number of records if not provided. num_rec = num_rec or 12 - table_timeframe = ex_details[1] # Extract timeframe from exchange details. + table_timeframe = ex_details[1] - # Initialize DataGenerator with the given timeframe. data_gen = DataGenerator(table_timeframe) if simulate_scenarios == 'not_enough_data': - # Set query_offset to a time earlier than the start of the table data. query_offset = (num_rec + 5) * data_gen.timeframe_amount else: - # Default to querying for 1 record length less than the table duration. query_offset = query_offset or (num_rec - 1) * data_gen.timeframe_amount if simulate_scenarios == 'incomplete_data': - # Set start time to generate fewer records than required. start_time_for_data = data_gen.x_time_ago(num_rec * data_gen.timeframe_amount) - num_rec = 5 # Set a smaller number of records to simulate incomplete data. + num_rec = 5 else: - # No specific start time for data generation. start_time_for_data = None - # Create the initial data table. df_initial = data_gen.create_table(num_rec, start=start_time_for_data) if simulate_scenarios == 'missing_section': - # Simulate missing section in the data by dropping records. df_initial = data_gen.generate_missing_section(df_initial, drop_start=2, drop_end=5) - # Convert 'open_time' to datetime for better readability. temp_df = df_initial.copy() temp_df['open_time'] = pd.to_datetime(temp_df['open_time'], unit='ms') print(f'Table Created:\n{temp_df}') if set_cache: - # Insert the generated table into the cache. - print('Inserting table into the cache.') - self.data.set_cache(data=df_initial, key=key) + print('Ensuring the cache exists and then inserting table into the cache.') + self.data.set_cache_item(data=df_initial, key=key, cache_name='candles') if set_db: - # Insert the generated table into the database. print('Inserting table into the database.') with SQLite(self.db_file) as con: df_initial.to_sql(key, con, if_exists='replace', index=False) - # Calculate the start time for querying the records. start_datetime = data_gen.x_time_ago(query_offset) - - # Ensure start_datetime is timezone-aware (UTC). if start_datetime.tzinfo is None: start_datetime = start_datetime.replace(tzinfo=dt.timezone.utc) - # Defaults to current time if not provided to get_records_since() query_end_time = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc) print(f'Requesting records from {start_datetime} to {query_end_time}') - # Query the records since the calculated start time. result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details) - # Filter the initial data table to match the query time. expected = df_initial[df_initial['open_time'] >= data_gen.unix_time_millis(start_datetime)].reset_index( drop=True) temp_df = expected.copy() temp_df['open_time'] = pd.to_datetime(temp_df['open_time'], unit='ms') print(f'Expected table:\n{temp_df}') - # Print the result from the query for comparison. temp_df = result.copy() temp_df['open_time'] = pd.to_datetime(temp_df['open_time'], unit='ms') print(f'Resulting table:\n{temp_df}') if simulate_scenarios in ['not_enough_data', 'incomplete_data', 'missing_section']: - # Check that the result has more rows than the expected incomplete data. assert result.shape[0] > expected.shape[ 0], "Result has fewer or equal rows compared to the incomplete data." print("\nThe returned DataFrame has filled in the missing data!") else: - # Ensure the result and expected dataframes match in shape and content. assert result.shape == expected.shape, f"Shape mismatch: {result.shape} vs {expected.shape}" pd.testing.assert_series_equal(result['open_time'], expected['open_time'], check_dtype=False) print("\nThe DataFrames have the same shape and the 'open_time' columns match.") - # Verify that the oldest timestamp in the result is within the allowed time difference. oldest_timestamp = pd.to_datetime(result['open_time'].min(), unit='ms').tz_localize('UTC') time_diff = oldest_timestamp - start_datetime max_allowed_time_diff = dt.timedelta(**{data_gen.timeframe_unit: data_gen.timeframe_amount}) @@ -469,7 +588,6 @@ class TestDataCacheV2(unittest.TestCase): print(f'The first timestamp is {time_diff} from {start_datetime}') - # Verify that the newest timestamp in the result is within the allowed time difference. newest_timestamp = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC') time_diff_end = abs(query_end_time - newest_timestamp) @@ -505,6 +623,9 @@ class TestDataCacheV2(unittest.TestCase): def test_other_timeframes(self): print('\nTest get_records_since with a different timeframe') + if 'candles' not in self.data.caches: + self.data.create_cache(cache_name='candles') + ex_details = ['BTC/USD', '15m', 'binance', 'test_guy'] start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=2) # Query the records since the calculated start time. @@ -573,37 +694,53 @@ class TestDataCacheV2(unittest.TestCase): def test_remove_row(self): print('Testing remove_row() method:') - # Insert data into the cache with the expected columns - df = pd.DataFrame({ - 'key': [self.key], - 'data': ['test_data'] + # Create a DataFrame to insert as the data + user_data = pd.DataFrame({ + 'user_name': ['test_user'], + 'password': ['test_password'] }) - self.data.set_cache(data='test_data', key=self.key) + + # Insert data into the cache + self.data.set_cache_item( + cache_name='users', + key='user1', + data=user_data + ) # Ensure the data is in the cache - self.assertTrue(self.data.cache_exists(self.key), "Data was not correctly inserted into the cache.") + cache_item = self.data.get_cache_item('user1', 'users') + self.assertIsNotNone(cache_item, "Data was not correctly inserted into the cache.") + + # The cache_item is a DataFrame, so we access the 'user_name' column directly + self.assertEqual(cache_item['user_name'].iloc[0], 'test_user', "Inserted data is incorrect.") # Remove the row from the cache only (soft delete) - self.data.remove_row(table='test_table_2', filter_vals=('key', self.key), remove_from_db=False) + self.data.remove_row(cache_name='users', filter_vals=('user_name', 'test_user'), remove_from_db=False) # Verify the row has been removed from the cache - self.assertFalse(self.data.cache_exists(self.key), "Row was not correctly removed from the cache.") + cache_item = self.data.get_cache_item('user1', 'users') + self.assertIsNone(cache_item, "Row was not correctly removed from the cache.") # Reinsert the data for hard delete test - self.data.set_cache(data='test_data', key=self.key) + self.data.set_cache_item( + cache_name='users', + key='user1', + data=user_data + ) # Mock database delete by adding the row to the database - self.data.db.insert_row(table='test_table_2', columns=('key', 'data'), values=(self.key, 'test_data')) + self.data.db.insert_row(table='users', columns=('user_name', 'password'), values=('test_user', 'test_password')) # Remove the row from both cache and database (hard delete) - self.data.remove_row(table='test_table_2', filter_vals=('key', self.key), remove_from_db=True) + self.data.remove_row(cache_name='users', filter_vals=('user_name', 'test_user'), remove_from_db=True) # Verify the row has been removed from the cache - self.assertFalse(self.data.cache_exists(self.key), "Row was not correctly removed from the cache.") + cache_item = self.data.get_cache_item('user1', 'users') + self.assertIsNone(cache_item, "Row was not correctly removed from the cache.") # Verify the row has been removed from the database with SQLite(self.db_file) as con: - result = pd.read_sql(f'SELECT * FROM test_table_2 WHERE key="{self.key}"', con) + result = pd.read_sql(f'SELECT * FROM users WHERE user_name="test_user"', con) self.assertTrue(result.empty, "Row was not correctly removed from the database.") print(' - Remove row from cache and database passed.') @@ -662,80 +799,164 @@ class TestDataCacheV2(unittest.TestCase): print(' - All estimate_record_count() tests passed.') - def test_fetch_cached_rows(self): - print('Testing fetch_cached_rows() method:') + def test_get_or_fetch_rows(self): - # Set up mock data in the cache - df = pd.DataFrame({ - 'table': ['test_table_2'], - 'key': ['test_key'], - 'data': ['test_data'] + # Create a mock table in the cache with multiple entries + df1 = pd.DataFrame({ + 'user_name': ['billy'], + 'password': ['1234'], + 'exchanges': [['ex1', 'ex2', 'ex3']] }) - self.data.cache = pd.concat([self.data.cache, df]) - # Test fetching from cache - result = self.data.fetch_cached_rows('test_table_2', ('key', 'test_key')) + df2 = pd.DataFrame({ + 'user_name': ['john'], + 'password': ['5678'], + 'exchanges': [['ex4', 'ex5', 'ex6']] + }) + + df3 = pd.DataFrame({ + 'user_name': ['alice'], + 'password': ['91011'], + 'exchanges': [['ex7', 'ex8', 'ex9']] + }) + + # Insert these DataFrames into the 'users' cache + self.data.create_cache('users', cache_type=InMemoryCache) + self.data.set_cache_item(key='user_billy', data=df1, cache_name='users') + self.data.set_cache_item(key='user_john', data=df2, cache_name='users') + self.data.set_cache_item(key='user_alice', data=df3, cache_name='users') + + print('Testing get_or_fetch_rows() method:') + + # Test fetching an existing user from the cache + result = self.data.get_or_fetch_rows('users', ('user_name', 'billy')) self.assertIsInstance(result, pd.DataFrame, "Failed to fetch DataFrame from cache") self.assertFalse(result.empty, "The fetched DataFrame is empty") - self.assertEqual(result.iloc[0]['data'], 'test_data', "Incorrect data fetched from cache") + self.assertEqual(result.iloc[0]['password'], '1234', "Incorrect data fetched from cache") - # Test fetching from database (assuming the method calls it) - # Here we would typically mock the database call - # But since we're not doing I/O, we will skip that part - print(' - Fetch from cache and database simulated.') + # Test fetching another user from the cache + result = self.data.get_or_fetch_rows('users', ('user_name', 'john')) + self.assertIsInstance(result, pd.DataFrame, "Failed to fetch DataFrame from cache") + self.assertFalse(result.empty, "The fetched DataFrame is empty") + self.assertEqual(result.iloc[0]['password'], '5678', "Incorrect data fetched from cache") + + # Test fetching a user that does not exist in the cache + result = self.data.get_or_fetch_rows('users', ('user_name', 'non_existent_user')) + + # Check if result is None (indicating that no data was found) + self.assertIsNone(result, "Expected result to be None for a non-existent user") + + print(' - Fetching rows from cache passed.') def test_is_attr_taken(self): - print('Testing is_attr_taken() method:') + # Create a cache named 'users' + self.data.create_cache('users', cache_type=InMemoryCache) - # Set up mock data in the cache - df = pd.DataFrame({ - 'table': ['users'], - 'user_name': ['test_user'], - 'data': ['test_data'] + # Create mock data for three users + user_data_1 = pd.DataFrame({ + 'user_name': ['billy'], + 'password': ['1234'], + 'exchanges': [['ex1', 'ex2', 'ex3']] + }) + user_data_2 = pd.DataFrame({ + 'user_name': ['john'], + 'password': ['5678'], + 'exchanges': [['ex1', 'ex2', 'ex4']] + }) + user_data_3 = pd.DataFrame({ + 'user_name': ['alice'], + 'password': ['abcd'], + 'exchanges': [['ex5', 'ex6', 'ex7']] }) - self.data.cache = pd.concat([self.data.cache, df]) - # Test for existing attribute - result = self.data.is_attr_taken('users', 'user_name', 'test_user') - self.assertTrue(result, "Failed to detect existing attribute") + # Insert mock data into the cache + self.data.set_cache_item('user1', user_data_1, cache_name='users') + self.data.set_cache_item('user2', user_data_2, cache_name='users') + self.data.set_cache_item('user3', user_data_3, cache_name='users') - # Test for non-existing attribute - result = self.data.is_attr_taken('users', 'user_name', 'non_existing_user') - self.assertFalse(result, "Incorrectly detected non-existing attribute") + # Test when attribute value is taken + result_taken = self.data.is_attr_taken(cache_name='users', attr='user_name', val='billy') + self.assertTrue(result_taken, "Expected 'billy' to be taken, but it was not.") - print(' - All is_attr_taken() tests passed.') + # Test when attribute value is not taken + result_not_taken = self.data.is_attr_taken(cache_name='users', attr='user_name', val='charlie') + self.assertFalse(result_not_taken, "Expected 'charlie' not to be taken, but it was.") - def test_insert_data(self): - print('Testing insert_data() method:') + def test_insert_df(self): + print('Testing insert_df() method:') # Create a DataFrame to insert df = pd.DataFrame({ - 'key': ['new_key'], - 'data': ['new_data'] + 'user_name': ['Alice'], + 'age': [30], + 'users_data': ['user_data_1'], + 'data': ['additional_data'], + 'password': ['1234'] }) # Insert data into the database and cache - self.data.insert_data(df=df, table='test_table_2') + self.data.insert_df(df=df, cache_name='users') - # Verify that the data was added to the cache - cached_value = self.data.get_cache('new_key') - self.assertEqual(cached_value, 'new_data', "Failed to insert data into cache") + # Assume the database will return an auto-incremented ID starting at 1 + auto_incremented_id = 1 - # Normally, we would also verify that the data was inserted into the database - # This would typically be done with a mock database or by checking the database state directly - print(' - Data insertion into cache and database simulated.') + # Verify that the data was added to the cache using the auto-incremented ID as the key + cached_df = self.data.get_cache_item(key=str(auto_incremented_id), cache_name='users') + + # Check that the DataFrame in the cache matches the original DataFrame + pd.testing.assert_frame_equal(cached_df, df, check_dtype=False) + + # Now, let's verify the data was inserted into the database + with SQLite(self.data.db.db_file) as conn: + # Query the users table for the inserted data + query_result = pd.read_sql_query(f"SELECT * FROM users WHERE id = {auto_incremented_id}", conn) + + # Verify the database content matches the inserted DataFrame + expected_db_df = df.copy() + expected_db_df['id'] = auto_incremented_id # Add the auto-incremented ID to the expected DataFrame + # Align column order + expected_db_df = expected_db_df[['id', 'user_name', 'age', 'users_data', 'data', 'password']] + + # Check that the database DataFrame matches the expected DataFrame + pd.testing.assert_frame_equal(query_result, expected_db_df, check_dtype=False) + + print(' - Data insertion into cache and database verified successfully.') def test_insert_row(self): - print('Testing insert_row() method:') + print("Testing insert_row() method:") - self.data.insert_row(table='test_table_2', columns=('key', 'data'), values=('test_key', 'test_data')) + # Define the cache name, columns, and values to insert + cache_name = 'users' + columns = ('user_name', 'age') + values = ('Alice', 30) - # Verify the row was inserted - with SQLite(self.db_file) as con: - result = pd.read_sql('SELECT * FROM test_table_2 WHERE key="test_key"', con) - self.assertFalse(result.empty, "Row was not inserted into the database.") + # Create the cache first + self.data.create_cache(cache_name, cache_type=InMemoryCache) - print(' - Insert row passed.') + # Insert a row into the cache and database without skipping the cache + self.data.insert_row(cache_name=cache_name, columns=columns, values=values, skip_cache=False) + + # Retrieve the inserted item from the cache + result = self.data.get_cache_item(key='1', cache_name=cache_name) + + # Assert that the data in the cache matches what was inserted + self.assertIsNotNone(result, "No data found in the cache for the inserted ID.") + self.assertEqual(result.iloc[0]['user_name'], 'Alice', "The name in the cache doesn't match the inserted value.") + self.assertEqual(result.iloc[0]['age'], 30, "The age in the cache does not match the inserted value.") + + # Now test with skipping the cache + print("Testing insert_row() with skip_cache=True") + + # Insert another row into the database, this time skipping the cache + self.data.insert_row(cache_name=cache_name, columns=columns, values=('Bob', 40), skip_cache=True) + + # Attempt to retrieve the newly inserted row from the cache + result_after_skip = self.data.get_cache_item(key='2', cache_name=cache_name) + + # Assert that no data is found in the cache for the new row + self.assertIsNone(result_after_skip, "Data should not have been cached when skip_cache=True.") + + print(" - Insert row with and without caching passed all checks.") def test_fill_data_holes(self): print('Testing _fill_data_holes() method:')