brighter-trading/tests/test_DataCache.py

2012 lines
97 KiB
Python

import pickle
import time
import pytz
from DataCache_v3 import DataCache, timeframe_to_timedelta, estimate_record_count, \
SnapshotDataCache, CacheManager, RowBasedCache, TableBasedCache
from ExchangeInterface import ExchangeInterface
import unittest
import pandas as pd
import datetime as dt
import os
from Database import SQLite, Database
import logging
from indicators import Indicator
logging.basicConfig(level=logging.DEBUG)
class DataGenerator:
def __init__(self, timeframe_str):
"""
Initialize the DataGenerator with a timeframe string like '2h', '5m', '1d', '1w', '1M', or '1y'.
"""
# Initialize attributes with placeholder values
self.timeframe_amount = None
self.timeframe_unit = None
# Set the actual timeframe
self.set_timeframe(timeframe_str)
def set_timeframe(self, timeframe_str):
"""
Set the timeframe unit and amount based on a string like '2h', '5m', '1d', '1w', '1M', or '1y'.
"""
self.timeframe_amount = int(timeframe_str[:-1])
unit = timeframe_str[-1]
if unit == 's':
self.timeframe_unit = 'seconds'
elif unit == 'm':
self.timeframe_unit = 'minutes'
elif unit == 'h':
self.timeframe_unit = 'hours'
elif unit == 'd':
self.timeframe_unit = 'days'
elif unit == 'w':
self.timeframe_unit = 'weeks'
elif unit == 'M':
self.timeframe_unit = 'months'
elif unit == 'Y':
self.timeframe_unit = 'years'
else:
raise ValueError(
"Unsupported timeframe unit. Use 's,m,h,d,w,M,Y'.")
def create_table(self, num_rec=None, start=None, end=None):
"""
Create a table with simulated data. If both start and end are provided, num_rec is derived from the interval.
If neither are provided the table will have num_rec and end at the current time.
Parameters:
num_rec (int, optional): The number of records to generate.
start (datetime, optional): The start time for the first record.
end (datetime, optional): The end time for the last record.
Returns:
pd.DataFrame: A DataFrame with the simulated data.
"""
# Ensure provided datetime parameters are timezone aware
if start and start.tzinfo is None:
raise ValueError('start datetime must be timezone aware.')
if end and end.tzinfo is None:
raise ValueError('end datetime must be timezone aware.')
# If neither start nor end are provided.
if start is None and end is None:
end = dt.datetime.now(dt.timezone.utc)
if num_rec is None:
raise ValueError("num_rec must be provided if both start and end are not specified.")
# If start and end are provided.
if start is not None and end is not None:
total_duration = (end - start).total_seconds()
interval_seconds = self.timeframe_amount * self._get_seconds_per_unit(self.timeframe_unit)
num_rec = int(total_duration // interval_seconds) + 1
# If only end is provided.
if end is not None and start is None:
if num_rec is None:
raise ValueError("num_rec must be provided if both start and end are not specified.")
interval_seconds = self.timeframe_amount * self._get_seconds_per_unit(self.timeframe_unit)
start = end - dt.timedelta(seconds=(num_rec - 1) * interval_seconds)
start = start.replace(tzinfo=pytz.utc)
# Ensure start is aligned to the timeframe interval
start = self.round_down_datetime(start, self.timeframe_unit[0], self.timeframe_amount)
# Generate times
times = [self.unix_time_millis(start + self._delta(i)) for i in range(num_rec)]
df = pd.DataFrame({
'market_id': 1,
'time': times,
'open': [100 + i for i in range(num_rec)],
'high': [110 + i for i in range(num_rec)],
'low': [90 + i for i in range(num_rec)],
'close': [105 + i for i in range(num_rec)],
'volume': [1000 + i for i in range(num_rec)]
})
return df
@staticmethod
def _get_seconds_per_unit(unit):
"""Helper method to convert timeframe units to seconds."""
units_in_seconds = {
'seconds': 1,
'minutes': 60,
'hours': 3600,
'days': 86400,
'weeks': 604800,
'months': 2592000, # Assuming 30 days per month
'years': 31536000 # Assuming 365 days per year
}
if unit not in units_in_seconds:
raise ValueError(f"Unsupported timeframe unit: {unit}")
return units_in_seconds[unit]
def generate_incomplete_data(self, query_offset, num_rec=5):
"""
Generate data that is incomplete, i.e., starts before the query but doesn't fully satisfy it.
"""
query_start_time = self.x_time_ago(query_offset)
start_time_for_data = self.get_start_time(query_start_time)
return self.create_table(num_rec, start=start_time_for_data)
@staticmethod
def generate_missing_section(df, drop_start=5, drop_end=8):
"""
Generate data with a missing section.
"""
df = df.drop(df.index[drop_start:drop_end]).reset_index(drop=True)
return df
def get_start_time(self, query_start_time):
margin = 2
delta_args = {self.timeframe_unit: margin * self.timeframe_amount}
return query_start_time - dt.timedelta(**delta_args)
def x_time_ago(self, offset):
"""
Returns a datetime object representing the current time minus the offset in the specified units.
"""
delta_args = {self.timeframe_unit: offset}
return dt.datetime.utcnow().replace(tzinfo=pytz.utc) - dt.timedelta(**delta_args)
def _delta(self, i):
"""
Returns a timedelta object for the ith increment based on the timeframe unit and amount.
"""
delta_args = {self.timeframe_unit: i * self.timeframe_amount}
return dt.timedelta(**delta_args)
@staticmethod
def unix_time_millis(dt_obj: dt.datetime):
"""
Convert a datetime object to Unix time in milliseconds.
"""
if dt_obj.tzinfo is None:
raise ValueError('dt_obj needs to be timezone aware.')
epoch = dt.datetime(1970, 1, 1).replace(tzinfo=pytz.UTC)
return int((dt_obj - epoch).total_seconds() * 1000)
@staticmethod
def round_down_datetime(dt_obj: dt.datetime, unit: str, interval: int) -> dt.datetime:
if dt_obj.tzinfo is None:
raise ValueError('dt_obj needs to be timezone aware.')
if unit == 's': # Round down to the nearest interval of seconds
seconds = (dt_obj.second // interval) * interval
dt_obj = dt_obj.replace(second=seconds, microsecond=0)
elif unit == 'm': # Round down to the nearest interval of minutes
minutes = (dt_obj.minute // interval) * interval
dt_obj = dt_obj.replace(minute=minutes, second=0, microsecond=0)
elif unit == 'h': # Round down to the nearest interval of hours
hours = (dt_obj.hour // interval) * interval
dt_obj = dt_obj.replace(hour=hours, minute=0, second=0, microsecond=0)
elif unit == 'd': # Round down to the nearest interval of days
days = (dt_obj.day // interval) * interval
dt_obj = dt_obj.replace(day=days, hour=0, minute=0, second=0, microsecond=0)
elif unit == 'w': # Round down to the nearest interval of weeks
dt_obj -= dt.timedelta(days=dt_obj.weekday() % (interval * 7))
dt_obj = dt_obj.replace(hour=0, minute=0, second=0, microsecond=0)
elif unit == 'M': # Round down to the nearest interval of months
months = ((dt_obj.month - 1) // interval) * interval + 1
dt_obj = dt_obj.replace(month=months, day=1, hour=0, minute=0, second=0, microsecond=0)
elif unit == 'y': # Round down to the nearest interval of years
years = (dt_obj.year // interval) * interval
dt_obj = dt_obj.replace(year=years, month=1, day=1, hour=0, minute=0, second=0, microsecond=0)
return dt_obj
class TestDataCache(unittest.TestCase):
def setUp(self):
# Initialize DataCache
self.data = DataCache()
self.exchanges = ExchangeInterface(self.data)
self.exchanges_connected = False
self.database_is_setup = False
self.test_data_loaded = False
self.load_prerequisites()
def tearDown(self):
if self.database_is_setup:
if os.path.exists(self.db_file):
os.remove(self.db_file)
def load_prerequisites(self):
self.connect_exchanges()
self.set_up_database()
self.load_test_data()
def load_test_data(self):
if self.test_data_loaded:
return
# Create caches needed for testing
self.data.create_cache('candles', cache_type='row')
# Reuse details for exchange and market
self.ex_details = ['BTC/USD', '2h', 'binance', 'test_guy']
self.key = f'{self.ex_details[0]}_{self.ex_details[1]}_{self.ex_details[2]}'
self.test_data_loaded = True
def connect_exchanges(self):
if not self.exchanges_connected:
self.exchanges.connect_exchange(exchange_name='binance', user_name='test_guy', api_keys=None)
self.exchanges.connect_exchange(exchange_name='binance', user_name='user_1', api_keys=None)
self.exchanges.connect_exchange(exchange_name='binance', user_name='user_2', api_keys=None)
self.exchanges_connected = True
def set_up_database(self):
if not self.database_is_setup:
self.db_file = 'test_db.sqlite'
self.database = Database(db_file=self.db_file)
# Create necessary tables
sql_create_table_1 = f"""
CREATE TABLE IF NOT EXISTS test_table (
id INTEGER PRIMARY KEY,
market_id INTEGER,
time INTEGER UNIQUE ON CONFLICT IGNORE,
open REAL NOT NULL,
high REAL NOT NULL,
low REAL NOT NULL,
close REAL NOT NULL,
volume REAL NOT NULL,
FOREIGN KEY (market_id) REFERENCES market (id)
)"""
sql_create_table_2 = """
CREATE TABLE IF NOT EXISTS exchange (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE
)"""
sql_create_table_3 = """
CREATE TABLE IF NOT EXISTS markets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT,
exchange_id INTEGER,
FOREIGN KEY (exchange_id) REFERENCES exchange(id)
)"""
sql_create_table_4 = f"""
CREATE TABLE IF NOT EXISTS test_table_2 (
key TEXT PRIMARY KEY,
data TEXT NOT NULL
)"""
sql_create_table_5 = """
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT,
age INTEGER,
users_data TEXT,
data TEXT,
password TEXT -- Moved to a new line and added a comma after 'data'
)
"""
with SQLite(db_file=self.db_file) as con:
con.execute(sql_create_table_1)
con.execute(sql_create_table_2)
con.execute(sql_create_table_3)
con.execute(sql_create_table_4)
con.execute(sql_create_table_5)
self.data.db = self.database # Keep the database setup
self.database_is_setup = True
def test_cache_system(self):
print("\n---- Starting Cache System Test ----")
# Step 1: Create CacheManager instance
cache_manager = CacheManager()
print("\n1. Created CacheManager instance.")
# Step 2: Create a Table-Based Cache for users
cache_manager.create_cache(name="users", cache_type="table",
default_expiration=dt.timedelta(seconds=2), size_limit=5)
table_cache = cache_manager.get_cache("users")
print("\n2. Created a Table-Based Cache for 'users' with expiration set to 2 seconds.")
# Step 3: Add users DataFrame to the table-based cache
users_df = pd.DataFrame([
{'name': 'Bob', 'age': 25, 'email': 'bob@example.com'},
{'name': 'Alice', 'age': 30, 'email': 'alice@example.com'}
])
table_cache.add_table(users_df)
print("\n3. Added user data to 'users' table-based cache:")
print(users_df)
# Step 4: Query the cache to retrieve Bob's information
result = table_cache.query([('name', 'Bob')])
print("\n4. Queried 'users' table-based cache for user 'Bob':")
print(result)
self.assertEqual(len(result), 1, "Should return exactly 1 row for Bob.")
self.assertEqual(result.iloc[0]['name'], 'Bob', "The name should be Bob.")
# Step 5: Wait for 3 seconds (after expiration time) and query again to check expiration
print("\n5. Waiting for 3 seconds to allow the cache to expire...")
time.sleep(3)
result_after_expiry = table_cache.query([('name', 'Bob')])
print("\n5. After 3 seconds, queried again for Bob. Result should be empty due to expiration:")
print(result_after_expiry)
self.assertTrue(result_after_expiry.empty, "Result should be empty as Bob's entry has expired.")
# Step 6: Create a Row-Based Cache for candles
cache_manager.create_cache(name="candles", cache_type="row",
default_expiration=dt.timedelta(seconds=5), size_limit=10)
row_cache = cache_manager.get_cache("candles")
print("\n6. Created a Row-Based Cache for 'candles' with expiration set to 5 seconds.")
# Step 7: Add candle data to the row-based cache
candle_data_1 = pd.DataFrame([
{'time': '2024-09-11 00:00', 'open': 100, 'high': 105, 'low': 99, 'close': 102},
{'time': '2024-09-11 01:00', 'open': 101, 'high': 106, 'low': 100, 'close': 103}
])
row_cache.add_entry("candle_1", candle_data_1)
print("\n7. Added candle data to 'candles' row-based cache under key 'candle_1':")
print(candle_data_1)
# Step 8: Query the row-based cache to retrieve specific time entry for candle_1
result_candle = row_cache.query([("key", "candle_1"), ("time", "2024-09-11 00:00")])
print("\n8. Queried 'candles' row-based cache for 'candle_1' and time '2024-09-11 00:00':")
print(result_candle)
self.assertEqual(len(result_candle), 1, "Should return exactly 1 row for the specified time.")
self.assertEqual(result_candle.iloc[0]['time'], '2024-09-11 00:00', "The time should match the queried time.")
# Step 9: Wait for 6 seconds (after expiration time) and query again to check expiration
print("\n9. Waiting for 6 seconds to allow the 'candle_1' cache to expire...")
time.sleep(6)
result_candle_after_expiry = row_cache.query([("key", "candle_1"), ("time", "2024-09-11 00:00")])
print("\n9. After 6 seconds, queried again for 'candle_1'. Result should be empty due to expiration:")
print(result_candle_after_expiry)
self.assertTrue(result_candle_after_expiry.empty, "Result should be empty as 'candle_1' has expired.")
# Step 10: Test the size limit of the row-based cache (adding more than limit)
print("\n10. Testing row-based cache size limit (max 10 entries).")
for i in range(1, 12):
row_cache.add_entry(f"candle_{i}", pd.DataFrame([
{'time': f'2024-09-11 00:00', 'open': 100 + i, 'high': 105 + i, 'low': 99 + i, 'close': 102 + i}
]))
print(f"Added entry: candle_{i}")
print("\n11. Checking the size of the cache after adding 11 entries (limit is 10):")
result = row_cache.get_entry("candle_1")
print(f"Checking 'candle_1': {result}")
self.assertIsNone(result, "'candle_1' should have been evicted as the size limit is 10.")
# Final print statement for clarity of test ending
print("\n---- Cache System Test Completed ----")
def test_cache_system_advanced_usage(self):
print("\n---- Starting Advanced Cache System Test ----")
# Step 1: Create CacheManager instance
cache_manager = CacheManager()
print("\n1. Created CacheManager instance.")
# Row-Based Cache Test with Different Data Types
cache_manager.create_cache(name="row_cache", cache_type="row",
default_expiration=dt.timedelta(seconds=5), size_limit=10)
row_cache = cache_manager.get_cache("row_cache")
print("\n2. Created a Row-Based Cache with expiration set to 5 seconds.")
# Step 2: Add different types of data into Row-Based Cache
# Add a string
row_cache.add_entry("message", "Hello, World!")
print("\n3. Added a string to Row-Based Cache under key 'message'.")
# Add a dictionary
row_cache.add_entry("user_profile", {"name": "Charlie", "age": 28, "email": "charlie@example.com"})
print("\n4. Added a dictionary to Row-Based Cache under key 'user_profile'.")
# Add a list of numbers
row_cache.add_entry("numbers", [1, 2, 3, 4, 5])
print("\n5. Added a list of numbers to Row-Based Cache under key 'numbers'.")
# Step 3: Query the Row-Based Cache
print("\n6. Querying Row-Based Cache for different types of data:")
result_message = row_cache.query([("key", "message")])
print(f"Query result for key 'message': {result_message}")
result_profile = row_cache.query([("key", "user_profile")])
print(f"Query result for key 'user_profile': {result_profile}")
result_numbers = row_cache.query([("key", "numbers")])
print(f"Query result for key 'numbers': {result_numbers}")
# Assert non-expired entries
self.assertEqual(result_message.iloc[0][0], "Hello, World!", "Message should be 'Hello, World!'")
self.assertEqual(result_profile.iloc[0]['name'], 'Charlie', "User profile should have name 'Charlie'")
# Convert the DataFrame row back to a list and assert the values match
numbers_list = result_numbers.iloc[0].tolist()
self.assertEqual(numbers_list, [1, 2, 3, 4, 5], "Should return list of numbers.")
# Table-Based Cache Test with DataFrames
cache_manager.create_cache(name="table_cache", cache_type="table",
default_expiration=dt.timedelta(seconds=5), size_limit=5)
table_cache = cache_manager.get_cache("table_cache")
print("\n7. Created a Table-Based Cache with expiration set to 5 seconds.")
# Step 4: Add a DataFrame with mixed data types to Table-Based Cache
mixed_df = pd.DataFrame([
{'category': 'A', 'value': 100, 'timestamp': '2024-09-12 12:00'},
{'category': 'B', 'value': 200, 'timestamp': '2024-09-12 13:00'},
{'category': 'A', 'value': 150, 'timestamp': '2024-09-12 14:00'}
])
table_cache.add_table(mixed_df)
print("\n8. Added mixed DataFrame to Table-Based Cache:")
print(mixed_df)
# Step 5: Query the Table-Based Cache
print("\n9. Querying Table-Based Cache for category 'A':")
result_category_a = table_cache.query([("category", "A")])
print(result_category_a)
self.assertEqual(len(result_category_a), 2, "There should be 2 rows with category 'A'.")
print("\n10. Querying Table-Based Cache for value greater than 100:")
result_value_gt_100 = table_cache.query([("value", 150)])
print(result_value_gt_100)
self.assertEqual(len(result_value_gt_100), 1, "There should be 1 row with value of 150.")
# Step 6: Wait for entries to expire and query again
print("\n11. Waiting for 6 seconds to let all rows expire...")
time.sleep(6)
result_after_expiry = table_cache.query([("category", "A")])
print(f"\n12. After 6 seconds, querying again for category 'A'. Result should be empty:")
print(result_after_expiry)
self.assertTrue(result_after_expiry.empty, "Result should be empty due to expiration.")
# Final print statement for clarity of test ending
print("\n---- Advanced Cache System Test Completed ----")
def test_cache_system_edge_cases(self):
print("\n---- Starting Edge Case Cache System Test ----")
# Step 1: Create CacheManager instance
cache_manager = CacheManager()
print("\n1. Created CacheManager instance.")
# Test 1: Cache Size Limit (Row-Based Cache)
cache_manager.create_cache(name="limited_row_cache", cache_type="row", size_limit=3, eviction_policy="evict")
limited_row_cache = cache_manager.get_cache("limited_row_cache")
print("\n2. Created a Row-Based Cache with size limit of 3 and eviction policy 'evict'.")
# Add entries beyond the size limit and check eviction
limited_row_cache.add_entry("item1", "Data 1")
print(f"Cache after adding item1: {limited_row_cache.cache}")
limited_row_cache.add_entry("item2", "Data 2")
print(f"Cache after adding item2: {limited_row_cache.cache}")
limited_row_cache.add_entry("item3", "Data 3")
print(f"Cache after adding item3: {limited_row_cache.cache}")
# Add 4th entry, which should cause eviction of the first entry
limited_row_cache.add_entry("item4", "Data 4")
print(f"Cache after adding item4: {limited_row_cache.cache}")
# Verify eviction of the oldest entry (item1)
result_item1 = limited_row_cache.query([("key", "item1")])
print(f"Query result for 'item1' (should be evicted): {result_item1}")
self.assertTrue(result_item1.empty, "'item1' should be evicted.")
# Verify that other items exist
result_item4 = limited_row_cache.query([("key", "item4")])
print(f"Query result for 'item4': {result_item4}")
self.assertFalse(result_item4.empty, "'item4' should be present.")
def test_access_counter_and_purging(self):
"""Test access counter and purging mechanism."""
print("\n---- Starting Access Counter and Purge Mechanism Test ----")
# Step 1: Create Row-Based Cache with a purge threshold of 5 accesses
cache = RowBasedCache(default_expiration=1, purge_threshold=5)
print("\n1. Created a Row-Based Cache with purge threshold of 5 accesses.")
# Step 2: Add entries with expiration times
cache.add_entry("item1", "Data 1")
cache.add_entry("item2", "Data 2")
print("\n2. Added 2 entries with a 1-second expiration time.")
# Step 3: Access cache 5 times to trigger purge
for i in range(5):
cache.get_entry("item1")
print(f"\n3. Accessed cache {i+1} times.")
# Step 4: Ensure item1 is still in cache (it hasn't expired yet because of timing)
result_item1 = cache.get_entry("item1")
print(f"\n4. Retrieved 'item1' from cache before expiration: {result_item1}")
self.assertIsNotNone(result_item1, "'item1' should still be in cache.")
# Step 5: Wait for expiration and access again to trigger purge
time.sleep(2)
result_item1_expired = cache.get_entry("item1")
print(f"\n5. Retrieved 'item1' after expiration (should be None): {result_item1_expired}")
self.assertIsNone(result_item1_expired, "'item1' should have expired.")
# Step 6: Access cache 5 more times to confirm expired entries are purged
for i in range(5):
cache.get_entry("item2")
print(f"\n6. Accessed cache {i+1} more times to trigger another purge.")
# Verify item2 is also expired after accesses
result_item2_expired = cache.get_entry("item2")
print(f"\n7. Retrieved 'item2' after expiration (should be None): {result_item2_expired}")
self.assertIsNone(result_item2_expired, "'item2' should have expired.")
def test_is_attr_taken_in_row_cache(self):
"""Test the is_attr_taken method in the Row-Based Cache."""
print("\n---- Starting is_attr_taken in Row-Based Cache Test ----")
# Step 1: Create a Row-Based Cache
cache = RowBasedCache()
print("\n1. Created a Row-Based Cache.")
# Step 2: Add entries with DataFrames
df = pd.DataFrame({'name': ['Alice', 'Bob'], 'age': [30, 25]})
cache.add_entry("users", df)
print("\n2. Added a DataFrame to the cache under key 'users'.")
# Step 3: Test is_attr_taken
attr_taken = cache.is_attr_taken('name', 'Alice')
print(f"\n3. Checked if 'name' column contains 'Alice': {attr_taken}")
self.assertTrue(attr_taken, "'name' column should contain 'Alice'.")
attr_not_taken = cache.is_attr_taken('name', 'Charlie')
print(f"\n4. Checked if 'name' column contains 'Charlie': {attr_not_taken}")
self.assertFalse(attr_not_taken, "'name' column should not contain 'Charlie'.")
def test_is_attr_taken_in_table_cache(self):
"""Test the is_attr_taken method in the Table-Based Cache."""
print("\n---- Starting is_attr_taken in Table-Based Cache Test ----")
# Step 1: Create a Table-Based Cache
cache = TableBasedCache()
print("\n1. Created a Table-Based Cache.")
# Step 2: Add a DataFrame to the Table-Based Cache
df = pd.DataFrame({'name': ['Alice', 'Charlie'], 'age': [30, 40]})
cache.add_table(df)
print("\n2. Added a DataFrame to the Table-Based Cache.")
# Step 3: Test is_attr_taken
attr_taken = cache.is_attr_taken('name', 'Alice')
print(f"\n3. Checked if 'name' column contains 'Alice': {attr_taken}")
self.assertTrue(attr_taken, "'name' column should contain 'Alice'.")
attr_not_taken = cache.is_attr_taken('name', 'Bob')
print(f"\n4. Checked if 'name' column contains 'Bob': {attr_not_taken}")
self.assertFalse(attr_not_taken, "'name' column should not contain 'Bob'.")
def test_expired_entry_handling(self):
"""Test that expired entries are not returned when querying."""
print("\n---- Starting Expired Entry Handling Test ----")
# Step 1: Create a Row-Based Cache with a 1-second expiration
cache = RowBasedCache(default_expiration=1)
print("\n1. Created a Row-Based Cache with a 1-second expiration.")
# Step 2: Add an entry
cache.add_entry("item1", "Temporary Data")
print("\n2. Added an entry 'item1' with a 1-second expiration.")
# Step 3: Wait for expiration
time.sleep(2)
# Step 4: Query the expired entry
result = cache.get_entry("item1")
print(f"\n4. Queried 'item1' after expiration: {result}")
self.assertIsNone(result, "'item1' should have expired and not be returned.")
def test_remove_item_with_conditions_row_cache(self):
"""Test remove_item method in Row-Based Cache with conditions."""
print("\n---- Starting remove_item with conditions in Row-Based Cache Test ----")
# Step 1: Create a Row-Based Cache
cache = RowBasedCache()
print("\n1. Created a Row-Based Cache.")
# Step 2: Add a DataFrame to the cache under 'users'
df = pd.DataFrame({'name': ['Alice', 'Bob', 'Charlie'], 'age': [30, 25, 35]})
cache.add_entry("users", df)
print("\n2. Added a DataFrame to the cache under key 'users'.")
print(cache.get_entry("users"))
# Step 3: Remove a specific row (where 'name' == 'Alice') from the DataFrame
removed = cache.remove_item([('key', 'users'), ('name', 'Alice')])
print(f"\n3. Removed 'Alice' from the DataFrame: {removed}")
self.assertTrue(removed, "'Alice' should be removed from the DataFrame.")
# Verify that 'Alice' was removed
remaining_data = cache.get_entry('users')
print(f"\n4. Remaining data in 'users': \n{remaining_data}")
self.assertNotIn('Alice', remaining_data['name'].values, "'Alice' should no longer be in the DataFrame.")
self.assertIn('Bob', remaining_data['name'].values, "'Bob' should still be in the DataFrame.")
# Step 4: Remove the last remaining row (where 'name' == 'Charlie')
removed = cache.remove_item([('key', 'users'), ('name', 'Charlie')])
print(f"\n5. Removed 'Charlie' from the DataFrame: {removed}")
remaining_data = cache.get_entry('users')
print(f"\n6. Remaining data in 'users' after 'Charlie' removal: \n{remaining_data}")
self.assertNotIn('Charlie', remaining_data['name'].values, "'Charlie' should no longer be in the DataFrame.")
# Step 5: Remove the last row (where 'name' == 'Bob'), this should remove the entire entry
removed = cache.remove_item([('key', 'users'), ('name', 'Bob')])
print(f"\n7. Removed 'Bob' from the DataFrame: {removed}")
remaining_data = cache.get_entry('users')
print(f"\n8. Remaining data in 'users' after removing 'Bob' (should be None): {remaining_data}")
self.assertIsNone(remaining_data, "'users' entry should no longer exist in the cache.")
def test_remove_item_with_conditions_table_cache(self):
"""Test remove_item method in Table-Based Cache with conditions."""
print("\n---- Starting remove_item with conditions in Table-Based Cache Test ----")
# Step 1: Create a Table-Based Cache
cache = TableBasedCache()
print("\n1. Created a Table-Based Cache.")
# Step 2: Add a DataFrame to the cache
df = pd.DataFrame({'name': ['Alice', 'Bob', 'Charlie'], 'age': [30, 25, 35]})
cache.add_table(df)
print("\n2. Added a DataFrame to the Table-Based Cache.")
print(cache.get_all_items())
# Step 3: Remove a specific row (where 'name' == 'Alice')
removed = cache.remove_item([('name', 'Alice')])
print(f"\n3. Removed 'Alice' from the table: {removed}")
self.assertTrue(removed, "'Alice' should be removed from the table-based cache.")
# Verify that 'Alice' was removed
remaining_data = cache.get_all_items()
print(f"\n4. Remaining data in the cache: \n{remaining_data}")
self.assertNotIn('Alice', remaining_data['name'].values, "'Alice' should no longer be in the table.")
self.assertIn('Bob', remaining_data['name'].values, "'Bob' should still be in the table.")
# Step 4: Remove another row (where 'name' == 'Charlie')
removed = cache.remove_item([('name', 'Charlie')])
print(f"\n5. Removed 'Charlie' from the table: {removed}")
remaining_data = cache.get_all_items()
print(f"\n6. Remaining data in the cache after removing 'Charlie': \n{remaining_data}")
self.assertNotIn('Charlie', remaining_data['name'].values, "'Charlie' should no longer be in the table.")
# Step 5: Remove the last row (where 'name' == 'Bob')
removed = cache.remove_item([('name', 'Bob')])
print(f"\n7. Removed 'Bob' from the table: {removed}")
remaining_data = cache.get_all_items()
print(f"\n8. Remaining data in the cache after removing 'Bob' (should be empty): \n{remaining_data}")
self.assertTrue(remaining_data.empty, "The table should be empty after removing all rows.")
def test_remove_item_with_conditions_market_data(self):
"""Test remove_item method in Row-Based Cache with market OHLC data."""
print("\n---- Starting remove_item with market OHLC data in Row-Based Cache Test ----")
# Step 1: Create a Row-Based Cache for market data
cache = RowBasedCache()
print("\n1. Created a Row-Based Cache for market data.")
# Step 2: Add OHLC data for 'BTC' and 'ETH'
btc_data = pd.DataFrame({
'timestamp': ['2024-09-10 12:00', '2024-09-10 12:05', '2024-09-10 12:10'],
'open': [30000, 30100, 30200],
'high': [30500, 30600, 30700],
'low': [29900, 30050, 30150],
'close': [30400, 30550, 30650]
})
eth_data = pd.DataFrame({
'timestamp': ['2024-09-10 12:00', '2024-09-10 12:05', '2024-09-10 12:10'],
'open': [2000, 2010, 2020],
'high': [2050, 2060, 2070],
'low': [1990, 2005, 2015],
'close': [2040, 2055, 2065]
})
cache.add_entry("BTC", btc_data)
cache.add_entry("ETH", eth_data)
print("\n2. Added OHLC data for 'BTC' and 'ETH'.")
print(f"BTC Data:\n{cache.get_entry('BTC')}")
print(f"ETH Data:\n{cache.get_entry('ETH')}")
# Step 3: Remove a specific row from 'BTC' data where timestamp == '2024-09-10 12:05'
removed = cache.remove_item([('key', 'BTC'), ('timestamp', '2024-09-10 12:05')])
print(f"\n3. Removed '2024-09-10 12:05' row from 'BTC' data: {removed}")
self.assertTrue(removed, "'2024-09-10 12:05' should be removed from the 'BTC' data.")
# Verify that the timestamp was removed from 'BTC' data
remaining_btc = cache.get_entry('BTC')
print(f"\n4. Remaining BTC data after removal:\n{remaining_btc}")
self.assertNotIn('2024-09-10 12:05', remaining_btc['timestamp'].values,
"'2024-09-10 12:05' should no longer be in the 'BTC' data.")
# Step 4: Remove entire 'ETH' data entry
removed_eth = cache.remove_item([('key', 'ETH')])
print(f"\n5. Removed entire 'ETH' data entry: {removed_eth}")
self.assertTrue(removed_eth, "'ETH' data should be removed from the cache.")
# Verify that 'ETH' was completely removed from the cache
remaining_eth = cache.get_entry('ETH')
print(f"\n6. Remaining ETH data after removal (should be None): {remaining_eth}")
self.assertIsNone(remaining_eth, "'ETH' entry should no longer exist in the cache.")
def test_remove_item_with_conditions_trade_stats(self):
"""Test remove_item method in Row-Based Cache with trade statistics data."""
print("\n---- Starting remove_item with trade statistics in Row-Based Cache Test ----")
# Step 1: Create a Row-Based Cache for trade statistics
cache = RowBasedCache()
print("\n1. Created a Row-Based Cache for trade statistics.")
# Step 2: Add trade statistics for 'strategy_1' and 'strategy_2'
strategy_1_data = pd.DataFrame({
'date': ['2024-09-10', '2024-09-11', '2024-09-12'],
'success_rate': [80, 85, 75],
'trades': [10, 12, 8]
})
strategy_2_data = pd.DataFrame({
'date': ['2024-09-10', '2024-09-11', '2024-09-12'],
'success_rate': [60, 70, 65],
'trades': [15, 17, 14]
})
cache.add_entry("strategy_1", strategy_1_data)
cache.add_entry("strategy_2", strategy_2_data)
print("\n2. Added trade statistics for 'strategy_1' and 'strategy_2'.")
print(f"Strategy 1 Data:\n{cache.get_entry('strategy_1')}")
print(f"Strategy 2 Data:\n{cache.get_entry('strategy_2')}")
# Step 3: Remove a specific row from 'strategy_1' where date == '2024-09-11'
removed = cache.remove_item([('key', 'strategy_1'), ('date', '2024-09-11')])
print(f"\n3. Removed '2024-09-11' row from 'strategy_1' data: {removed}")
self.assertTrue(removed, "'2024-09-11' should be removed from the 'strategy_1' data.")
# Verify that the date was removed from 'strategy_1' data
remaining_strategy_1 = cache.get_entry('strategy_1')
print(f"\n4. Remaining strategy_1 data after removal:\n{remaining_strategy_1}")
self.assertNotIn('2024-09-11', remaining_strategy_1['date'].values,
"'2024-09-11' should no longer be in the 'strategy_1' data.")
# Step 4: Remove entire 'strategy_2' data entry
removed_strategy_2 = cache.remove_item([('key', 'strategy_2')])
print(f"\n5. Removed entire 'strategy_2' data entry: {removed_strategy_2}")
self.assertTrue(removed_strategy_2, "'strategy_2' data should be removed from the cache.")
# Verify that 'strategy_2' was completely removed from the cache
remaining_strategy_2 = cache.get_entry('strategy_2')
print(f"\n6. Remaining strategy_2 data after removal (should be None): {remaining_strategy_2}")
self.assertIsNone(remaining_strategy_2, "'strategy_2' entry should no longer exist in the cache.")
def test_remove_item_with_other_data_types(self):
"""Test remove_item method in Row-Based Cache with different data types."""
print("\n---- Starting remove_item with different data types in Row-Based Cache Test ----")
# Step 1: Create a Row-Based Cache for mixed data types
cache = RowBasedCache()
print("\n1. Created a Row-Based Cache for mixed data types.")
# Step 2: Add entries with different data types
# String
cache.add_entry("message", "Hello, World!")
print("\n2. Added a string 'Hello, World!' under key 'message'.")
# Dictionary
cache.add_entry("user_profile", {"name": "Alice", "age": 30, "email": "alice@example.com"})
print("\n3. Added a dictionary under key 'user_profile'.")
# List
cache.add_entry("numbers", [1, 2, 3, 4, 5])
print("\n4. Added a list of numbers under key 'numbers'.")
# Integer
cache.add_entry("count", 42)
print("\n5. Added an integer '42' under key 'count'.")
# Step 3: Remove specific entries based on key
# Remove string entry
removed_message = cache.remove_item([('key', 'message')])
print(f"\n6. Removed string entry: {removed_message}")
self.assertTrue(removed_message, "'message' should be removed from the cache.")
self.assertIsNone(cache.get_entry('message'), "'message' entry should no longer exist.")
# Remove dictionary entry
removed_user_profile = cache.remove_item([('key', 'user_profile')])
print(f"\n7. Removed dictionary entry: {removed_user_profile}")
self.assertTrue(removed_user_profile, "'user_profile' should be removed from the cache.")
self.assertIsNone(cache.get_entry('user_profile'), "'user_profile' entry should no longer exist.")
# Remove list entry
removed_numbers = cache.remove_item([('key', 'numbers')])
print(f"\n8. Removed list entry: {removed_numbers}")
self.assertTrue(removed_numbers, "'numbers' should be removed from the cache.")
self.assertIsNone(cache.get_entry('numbers'), "'numbers' entry should no longer exist.")
# Remove integer entry
removed_count = cache.remove_item([('key', 'count')])
print(f"\n9. Removed integer entry: {removed_count}")
self.assertTrue(removed_count, "'count' should be removed from the cache.")
self.assertIsNone(cache.get_entry('count'), "'count' entry should no longer exist.")
def test_snapshot_row_based_cache(self):
"""Test snapshot functionality with row-based cache."""
print("\n---- Starting Snapshot Test with Row-Based Cache ----")
# Step 1: Create an instance of SnapshotDataCache
snapshot_cache = SnapshotDataCache()
print("\n1. Created SnapshotDataCache instance.")
# Step 2: Create a row-based cache and add data
snapshot_cache.create_cache(name="market_data", cache_type="row")
market_data = pd.DataFrame({
'timestamp': ['2024-09-10 12:00', '2024-09-10 12:05'],
'open': [30000, 30100],
'high': [30500, 30600],
'low': [29900, 30050],
'close': [30400, 30550]
})
snapshot_cache.get_cache("market_data").add_entry("BTC", market_data)
print("\n2. Added 'BTC' market data to row-based cache.")
# Step 3: Take a snapshot of the row-based cache
snapshot_cache.snapshot_cache("market_data")
snapshot_list = snapshot_cache.list_snapshots()
print(f"\n3. Snapshot list after taking snapshot: {snapshot_list}")
self.assertIn("market_data", snapshot_list, "Snapshot for 'market_data' should be present.")
# Step 4: Retrieve the snapshot and verify its contents
snapshot = snapshot_cache.get_snapshot("market_data")
print(f"\n4. Retrieved snapshot of 'market_data':\n{snapshot.get_entry('BTC')}")
pd.testing.assert_frame_equal(snapshot.get_entry('BTC'), market_data)
# Step 5: Add more data to the live cache and verify the snapshot is unchanged
additional_data = pd.DataFrame({
'timestamp': ['2024-09-10 12:10'],
'open': [30200],
'high': [30700],
'low': [30150],
'close': [30650]
})
snapshot_cache.get_cache("market_data").add_entry("BTC", additional_data)
print("\n5. Added additional data to the live 'BTC' cache.")
# Verify live cache has updated but the snapshot remains unchanged
live_data = snapshot_cache.get_cache("market_data").get_entry("BTC")
print(f"\n6. Live 'BTC' cache data:\n{live_data}")
self.assertEqual(len(live_data), 3, "Live cache should have 3 rows after adding more data.")
# Ensure the snapshot still has the original data
snapshot_data = snapshot_cache.get_snapshot("market_data").get_entry("BTC")
print(f"\n7. Snapshot data (should still be original):\n{snapshot_data}")
self.assertEqual(len(snapshot_data), 2, "Snapshot should still have the original 2 rows.")
def test_snapshot_table_based_cache_with_overwrite_column(self):
"""Test snapshot functionality with table-based cache and overwrite by column."""
print("\n---- Starting Snapshot Test with Table-Based Cache ----")
# Step 1: Create an instance of SnapshotDataCache
snapshot_cache = SnapshotDataCache()
print("\n1. Created SnapshotDataCache instance.")
# Step 2: Create a table-based cache and add initial data
user_data = pd.DataFrame({
'name': ['Alice', 'Bob'],
'email': ['alice@example.com', 'bob@example.com'],
'age': [30, 25]
})
snapshot_cache.create_cache(name="user_data", cache_type="table")
snapshot_cache.get_cache("user_data").add_table(user_data)
print("\n2. Added user data to table-based cache.")
# Step 3: Take a snapshot of the table-based cache
snapshot_cache.snapshot_cache("user_data")
snapshot_list = snapshot_cache.list_snapshots()
print(f"\n3. Snapshot list after taking snapshot: {snapshot_list}")
self.assertIn("user_data", snapshot_list, "Snapshot for 'user_data' should be present.")
# Step 4: Retrieve the snapshot and verify its contents (excluding metadata)
snapshot = snapshot_cache.get_snapshot("user_data")
snapshot_data = snapshot.get_all_items().drop(columns=['metadata'])
print(f"\n4. Retrieved snapshot of 'user_data' (without metadata):\n{snapshot_data}")
pd.testing.assert_frame_equal(snapshot_data, user_data)
# Step 5: Modify the live cache and overwrite specific rows by 'name'
updated_user_data = pd.DataFrame({
'name': ['Alice', 'Bob', 'Charlie'],
'email': ['alice@example.com', 'bob@example.com', 'charlie@example.com'],
'age': [35, 25, 40]
})
snapshot_cache.get_cache("user_data").add_table(updated_user_data, overwrite='name')
print("\n5. Updated live table by overwriting rows based on 'name'.")
# Verify live cache has updated but the snapshot remains unchanged
live_data = snapshot_cache.get_cache("user_data").get_all_items().drop(columns=['metadata'])
print(f"\n6. Live user_data table (without metadata):\n{live_data}")
self.assertEqual(len(live_data), 3,
"Live cache should have 3 rows after adding 'Charlie' and overwriting 'Alice'.")
# Ensure the snapshot still has the original data
snapshot_data = snapshot_cache.get_snapshot("user_data").get_all_items().drop(columns=['metadata'])
print(f"\n7. Snapshot data (should still be original, without metadata):\n{snapshot_data}")
self.assertEqual(len(snapshot_data), 2, "Snapshot should still have the original 2 rows.")
def test_update_candle_cache(self):
self.load_prerequisites()
print('Testing update_candle_cache() method:')
# Set a cache key
candle_cache_key = f'{self.ex_details[0]}_{self.ex_details[1]}_{self.ex_details[2]}'
# Initialize the DataGenerator with the 5-minute timeframe
data_gen = DataGenerator('5m')
# Create initial DataFrame and insert it into the cache
df_initial = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 0, 0, tzinfo=dt.timezone.utc))
print(f'Inserting this table into cache:\n{df_initial}\n')
self.data.serialized_datacache_insert(key=candle_cache_key, data=df_initial, cache_name='candles')
# Create new DataFrame to be added to the cache
df_new = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 15, 0, tzinfo=dt.timezone.utc))
print(f'Updating cache with this table:\n{df_new}\n')
self.data._update_candle_cache(more_records=df_new, key=candle_cache_key)
# Retrieve the resulting DataFrame from the cache
result = self.data.get_serialized_datacache(key=candle_cache_key, cache_name='candles')
print(f'The resulting table in cache is:\n{result}\n')
# Create the expected DataFrame
expected = data_gen.create_table(num_rec=6, start=dt.datetime(2024, 8, 9, 0, 0, 0, tzinfo=dt.timezone.utc))
print(f'The expected time values are:\n{expected["time"].tolist()}\n')
# Assert that the time values in the result match those in the expected DataFrame, in order
assert result['time'].tolist() == expected['time'].tolist(), \
f"time values in result are {result['time'].tolist()} expected {expected['time'].tolist()}"
print(f'The result time values match:\n{result["time"].tolist()}\n')
print(' - Update cache with new records passed.')
def analyze_cache_update(self, existing_records, more_records):
print("\n### Initial Data ###")
print("Existing Records:")
print(existing_records)
print("\nMore Records:")
print(more_records)
# Column-by-column comparison
print("\n### Column Comparison ###")
for col in existing_records.columns:
if col in more_records.columns:
print(f"\nAnalyzing column: {col}")
print(f"Existing Records '{col}' values:\n{existing_records[col].tolist()}")
print(f"More Records '{col}' values:\n{more_records[col].tolist()}")
print(f"Existing Records '{col}' type: {existing_records[col].dtype}")
print(f"More Records '{col}' type: {more_records[col].dtype}")
# Check for duplicate rows based on the 'time' column
print("\n### Duplicate Detection ###")
combined = pd.concat([existing_records, more_records], ignore_index=True)
print("Combined Records (before removing duplicates):")
print(combined)
# Method 1: Drop duplicates keeping the last occurrence
no_duplicates_last = combined.drop_duplicates(subset='time', keep='last')
print("\nAfter Dropping Duplicates (keep='last'):")
print(no_duplicates_last)
# Method 2: Drop duplicates keeping the first occurrence
no_duplicates_first = combined.drop_duplicates(subset='time', keep='first')
print("\nAfter Dropping Duplicates (keep='first'):")
print(no_duplicates_first)
# Method 3: Ensure 'time' is in a consistent data type, and drop duplicates
combined['time'] = combined['time'].astype('int64')
consistent_time = combined.drop_duplicates(subset='time', keep='last')
print("\nAfter Dropping Duplicates with 'time' as int64:")
print(consistent_time)
print("\n### Final Analysis ###")
print("Resulting DataFrame after sorting by 'time':")
final_result = consistent_time.sort_values(by='time').reset_index(drop=True)
print(final_result)
def test_reproduce_duplicate_issue(self):
# Simulating DataFrames like in your original test
# Time as epoch timestamps
existing_records = pd.DataFrame({
'market_id': [1, 1, 1],
'time': [1723161600000, 1723161900000, 1723162200000],
'open': [100, 101, 102],
'high': [110, 111, 112],
'low': [90, 91, 92],
'close': [105, 106, 107],
'volume': [1000, 1001, 1002]
})
more_records = pd.DataFrame({
'market_id': [1, 1, 1],
'time': [1723161600000, 1723161900000, 1723162500000], # Overlap at index 0 and 1
'open': [100, 101, 100],
'high': [110, 111, 110],
'low': [90, 91, 90],
'close': [105, 106, 105],
'volume': [1000, 1001, 1000]
})
# Run analysis
self.analyze_cache_update(existing_records, more_records)
def _test_get_records_since(self, set_cache=True, set_db=True, query_offset=None, num_rec=None, ex_details=None,
simulate_scenarios=None):
"""
Test the get_records_since() method by generating a table of simulated data,
inserting it into data and/or database, and then querying the records.
Parameters:
set_cache (bool): If True, the generated table is inserted into the cache.
set_db (bool): If True, the generated table is inserted into the database.
query_offset (int, optional): The offset in the timeframe units for the query.
num_rec (int, optional): The number of records to generate in the simulated table.
ex_details (list, optional): Exchange details to generate the data key.
simulate_scenarios (str, optional): The type of scenario to simulate. Options are:
- 'not_enough_data': The table data doesn't go far enough back.
- 'incomplete_data': The table doesn't have enough records to satisfy the query.
- 'missing_section': The table has missing records in the middle.
"""
print('Testing get_records_since() method:')
ex_details = ex_details or self.ex_details
key = f'{ex_details[0]}_{ex_details[1]}_{ex_details[2]}'
num_rec = num_rec or 12
table_timeframe = ex_details[1]
data_gen = DataGenerator(table_timeframe)
if simulate_scenarios == 'not_enough_data':
query_offset = (num_rec + 5) * data_gen.timeframe_amount
else:
query_offset = query_offset or (num_rec - 1) * data_gen.timeframe_amount
if simulate_scenarios == 'incomplete_data':
start_time_for_data = data_gen.x_time_ago(num_rec * data_gen.timeframe_amount)
num_rec = 5
else:
start_time_for_data = None
df_initial = data_gen.create_table(num_rec, start=start_time_for_data)
if simulate_scenarios == 'missing_section':
df_initial = data_gen.generate_missing_section(df_initial, drop_start=2, drop_end=5)
temp_df = df_initial.copy()
temp_df['time'] = pd.to_datetime(temp_df['time'], unit='ms')
print(f'Table Created:\n{temp_df}')
if set_cache:
print('Ensuring the cache exists and then inserting table into the cache.')
self.data.serialized_datacache_insert(data=df_initial, key=key, cache_name='candles')
if set_db:
print('Inserting table into the database.')
with SQLite(self.db_file) as con:
df_initial.to_sql(key, con, if_exists='replace', index=False)
start_datetime = data_gen.x_time_ago(query_offset)
if start_datetime.tzinfo is None:
start_datetime = start_datetime.replace(tzinfo=dt.timezone.utc)
query_end_time = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc)
print(f'Requesting records from {start_datetime} to {query_end_time}')
result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details)
expected = df_initial[df_initial['time'] >= data_gen.unix_time_millis(start_datetime)].reset_index(
drop=True)
temp_df = expected.copy()
temp_df['time'] = pd.to_datetime(temp_df['time'], unit='ms')
print(f'Expected table:\n{temp_df}')
temp_df = result.copy()
temp_df['time'] = pd.to_datetime(temp_df['time'], unit='ms')
print(f'Resulting table:\n{temp_df}')
if simulate_scenarios in ['not_enough_data', 'incomplete_data', 'missing_section']:
assert result.shape[0] > expected.shape[
0], "Result has fewer or equal rows compared to the incomplete data."
print("\nThe returned DataFrame has filled in the missing data!")
else:
assert result.shape == expected.shape, f"Shape mismatch: {result.shape} vs {expected.shape}"
pd.testing.assert_series_equal(result['time'], expected['time'], check_dtype=False)
print("\nThe DataFrames have the same shape and the 'time' columns match.")
oldest_timestamp = pd.to_datetime(result['time'].min(), unit='ms').tz_localize('UTC')
time_diff = oldest_timestamp - start_datetime
max_allowed_time_diff = dt.timedelta(**{data_gen.timeframe_unit: data_gen.timeframe_amount})
assert dt.timedelta(0) <= time_diff <= max_allowed_time_diff, \
f"Oldest timestamp {oldest_timestamp} is not within " \
f"{data_gen.timeframe_amount} {data_gen.timeframe_unit} of {start_datetime}"
print(f'The first timestamp is {time_diff} from {start_datetime}')
newest_timestamp = pd.to_datetime(result['time'].max(), unit='ms').tz_localize('UTC')
time_diff_end = abs(query_end_time - newest_timestamp)
assert dt.timedelta(0) <= time_diff_end <= max_allowed_time_diff, \
f"Newest timestamp {newest_timestamp} is not within {data_gen.timeframe_amount} " \
f"{data_gen.timeframe_unit} of {query_end_time}"
print(f'The last timestamp is {time_diff_end} from {query_end_time}')
print(' - Fetch records within the specified time range passed.')
def test_get_records_since(self):
self.connect_exchanges()
self.set_up_database()
self.load_test_data()
print('\nTest get_records_since with records set in data')
self._test_get_records_since()
print('\nTest get_records_since with records not in data')
self._test_get_records_since(set_cache=False)
print('\nTest get_records_since with records not in database')
self._test_get_records_since(set_cache=False, set_db=False)
print('\nTest get_records_since with a different timeframe')
self._test_get_records_since(query_offset=None, num_rec=None,
ex_details=['BTC/USD', '15m', 'binance', 'test_guy'])
print('\nTest get_records_since where data does not go far enough back')
self._test_get_records_since(simulate_scenarios='not_enough_data')
print('\nTest get_records_since with incomplete data')
self._test_get_records_since(simulate_scenarios='incomplete_data')
print('\nTest get_records_since with missing section in data')
self._test_get_records_since(simulate_scenarios='missing_section')
def test_other_timeframes(self):
print('\nTest get_records_since with a different timeframe')
ex_details = ['BTC/USD', '15m', 'binance', 'test_guy']
start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=2)
# Query the records since the calculated start time.
result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details)
last_record_time = pd.to_datetime(result['time'].max(), unit='ms').tz_localize('UTC')
assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=15.1)
print('\nTest get_records_since with a different timeframe')
ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=1)
# Query the records since the calculated start time.
result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details)
last_record_time = pd.to_datetime(result['time'].max(), unit='ms').tz_localize('UTC')
assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=5.1)
print('\nTest get_records_since with a different timeframe')
ex_details = ['BTC/USD', '4h', 'binance', 'test_guy']
start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=12)
# Query the records since the calculated start time.
result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details)
last_record_time = pd.to_datetime(result['time'].max(), unit='ms').tz_localize('UTC')
assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=4.1)
def test_populate_db(self):
print('Testing _populate_db() method:')
# Create a table of candle records.
data_gen = DataGenerator(self.ex_details[1])
data = data_gen.create_table(num_rec=5)
self.data._populate_db(ex_details=self.ex_details, data=data)
with SQLite(self.db_file) as con:
result = pd.read_sql(f'SELECT * FROM "{self.key}"', con)
self.assertFalse(result.empty)
print(' - Populate database with data passed.')
def test_fetch_candles_from_exchange(self):
print('Testing _fetch_candles_from_exchange() method:')
# Define start and end times for the data fetch
start_time = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc) - dt.timedelta(days=1)
end_time = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc)
# Fetch the candles from the exchange using the method
result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h',
exchange_name='binance', user_name='test_guy',
start_datetime=start_time, end_datetime=end_time)
# Validate that the result is a DataFrame
self.assertIsInstance(result, pd.DataFrame)
# Validate that the DataFrame is not empty
self.assertFalse(result.empty, "The DataFrame returned from the exchange is empty.")
# Ensure that the 'time' column exists in the DataFrame
self.assertIn('time', result.columns, "'time' column is missing in the result DataFrame.")
# Check if the DataFrame contains valid timestamps within the specified range
min_time = pd.to_datetime(result['time'].min(), unit='ms').tz_localize('UTC')
max_time = pd.to_datetime(result['time'].max(), unit='ms').tz_localize('UTC')
self.assertTrue(start_time <= min_time <= end_time, f"Data starts outside the expected range: {min_time}")
self.assertTrue(start_time <= max_time <= end_time, f"Data ends outside the expected range: {max_time}")
print(' - Fetch candle data from exchange passed.')
def test_remove_row(self):
# Step 1: Create the cache and insert data
self.data.create_cache('users', cache_type='table') # Create 'users' cache for this test
# Insert test data into the cache and database
df = pd.DataFrame({
'user_name': ['Alice', 'Bob', 'Charlie'],
'age': [30, 25, 40],
'users_data': ['data1', 'data2', 'data3'],
'data': ['info1', 'info2', 'info3'],
'password': ['pass1', 'pass2', 'pass3']
})
self.data.insert_df_into_datacache(df=df, cache_name="users", skip_cache=False)
# Scenario 1: Remove a row from both cache and database
filter_vals = [('user_name', 'Bob')]
self.data.remove_row_from_datacache(cache_name="users", filter_vals=filter_vals, remove_from_db=True)
# Verify the row was removed from the cache
cache = self.data.get_cache('users')
cached_data = cache.get_all_items()
self.assertEqual(len(cached_data), 2) # Only 2 rows should remain
self.assertNotIn('Bob', cached_data['user_name'].values) # Ensure 'Bob' is not in the cache
# Verify the row was removed from the database
with SQLite(db_file=self.db_file) as con:
remaining_users = pd.read_sql_query("SELECT * FROM users", con)
self.assertEqual(len(remaining_users), 2) # Ensure 2 rows remain in the database
self.assertNotIn('Bob', remaining_users['user_name'].values) # Ensure 'Bob' is not in the database
# Scenario 2: Remove a row from the cache only (not from the database)
filter_vals = [('user_name', 'Charlie')]
self.data.remove_row_from_datacache(cache_name="users", filter_vals=filter_vals, remove_from_db=False)
# Verify the row was removed from the cache
cached_data = cache.get_all_items()
self.assertEqual(len(cached_data), 1) # Only 1 row should remain in the cache
self.assertNotIn('Charlie', cached_data['user_name'].values) # Ensure 'Charlie' is not in the cache
# Verify the row still exists in the database
with SQLite(db_file=self.db_file) as con:
remaining_users = pd.read_sql_query("SELECT * FROM users", con)
self.assertEqual(len(remaining_users), 2) # Ensure Charlie is still in the database
self.assertIn('Charlie', remaining_users['user_name'].values) # Charlie should still exist in the database
# Scenario 3: Try removing from a non-existing cache (expecting KeyError)
filter_vals = [('user_name', 'Bob')]
with self.assertRaises(KeyError) as context:
self.data.remove_row_from_datacache(cache_name="non_existing_cache", filter_vals=filter_vals, remove_from_db=True)
self.assertEqual(context.exception.args[0], "Cache: non_existing_cache, does not exist.")
# Scenario 4: Invalid filter_vals format (expecting ValueError)
invalid_filter_vals = 'invalid_filter' # Not a list of tuples
with self.assertRaises(ValueError) as context:
self.data.remove_row_from_datacache(cache_name="users", filter_vals=invalid_filter_vals, remove_from_db=True)
self.assertEqual(str(context.exception), "filter_vals must be a list of tuples (column, value)")
def test_timeframe_to_timedelta(self):
print('Testing timeframe_to_timedelta() function:')
result = timeframe_to_timedelta('2h')
expected = pd.Timedelta(hours=2)
self.assertEqual(result, expected, "Failed to convert '2h' to Timedelta")
result = timeframe_to_timedelta('5m')
expected = pd.Timedelta(minutes=5)
self.assertEqual(result, expected, "Failed to convert '5m' to Timedelta")
result = timeframe_to_timedelta('1d')
expected = pd.Timedelta(days=1)
self.assertEqual(result, expected, "Failed to convert '1d' to Timedelta")
result = timeframe_to_timedelta('3M')
expected = pd.DateOffset(months=3)
self.assertEqual(result, expected, "Failed to convert '3M' to DateOffset")
result = timeframe_to_timedelta('1Y')
expected = pd.DateOffset(years=1)
self.assertEqual(result, expected, "Failed to convert '1Y' to DateOffset")
with self.assertRaises(ValueError):
timeframe_to_timedelta('5x')
print(' - All timeframe_to_timedelta() tests passed.')
def test_estimate_record_count(self):
print('Testing estimate_record_count() function:')
# Test with '1h' timeframe (24 records expected)
start_time = dt.datetime(2023, 8, 1, 0, 0, 0, tzinfo=dt.timezone.utc)
end_time = dt.datetime(2023, 8, 2, 0, 0, 0, tzinfo=dt.timezone.utc)
result = estimate_record_count(start_time, end_time, '1h')
self.assertEqual(result, 24, "Failed to estimate record count for 1h timeframe")
# Test with '1d' timeframe (1 record expected)
result = estimate_record_count(start_time, end_time, '1d')
self.assertEqual(result, 1, "Failed to estimate record count for 1d timeframe")
# Test with '1h' timeframe and timestamps in milliseconds
start_time_ms = int(start_time.timestamp() * 1000) # Convert to milliseconds
end_time_ms = int(end_time.timestamp() * 1000) # Convert to milliseconds
result = estimate_record_count(start_time_ms, end_time_ms, '1h')
self.assertEqual(result, 24, "Failed to estimate record count for 1h timeframe with milliseconds")
# Test with '5m' timeframe and Unix timestamps in milliseconds
start_time_ms = 1672531200000 # Equivalent to '2023-01-01 00:00:00 UTC'
end_time_ms = 1672534800000 # Equivalent to '2023-01-01 01:00:00 UTC'
result = estimate_record_count(start_time_ms, end_time_ms, '5m')
self.assertEqual(result, 12, "Failed to estimate record count for 5m timeframe with milliseconds")
# Test with '5m' timeframe (12 records expected for 1-hour duration)
start_time = dt.datetime(2023, 1, 1, 0, 0, tzinfo=dt.timezone.utc)
end_time = dt.datetime(2023, 1, 1, 1, 0, tzinfo=dt.timezone.utc)
result = estimate_record_count(start_time, end_time, '5m')
self.assertEqual(result, 12, "Failed to estimate record count for 5m timeframe")
# Test with '1M' (3 records expected for 3 months)
start_time = dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc)
end_time = dt.datetime(2023, 4, 1, tzinfo=dt.timezone.utc)
result = estimate_record_count(start_time, end_time, '1M')
self.assertEqual(result, 3, "Failed to estimate record count for 1M timeframe")
# Test with invalid timeframe
with self.assertRaises(ValueError):
estimate_record_count(start_time, end_time, 'xyz') # Invalid timeframe
# Test with invalid start_time passed in
with self.assertRaises(ValueError):
estimate_record_count("invalid_start", end_time, '1h')
# Cross-Year Transition (Months)
start_time = dt.datetime(2022, 12, 1, tzinfo=dt.timezone.utc)
end_time = dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc)
result = estimate_record_count(start_time, end_time, '1M')
self.assertEqual(result, 1, "Failed to estimate record count for month across years")
# Leap Year (Months)
start_time = dt.datetime(2020, 2, 1, tzinfo=dt.timezone.utc)
end_time = dt.datetime(2021, 2, 1, tzinfo=dt.timezone.utc)
result = estimate_record_count(start_time, end_time, '1M')
self.assertEqual(result, 12, "Failed to estimate record count for months during leap year")
# Sub-Minute Timeframes (e.g., 30 seconds)
start_time = dt.datetime(2023, 1, 1, 0, 0, tzinfo=dt.timezone.utc)
end_time = dt.datetime(2023, 1, 1, 0, 1, tzinfo=dt.timezone.utc)
result = estimate_record_count(start_time, end_time, '30s')
self.assertEqual(result, 2, "Failed to estimate record count for 30 seconds timeframe")
# Different Timezones
start_time = dt.datetime(2023, 1, 1, 0, 0, tzinfo=dt.timezone(dt.timedelta(hours=5))) # UTC+5
end_time = dt.datetime(2023, 1, 1, 1, 0, tzinfo=dt.timezone.utc) # UTC
result = estimate_record_count(start_time, end_time, '1h')
self.assertEqual(result, 6,
"Failed to estimate record count for different timezones") # Expect 6 records, not 1
# Test with zero-length interval (should return 0)
result = estimate_record_count(start_time, start_time, '1h')
self.assertEqual(result, 0, "Failed to return 0 for zero-length interval")
# Test with negative interval (end_time earlier than start_time, should return 0)
result = estimate_record_count(end_time, start_time, '1h')
self.assertEqual(result, 0, "Failed to return 0 for negative interval")
# Test with small interval compared to timeframe (should return 0)
start_time = dt.datetime(2023, 8, 1, 0, 0, tzinfo=dt.timezone.utc)
end_time = dt.datetime(2023, 8, 1, 0, 30, tzinfo=dt.timezone.utc) # 30 minutes
result = estimate_record_count(start_time, end_time, '1h')
self.assertEqual(result, 0, "Failed to return 0 for small interval compared to timeframe")
print(' - All estimate_record_count() tests passed.')
def test_get_or_fetch_rows(self):
# Create mock DataFrames for different users
df1 = pd.DataFrame({
'user_name': ['billy'],
'password': ['1234'],
'exchanges': [['ex1', 'ex2', 'ex3']]
})
df2 = pd.DataFrame({
'user_name': ['john'],
'password': ['5678'],
'exchanges': [['ex4', 'ex5', 'ex6']]
})
df3 = pd.DataFrame({
'user_name': ['alice'],
'password': ['91011'],
'exchanges': [['ex7', 'ex8', 'ex9']]
})
# Insert these DataFrames into the 'users' cache with row-based caching
self.data.create_cache('users', cache_type='row') # Assuming 'row' cache type for this test
self.data.serialized_datacache_insert(key='user_billy', data=df1, cache_name='users')
self.data.serialized_datacache_insert(key='user_john', data=df2, cache_name='users')
self.data.serialized_datacache_insert(key='user_alice', data=df3, cache_name='users')
print('Testing get_or_fetch_rows() method:')
# Fetch user directly by key since this is a row-based cache
result = self.data.get_serialized_datacache(key='user_billy', cache_name='users')
self.assertIsInstance(result, pd.DataFrame, "Failed to fetch DataFrame from cache")
self.assertFalse(result.empty, "The fetched DataFrame is empty")
self.assertEqual(result.iloc[0]['password'], '1234', "Incorrect data fetched from cache")
# Fetch another user by key
result = self.data.get_serialized_datacache(key='user_john', cache_name='users')
self.assertIsInstance(result, pd.DataFrame, "Failed to fetch DataFrame from cache")
self.assertFalse(result.empty, "The fetched DataFrame is empty")
self.assertEqual(result.iloc[0]['password'], '5678', "Incorrect data fetched from cache")
# Test fetching a user that does not exist in the cache
result = self.data.get_serialized_datacache(key='non_existent_user', cache_name='users')
# Check if result is None (indicating that no data was found)
self.assertIsNone(result, "Expected result to be None for a non-existent user")
print(' - Fetching rows from cache passed.')
def test_is_attr_taken(self):
# Create a cache named 'users'
user_cache = self.data.create_cache('users', cache_type='table')
# Create mock data for three users
user_data_1 = pd.DataFrame({
'user_name': ['billy'],
'password': ['1234'],
'exchanges': [['ex1', 'ex2', 'ex3']]
})
user_data_2 = pd.DataFrame({
'user_name': ['john'],
'password': ['5678'],
'exchanges': [['ex1', 'ex2', 'ex4']]
})
user_data_3 = pd.DataFrame({
'user_name': ['alice'],
'password': ['abcd'],
'exchanges': [['ex5', 'ex6', 'ex7']]
})
# Insert mock data into the cache
self.data.serialized_datacache_insert(cache_name='users', data=user_data_1)
self.data.serialized_datacache_insert(cache_name='users', data=user_data_2)
self.data.serialized_datacache_insert(cache_name='users', data=user_data_3)
# Test when attribute value is taken
result_taken = user_cache.is_attr_taken('user_name', 'billy')
self.assertTrue(result_taken, "Expected 'billy' to be taken, but it was not.")
# Test when attribute value is not taken
result_not_taken = user_cache.is_attr_taken('user_name', 'charlie')
self.assertFalse(result_not_taken, "Expected 'charlie' not to be taken, but it was.")
def test_insert_df_row_based_cache(self):
self._test_insert_df(cache_type='row')
def test_insert_df_table_based_cache(self):
self._test_insert_df(cache_type='table')
def _test_insert_df(self, cache_type):
self.data.create_cache('users', cache_type=cache_type) # Create 'users' cache for this test
# Arrange: Create a simple DataFrame to insert
df = pd.DataFrame({
'user_name': ['Alice', 'Bob'],
'age': [30, 25],
'users_data': ['data1', 'data2'],
'data': ['info1', 'info2'],
'password': ['pass1', 'pass2']
})
# Ensure the users table exists in the database and clear any existing data
with SQLite(db_file=self.db_file) as con:
con.execute("DELETE FROM users;") # Clear existing data for clean testing
# Act: Insert the DataFrame into the 'users' table without skipping the cache
self.data.insert_df_into_datacache(df=df, cache_name="users", skip_cache=False)
# Assert: Verify the data was correctly inserted into the database
with SQLite(db_file=self.db_file) as con:
inserted_users = pd.read_sql_query("SELECT * FROM users", con)
self.assertEqual(len(inserted_users), 2) # Ensure both rows are inserted
self.assertEqual(inserted_users.iloc[0]['user_name'], 'Alice') # Verify first row data
self.assertEqual(inserted_users.iloc[1]['user_name'], 'Bob') # Verify second row data
# Verify cache behavior (RowBasedCache)
cache = self.data.get_cache('users')
if isinstance(cache, RowBasedCache):
# Check if each row is added to the cache (in row-based cache)
cached_user1 = cache.get_entry('Alice')
cached_user2 = cache.get_entry('Bob')
self.assertIsNotNone(cached_user1) # Ensure user 'Alice' is cached
self.assertIsNotNone(cached_user2) # Ensure user 'Bob' is cached
self.assertEqual(cached_user1.iloc[0]['user_name'], 'Alice') # Verify cache content for Alice
self.assertEqual(cached_user2.iloc[0]['user_name'], 'Bob') # Verify cache content for Bob
elif isinstance(cache, TableBasedCache):
# For TableBasedCache, check if the entire DataFrame is cached
cached_data = cache.get_all_items()
self.assertEqual(len(cached_data), 2) # Ensure both rows are cached
self.assertEqual(cached_data.iloc[0]['user_name'], 'Alice')
self.assertEqual(cached_data.iloc[1]['user_name'], 'Bob')
def test_insert_row(self):
print("Testing insert_row() method:")
# Define the cache name, columns, and values to insert
cache_name = 'users'
columns = ('user_name', 'age')
values = ('Alice', 30)
# Create the cache with a row-based cache type (if that's how your system works now)
user_cache = self.data.create_cache(cache_name, cache_type='row')
# Insert a row into the cache and database without skipping the cache
# Ensure 'key' is passed correctly, if needed (depending on how `insert_row` works now)
self.data.insert_row_into_datacache(cache_name=cache_name, columns=columns, values=values, key='1', skip_cache=False)
# Retrieve the inserted item from the cache using the correct method
result = user_cache.get_entry(key='1')
# Assert that the data in the cache matches what was inserted
self.assertIsNotNone(result, "No data found in the cache for the inserted ID.")
self.assertEqual(result.iloc[0]['user_name'], 'Alice',
"The name in the cache doesn't match the inserted value.")
self.assertEqual(result.iloc[0]['age'], 30, "The age in the cache does not match the inserted value.")
# Now test with skipping the cache
print("Testing insert_row() with skip_cache=True")
# Insert another row into the database, this time skipping the cache
self.data.insert_row_into_datacache(cache_name=cache_name, columns=columns, values=('Bob', 40), key='2', skip_cache=True)
# Attempt to retrieve the newly inserted row from the cache
result_after_skip = user_cache.get_entry(key='2')
# Assert that no data is found in the cache for the new row
self.assertIsNone(result_after_skip, "Data should not have been cached when skip_cache=True.")
print(" - Insert row with and without caching passed all checks.")
def test_fill_data_holes(self):
print('Testing _fill_data_holes() method:')
# Create mock data with gaps
df = pd.DataFrame({
'time': [dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc).timestamp() * 1000,
dt.datetime(2023, 1, 1, 2, tzinfo=dt.timezone.utc).timestamp() * 1000,
dt.datetime(2023, 1, 1, 6, tzinfo=dt.timezone.utc).timestamp() * 1000,
dt.datetime(2023, 1, 1, 8, tzinfo=dt.timezone.utc).timestamp() * 1000,
dt.datetime(2023, 1, 1, 12, tzinfo=dt.timezone.utc).timestamp() * 1000]
})
# Call the method
result = self.data._fill_data_holes(records=df, interval='2h')
self.assertEqual(len(result), 7, "Data holes were not filled correctly.")
print(' - _fill_data_holes passed.')
def test_get_cache_item(self):
self.load_prerequisites()
# Case 1: Retrieve a stored Indicator instance (serialized and deserialized)
indicator = Indicator(name='SMA', indicator_type='SMA', properties={'period': 5})
# Create a row-based cache for indicators and store serialized Indicator data
self.data.create_cache('indicators', cache_type='row')
self.data.serialized_datacache_insert(key='indicator_key', data=indicator, cache_name='indicators')
# Retrieve the indicator and check for deserialization
stored_data = self.data.get_serialized_datacache('indicator_key', cache_name='indicators')
self.assertIsInstance(stored_data, Indicator, "Failed to retrieve and deserialize the Indicator instance")
# Case 2: Retrieve non-Indicator data (e.g., dict)
data_dict = {'key': 'value'}
# Create a cache for generic data (row-based)
self.data.create_cache('default_cache', cache_type='row')
# Store a dictionary
self.data.serialized_datacache_insert(key='dict_key', data=data_dict, cache_name='default_cache')
# Retrieve and check if the data matches the original dict
stored_data = self.data.get_serialized_datacache('dict_key', cache_name='default_cache')
self.assertEqual(stored_data, data_dict, "Failed to retrieve non-Indicator data correctly")
# Case 3: Retrieve a list stored in the cache
data_list = [1, 2, 3, 4, 5]
# Store a list in row-based cache
self.data.serialized_datacache_insert(key='list_key', data=data_list, cache_name='default_cache')
# Retrieve and check if the data matches the original list
stored_data = self.data.get_serialized_datacache('list_key', cache_name='default_cache')
self.assertEqual(stored_data, data_list, "Failed to retrieve list data correctly")
# Case 4: Retrieve a DataFrame stored in the cache (Table-Based Cache)
data_df = pd.DataFrame({
'column1': [10, 20, 30],
'column2': ['A', 'B', 'C']
})
# Create a table-based cache
self.data.create_cache('table_cache', cache_type='table')
# Store a DataFrame in table-based cache
self.data.serialized_datacache_insert(key='testkey', data=data_df, cache_name='table_cache')
# Retrieve and check if the DataFrame matches the original
stored_data = self.data.get_serialized_datacache(key='testkey', cache_name='table_cache')
pd.testing.assert_frame_equal(stored_data, data_df)
# Case 5: Attempt to retrieve a non-existent key
non_existent = self.data.get_serialized_datacache('non_existent_key', cache_name='default_cache')
self.assertIsNone(non_existent, "Expected None for non-existent cache key")
print(" - All get_cache_item tests passed.")
def test_set_user_indicator_properties(self):
# Case 1: Store user-specific display properties
user_id = 'user123'
indicator_type = 'SMA'
symbol = 'AAPL'
timeframe = '1h'
exchange_name = 'NYSE'
display_properties = {'color': 'blue', 'line_width': 2}
# Call the method to set properties
self.data.set_user_indicator_properties(user_id, indicator_type, symbol, timeframe, exchange_name,
display_properties)
# Construct the cache key manually for validation
user_cache_key = f"user_{user_id}_{indicator_type}_{symbol}_{timeframe}_{exchange_name}"
# Retrieve the stored properties
stored_properties = self.data.get_serialized_datacache(user_cache_key, cache_name='user_display_properties')
# Check if the properties were stored correctly
self.assertEqual(stored_properties, display_properties, "Failed to store user-specific display properties")
# Case 2: Update existing user-specific properties
updated_properties = {'color': 'red', 'line_width': 3}
# Update the properties
self.data.set_user_indicator_properties(user_id, indicator_type, symbol, timeframe, exchange_name,
updated_properties)
# Retrieve the updated properties
updated_stored_properties = self.data.get_serialized_datacache(user_cache_key, cache_name='user_display_properties')
# Check if the properties were updated correctly
self.assertEqual(updated_stored_properties, updated_properties,
"Failed to update user-specific display properties")
# Case 3: Handle invalid user properties (e.g., non-dict input)
with self.assertRaises(ValueError):
self.data.set_user_indicator_properties(user_id, indicator_type, symbol, timeframe, exchange_name,
"invalid_properties")
def test_get_user_indicator_properties(self):
# Case 1: Retrieve existing user-specific display properties
user_id = 'user123'
indicator_type = 'SMA'
symbol = 'AAPL'
timeframe = '1h'
exchange_name = 'NYSE'
display_properties = {'color': 'blue', 'line_width': 2}
# Set the properties first
self.data.set_user_indicator_properties(user_id, indicator_type, symbol, timeframe, exchange_name,
display_properties)
# Retrieve the properties
retrieved_properties = self.data.get_user_indicator_properties(user_id, indicator_type, symbol, timeframe,
exchange_name)
self.assertEqual(retrieved_properties, display_properties,
"Failed to retrieve user-specific display properties")
# Case 2: Handle missing key (should return None)
missing_properties = self.data.get_user_indicator_properties('nonexistent_user', indicator_type, symbol,
timeframe, exchange_name)
self.assertIsNone(missing_properties, "Expected None for missing user-specific display properties")
# Case 3: Invalid argument handling
with self.assertRaises(TypeError):
self.data.get_user_indicator_properties(123, indicator_type, symbol, timeframe,
exchange_name) # Invalid user_id type
def test_set_cache_item(self):
data_cache = self.data
# -------------------------
# Row-Based Cache Test Cases
# -------------------------
# Case 1: Store and retrieve an item in a RowBasedCache with a key
data_cache.create_cache('row_cache', cache_type='row') # Create row-based cache
key = 'row_key'
data = {'some': 'data'}
data_cache.serialized_datacache_insert(cache_name='row_cache', data=data, key=key)
cached_item = data_cache.get_serialized_datacache(key, cache_name='row_cache')
self.assertEqual(cached_item, data, "Failed to store and retrieve data in RowBasedCache")
# Case 2: Store and retrieve an Indicator instance (serialization)
indicator = Indicator(name='SMA', indicator_type='SMA', properties={'period': 5})
data_cache.serialized_datacache_insert(cache_name='row_cache', data=indicator, key='indicator_key')
cached_indicator = data_cache.get_serialized_datacache('indicator_key', cache_name='row_cache')
# Assert that the data was correctly serialized and deserialized
self.assertIsInstance(pickle.loads(cached_indicator), Indicator, "Failed to deserialize Indicator instance")
# Case 3: Prevent overwriting an existing key if do_not_overwrite=True
new_data = {'new': 'data'}
data_cache.serialized_datacache_insert(cache_name='row_cache', data=new_data, key=key, do_not_overwrite=True)
cached_item_after = data_cache.get_serialized_datacache(key, cache_name='row_cache')
self.assertEqual(cached_item_after, data, "Overwriting occurred when it should have been prevented")
# Case 4: Raise ValueError if key is None in RowBasedCache
with self.assertRaises(ValueError, msg="RowBasedCache requires a key to store the data."):
data_cache.serialized_datacache_insert(cache_name='row_cache', data=data, key=None)
# -------------------------
# Table-Based Cache Test Cases
# -------------------------
# Case 5: Store and retrieve a DataFrame in a TableBasedCache
data_cache.create_cache('table_cache', cache_type='table') # Create table-based cache
df = pd.DataFrame({'col1': [1, 2], 'col2': ['A', 'B']})
data_cache.serialized_datacache_insert(cache_name='table_cache', data=df, key='table_key')
cached_df = data_cache.get_serialized_datacache('table_key', cache_name='table_cache')
pd.testing.assert_frame_equal(cached_df, df, "Failed to store and retrieve DataFrame in TableBasedCache")
# Case 6: Prevent overwriting an existing key if do_not_overwrite=True in TableBasedCache
new_df = pd.DataFrame({'col1': [3, 4], 'col2': ['C', 'D']})
data_cache.serialized_datacache_insert(cache_name='table_cache', data=new_df, key='table_key', do_not_overwrite=True)
cached_df_after = data_cache.get_serialized_datacache('table_key', cache_name='table_cache')
pd.testing.assert_frame_equal(cached_df_after, df, "Overwriting occurred when it should have been prevented")
# Case 7: Raise ValueError if non-DataFrame data is provided in TableBasedCache
with self.assertRaises(ValueError, msg="TableBasedCache can only store DataFrames."):
data_cache.serialized_datacache_insert(cache_name='table_cache', data={'not': 'a dataframe'}, key='table_key')
# -------------------------
# Expiration Handling Test Case
# -------------------------
# Case 8: Store an item with an expiration time (RowBasedCache)
key = 'expiring_key'
data = {'some': 'data'}
expire_delta = dt.timedelta(seconds=5)
data_cache.serialized_datacache_insert(cache_name='row_cache', data=data, key=key, expire_delta=expire_delta)
cached_item = data_cache.get_serialized_datacache(key, cache_name='row_cache')
self.assertEqual(cached_item, data, "Failed to store and retrieve data with expiration")
# Wait for expiration to occur (ensure data is removed after expiration)
import time
time.sleep(6)
expired_item = data_cache.get_serialized_datacache(key, cache_name='row_cache')
self.assertIsNone(expired_item, "Data was not removed after expiration time")
# -------------------------
# Invalid Cache Type Test Case
# -------------------------
# Case 9: Raise ValueError if unsupported cache type is provided
with self.assertRaises(KeyError, msg="Unsupported cache type for 'unsupported_cache'"):
data_cache.serialized_datacache_insert(cache_name='unsupported_cache', data={'some': 'data'}, key='some_key')
def test_calculate_and_cache_indicator(self):
# Testing the calculation and caching of an indicator through DataCache (which includes IndicatorCache
# functionality)
user_properties = {'color_line_1': 'blue', 'thickness_line_1': 2}
ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
# Define the time range for the calculation
start_datetime = dt.datetime(2023, 9, 1, 0, 0, 0, tzinfo=dt.timezone.utc)
end_datetime = dt.datetime(2023, 9, 2, 0, 0, 0, tzinfo=dt.timezone.utc)
# Simulate calculating an indicator and caching it through DataCache
result = self.data.calculate_indicator(
user_name='test_guy',
symbol=ex_details[0],
timeframe=ex_details[1],
exchange_name=ex_details[2],
indicator_type='SMA', # Type of indicator
start_datetime=start_datetime,
end_datetime=end_datetime,
properties={'period': 5} # Add the necessary indicator properties like period
)
# Ensure that result is not None
self.assertIsNotNone(result, "Indicator calculation returned None.")
def test_calculate_indicator_multiple_users(self):
"""
Test that the calculate_indicator method handles multiple users' requests with different properties.
"""
ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
user1_properties = {'color': 'blue', 'thickness': 2}
user2_properties = {'color': 'red', 'thickness': 1}
# Set user-specific properties
self.data.set_user_indicator_properties('user_1', 'SMA', 'BTC/USD', '5m', 'binance', user1_properties)
self.data.set_user_indicator_properties('user_2', 'SMA', 'BTC/USD', '5m', 'binance', user2_properties)
# User 1 calculates the SMA indicator
result_user1 = self.data.calculate_indicator(
user_name='user_1',
symbol='BTC/USD',
timeframe='5m',
exchange_name='binance',
indicator_type='SMA',
start_datetime=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc),
end_datetime=dt.datetime(2023, 1, 2, tzinfo=dt.timezone.utc),
properties={'period': 5}
)
# User 2 calculates the same SMA indicator but with different display properties
result_user2 = self.data.calculate_indicator(
user_name='user_2',
symbol='BTC/USD',
timeframe='5m',
exchange_name='binance',
indicator_type='SMA',
start_datetime=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc),
end_datetime=dt.datetime(2023, 1, 2, tzinfo=dt.timezone.utc),
properties={'period': 5}
)
# Assert that the calculation data is the same
self.assertEqual(result_user1['calculation_data'], result_user2['calculation_data'])
# Assert that the display properties are different
self.assertNotEqual(result_user1['display_properties'], result_user2['display_properties'])
# Assert that the correct display properties are returned
self.assertEqual(result_user1['display_properties']['color'], 'blue')
self.assertEqual(result_user2['display_properties']['color'], 'red')
def test_calculate_indicator_cache_retrieval(self):
"""
Test that cached data is retrieved efficiently without recalculating when the same request is made.
"""
ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
properties = {'period': 5}
cache_key = 'BTC/USD_5m_binance_SMA_5'
# First calculation (should store result in cache)
result_first = self.data.calculate_indicator(
user_name='user_1',
symbol='BTC/USD',
timeframe='5m',
exchange_name='binance',
indicator_type='SMA',
start_datetime=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc),
end_datetime=dt.datetime(2023, 1, 2, tzinfo=dt.timezone.utc),
properties=properties
)
# Check if the data was cached after the first calculation
cached_data = self.data.get_serialized_datacache(cache_key, cache_name='indicator_data')
print(f"Cached Data after first calculation: {cached_data}")
# Ensure the data was cached correctly
self.assertIsNotNone(cached_data, "The first calculation did not cache the result properly.")
# Second calculation with the same parameters (should retrieve from cache)
with self.assertLogs(level='INFO') as log:
result_second = self.data.calculate_indicator(
user_name='user_1',
symbol='BTC/USD',
timeframe='5m',
exchange_name='binance',
indicator_type='SMA',
start_datetime=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc),
end_datetime=dt.datetime(2023, 1, 2, tzinfo=dt.timezone.utc),
properties=properties
)
# Verify the log message for cache retrieval
self.assertTrue(
any(f"DataFrame retrieved from cache for key: {cache_key}" in message for message in log.output),
f"Cache retrieval log message not found for key: {cache_key}"
)
def test_calculate_indicator_partial_cache(self):
"""
Test handling of partial cache where some of the requested data is already cached,
and the rest needs to be fetched.
"""
ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
properties = {'period': 5}
# Simulate cache for part of the range (manual setup, no call to `get_records_since`)
cached_data = pd.DataFrame({
'time': pd.date_range(start="2023-01-01", periods=144, freq='5min', tz=dt.timezone.utc),
# Cached half a day of data
'value': [16500 + i for i in range(144)]
})
# Generate cache key with correct format
cache_key = self.data._make_indicator_key('BTC/USD', '5m', 'binance', 'SMA', properties['period'])
# Store the cached data as DataFrame (no need for to_dict('records'))
self.data.serialized_datacache_insert(cache_name='indicator_data', data=cached_data, key=cache_key)
# Print cached data to inspect its range
print("Cached data time range:")
print(f"Min cached time: {cached_data['time'].min()}")
print(f"Max cached time: {cached_data['time'].max()}")
# Now request a range that partially overlaps the cached data
result = self.data.calculate_indicator(
user_name='user_1',
symbol='BTC/USD',
timeframe='5m',
exchange_name='binance',
indicator_type='SMA',
start_datetime=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc),
end_datetime=dt.datetime(2023, 1, 2, tzinfo=dt.timezone.utc),
properties=properties
)
# Convert the result into a DataFrame
result_df = pd.DataFrame(result['calculation_data'])
# Convert the 'time' column from Unix timestamp (ms) back to datetime with timezone
result_df['time'] = pd.to_datetime(result_df['time'], unit='ms', utc=True)
# Debugging: print the full result to inspect the time range
print("Result data time range:")
print(f"Min result time: {result_df['time'].min()}")
print(f"Max result time: {result_df['time'].max()}")
# Now you can safely find the min and max values
min_time = result_df['time'].min()
max_time = result_df['time'].max()
# Debugging print statements to confirm the values
print(f"Min time in result: {min_time}")
print(f"Max time in result: {max_time}")
# Assert that the min and max time in the result cover the full range from the cache and new data
self.assertEqual(min_time, pd.Timestamp("2023-01-01 00:00:00", tz=dt.timezone.utc))
self.assertEqual(max_time, pd.Timestamp("2023-01-02 00:00:00", tz=dt.timezone.utc))
def test_calculate_indicator_no_data(self):
"""
Test that the indicator calculation handles cases where no data is available for the requested range.
"""
ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
properties = {'period': 5}
# Request data for a period where no data exists
result = self.data.calculate_indicator(
user_name='user_1',
symbol='BTC/USD',
timeframe='5m',
exchange_name='binance',
indicator_type='SMA',
start_datetime=dt.datetime(1900, 1, 1, tzinfo=dt.timezone.utc),
end_datetime=dt.datetime(1900, 1, 2, tzinfo=dt.timezone.utc),
properties=properties
)
# Ensure no calculation data is returned
self.assertEqual(len(result['calculation_data']), 0)
if __name__ == '__main__':
unittest.main()