2012 lines
96 KiB
Python
2012 lines
96 KiB
Python
import pickle
|
|
import time
|
|
import pytz
|
|
|
|
|
|
from DataCache_v3 import DataCache, timeframe_to_timedelta, estimate_record_count, \
|
|
SnapshotDataCache, CacheManager, RowBasedCache, TableBasedCache
|
|
from ExchangeInterface import ExchangeInterface
|
|
import unittest
|
|
import pandas as pd
|
|
import datetime as dt
|
|
import os
|
|
from Database import SQLite, Database
|
|
|
|
import logging
|
|
|
|
from indicators import Indicator
|
|
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
|
|
|
|
class DataGenerator:
|
|
def __init__(self, timeframe_str):
|
|
"""
|
|
Initialize the DataGenerator with a timeframe string like '2h', '5m', '1d', '1w', '1M', or '1y'.
|
|
"""
|
|
# Initialize attributes with placeholder values
|
|
self.timeframe_amount = None
|
|
self.timeframe_unit = None
|
|
# Set the actual timeframe
|
|
self.set_timeframe(timeframe_str)
|
|
|
|
def set_timeframe(self, timeframe_str):
|
|
"""
|
|
Set the timeframe unit and amount based on a string like '2h', '5m', '1d', '1w', '1M', or '1y'.
|
|
"""
|
|
self.timeframe_amount = int(timeframe_str[:-1])
|
|
unit = timeframe_str[-1]
|
|
|
|
if unit == 's':
|
|
self.timeframe_unit = 'seconds'
|
|
elif unit == 'm':
|
|
self.timeframe_unit = 'minutes'
|
|
elif unit == 'h':
|
|
self.timeframe_unit = 'hours'
|
|
elif unit == 'd':
|
|
self.timeframe_unit = 'days'
|
|
elif unit == 'w':
|
|
self.timeframe_unit = 'weeks'
|
|
elif unit == 'M':
|
|
self.timeframe_unit = 'months'
|
|
elif unit == 'Y':
|
|
self.timeframe_unit = 'years'
|
|
else:
|
|
raise ValueError(
|
|
"Unsupported timeframe unit. Use 's,m,h,d,w,M,Y'.")
|
|
|
|
def create_table(self, num_rec=None, start=None, end=None):
|
|
"""
|
|
Create a table with simulated data. If both start and end are provided, num_rec is derived from the interval.
|
|
If neither are provided the table will have num_rec and end at the current time.
|
|
|
|
Parameters:
|
|
num_rec (int, optional): The number of records to generate.
|
|
start (datetime, optional): The start time for the first record.
|
|
end (datetime, optional): The end time for the last record.
|
|
|
|
Returns:
|
|
pd.DataFrame: A DataFrame with the simulated data.
|
|
"""
|
|
# Ensure provided datetime parameters are timezone aware
|
|
if start and start.tzinfo is None:
|
|
raise ValueError('start datetime must be timezone aware.')
|
|
if end and end.tzinfo is None:
|
|
raise ValueError('end datetime must be timezone aware.')
|
|
|
|
# If neither start nor end are provided.
|
|
if start is None and end is None:
|
|
end = dt.datetime.now(dt.timezone.utc)
|
|
if num_rec is None:
|
|
raise ValueError("num_rec must be provided if both start and end are not specified.")
|
|
|
|
# If start and end are provided.
|
|
if start is not None and end is not None:
|
|
total_duration = (end - start).total_seconds()
|
|
interval_seconds = self.timeframe_amount * self._get_seconds_per_unit(self.timeframe_unit)
|
|
num_rec = int(total_duration // interval_seconds) + 1
|
|
|
|
# If only end is provided.
|
|
if end is not None and start is None:
|
|
if num_rec is None:
|
|
raise ValueError("num_rec must be provided if both start and end are not specified.")
|
|
interval_seconds = self.timeframe_amount * self._get_seconds_per_unit(self.timeframe_unit)
|
|
start = end - dt.timedelta(seconds=(num_rec - 1) * interval_seconds)
|
|
start = start.replace(tzinfo=pytz.utc)
|
|
|
|
# Ensure start is aligned to the timeframe interval
|
|
start = self.round_down_datetime(start, self.timeframe_unit[0], self.timeframe_amount)
|
|
|
|
# Generate times
|
|
times = [self.unix_time_millis(start + self._delta(i)) for i in range(num_rec)]
|
|
|
|
df = pd.DataFrame({
|
|
'market_id': 1,
|
|
'time': times,
|
|
'open': [100 + i for i in range(num_rec)],
|
|
'high': [110 + i for i in range(num_rec)],
|
|
'low': [90 + i for i in range(num_rec)],
|
|
'close': [105 + i for i in range(num_rec)],
|
|
'volume': [1000 + i for i in range(num_rec)]
|
|
})
|
|
|
|
return df
|
|
|
|
@staticmethod
|
|
def _get_seconds_per_unit(unit):
|
|
"""Helper method to convert timeframe units to seconds."""
|
|
units_in_seconds = {
|
|
'seconds': 1,
|
|
'minutes': 60,
|
|
'hours': 3600,
|
|
'days': 86400,
|
|
'weeks': 604800,
|
|
'months': 2592000, # Assuming 30 days per month
|
|
'years': 31536000 # Assuming 365 days per year
|
|
}
|
|
if unit not in units_in_seconds:
|
|
raise ValueError(f"Unsupported timeframe unit: {unit}")
|
|
return units_in_seconds[unit]
|
|
|
|
def generate_incomplete_data(self, query_offset, num_rec=5):
|
|
"""
|
|
Generate data that is incomplete, i.e., starts before the query but doesn't fully satisfy it.
|
|
"""
|
|
query_start_time = self.x_time_ago(query_offset)
|
|
start_time_for_data = self.get_start_time(query_start_time)
|
|
return self.create_table(num_rec, start=start_time_for_data)
|
|
|
|
@staticmethod
|
|
def generate_missing_section(df, drop_start=5, drop_end=8):
|
|
"""
|
|
Generate data with a missing section.
|
|
"""
|
|
df = df.drop(df.index[drop_start:drop_end]).reset_index(drop=True)
|
|
return df
|
|
|
|
def get_start_time(self, query_start_time):
|
|
margin = 2
|
|
delta_args = {self.timeframe_unit: margin * self.timeframe_amount}
|
|
return query_start_time - dt.timedelta(**delta_args)
|
|
|
|
def x_time_ago(self, offset):
|
|
"""
|
|
Returns a datetime object representing the current time minus the offset in the specified units.
|
|
"""
|
|
delta_args = {self.timeframe_unit: offset}
|
|
return dt.datetime.utcnow().replace(tzinfo=pytz.utc) - dt.timedelta(**delta_args)
|
|
|
|
def _delta(self, i):
|
|
"""
|
|
Returns a timedelta object for the ith increment based on the timeframe unit and amount.
|
|
"""
|
|
delta_args = {self.timeframe_unit: i * self.timeframe_amount}
|
|
return dt.timedelta(**delta_args)
|
|
|
|
@staticmethod
|
|
def unix_time_millis(dt_obj: dt.datetime):
|
|
"""
|
|
Convert a datetime object to Unix time in milliseconds.
|
|
"""
|
|
if dt_obj.tzinfo is None:
|
|
raise ValueError('dt_obj needs to be timezone aware.')
|
|
epoch = dt.datetime(1970, 1, 1).replace(tzinfo=pytz.UTC)
|
|
return int((dt_obj - epoch).total_seconds() * 1000)
|
|
|
|
@staticmethod
|
|
def round_down_datetime(dt_obj: dt.datetime, unit: str, interval: int) -> dt.datetime:
|
|
if dt_obj.tzinfo is None:
|
|
raise ValueError('dt_obj needs to be timezone aware.')
|
|
|
|
if unit == 's': # Round down to the nearest interval of seconds
|
|
seconds = (dt_obj.second // interval) * interval
|
|
dt_obj = dt_obj.replace(second=seconds, microsecond=0)
|
|
elif unit == 'm': # Round down to the nearest interval of minutes
|
|
minutes = (dt_obj.minute // interval) * interval
|
|
dt_obj = dt_obj.replace(minute=minutes, second=0, microsecond=0)
|
|
elif unit == 'h': # Round down to the nearest interval of hours
|
|
hours = (dt_obj.hour // interval) * interval
|
|
dt_obj = dt_obj.replace(hour=hours, minute=0, second=0, microsecond=0)
|
|
elif unit == 'd': # Round down to the nearest interval of days
|
|
days = (dt_obj.day // interval) * interval
|
|
dt_obj = dt_obj.replace(day=days, hour=0, minute=0, second=0, microsecond=0)
|
|
elif unit == 'w': # Round down to the nearest interval of weeks
|
|
dt_obj -= dt.timedelta(days=dt_obj.weekday() % (interval * 7))
|
|
dt_obj = dt_obj.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
elif unit == 'M': # Round down to the nearest interval of months
|
|
months = ((dt_obj.month - 1) // interval) * interval + 1
|
|
dt_obj = dt_obj.replace(month=months, day=1, hour=0, minute=0, second=0, microsecond=0)
|
|
elif unit == 'y': # Round down to the nearest interval of years
|
|
years = (dt_obj.year // interval) * interval
|
|
dt_obj = dt_obj.replace(year=years, month=1, day=1, hour=0, minute=0, second=0, microsecond=0)
|
|
return dt_obj
|
|
|
|
|
|
class TestDataCache(unittest.TestCase):
|
|
def setUp(self):
|
|
# Initialize DataCache
|
|
self.exchanges = ExchangeInterface()
|
|
self.data = DataCache(self.exchanges)
|
|
|
|
self.exchanges_connected = False
|
|
self.database_is_setup = False
|
|
self.test_data_loaded = False
|
|
|
|
self.load_prerequisites()
|
|
|
|
def tearDown(self):
|
|
if self.database_is_setup:
|
|
if os.path.exists(self.db_file):
|
|
os.remove(self.db_file)
|
|
|
|
def load_prerequisites(self):
|
|
self.connect_exchanges()
|
|
self.set_up_database()
|
|
self.load_test_data()
|
|
|
|
def load_test_data(self):
|
|
if self.test_data_loaded:
|
|
return
|
|
# Create caches needed for testing
|
|
self.data.create_cache('candles', cache_type='row')
|
|
|
|
# Reuse details for exchange and market
|
|
self.ex_details = ['BTC/USD', '2h', 'binance', 'test_guy']
|
|
self.key = f'{self.ex_details[0]}_{self.ex_details[1]}_{self.ex_details[2]}'
|
|
|
|
self.test_data_loaded = True
|
|
|
|
def connect_exchanges(self):
|
|
if not self.exchanges_connected:
|
|
self.exchanges.connect_exchange(exchange_name='binance', user_name='test_guy', api_keys=None)
|
|
self.exchanges.connect_exchange(exchange_name='binance', user_name='user_1', api_keys=None)
|
|
self.exchanges.connect_exchange(exchange_name='binance', user_name='user_2', api_keys=None)
|
|
self.exchanges_connected = True
|
|
|
|
def set_up_database(self):
|
|
if not self.database_is_setup:
|
|
self.db_file = 'test_db.sqlite'
|
|
self.database = Database(db_file=self.db_file)
|
|
|
|
# Create necessary tables
|
|
sql_create_table_1 = f"""
|
|
CREATE TABLE IF NOT EXISTS test_table (
|
|
id INTEGER PRIMARY KEY,
|
|
market_id INTEGER,
|
|
time INTEGER UNIQUE ON CONFLICT IGNORE,
|
|
open REAL NOT NULL,
|
|
high REAL NOT NULL,
|
|
low REAL NOT NULL,
|
|
close REAL NOT NULL,
|
|
volume REAL NOT NULL,
|
|
FOREIGN KEY (market_id) REFERENCES market (id)
|
|
)"""
|
|
sql_create_table_2 = """
|
|
CREATE TABLE IF NOT EXISTS exchange (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT UNIQUE
|
|
)"""
|
|
sql_create_table_3 = """
|
|
CREATE TABLE IF NOT EXISTS markets (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
symbol TEXT,
|
|
exchange_id INTEGER,
|
|
FOREIGN KEY (exchange_id) REFERENCES exchange(id)
|
|
)"""
|
|
sql_create_table_4 = f"""
|
|
CREATE TABLE IF NOT EXISTS test_table_2 (
|
|
key TEXT PRIMARY KEY,
|
|
data TEXT NOT NULL
|
|
)"""
|
|
sql_create_table_5 = """
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_name TEXT,
|
|
age INTEGER,
|
|
users_data TEXT,
|
|
data TEXT,
|
|
password TEXT -- Moved to a new line and added a comma after 'data'
|
|
)
|
|
"""
|
|
|
|
with SQLite(db_file=self.db_file) as con:
|
|
con.execute(sql_create_table_1)
|
|
con.execute(sql_create_table_2)
|
|
con.execute(sql_create_table_3)
|
|
con.execute(sql_create_table_4)
|
|
con.execute(sql_create_table_5)
|
|
|
|
self.data.db = self.database # Keep the database setup
|
|
self.database_is_setup = True
|
|
|
|
def test_cache_system(self):
|
|
print("\n---- Starting Cache System Test ----")
|
|
|
|
# Step 1: Create CacheManager instance
|
|
cache_manager = CacheManager()
|
|
print("\n1. Created CacheManager instance.")
|
|
|
|
# Step 2: Create a Table-Based Cache for users
|
|
cache_manager.create_cache(name="users", cache_type="table",
|
|
default_expiration=dt.timedelta(seconds=2), size_limit=5)
|
|
table_cache = cache_manager.get_cache("users")
|
|
print("\n2. Created a Table-Based Cache for 'users' with expiration set to 2 seconds.")
|
|
|
|
# Step 3: Add users DataFrame to the table-based cache
|
|
users_df = pd.DataFrame([
|
|
{'name': 'Bob', 'age': 25, 'email': 'bob@example.com'},
|
|
{'name': 'Alice', 'age': 30, 'email': 'alice@example.com'}
|
|
])
|
|
table_cache.add_table(users_df)
|
|
print("\n3. Added user data to 'users' table-based cache:")
|
|
print(users_df)
|
|
|
|
# Step 4: Query the cache to retrieve Bob's information
|
|
result = table_cache.query([('name', 'Bob')])
|
|
print("\n4. Queried 'users' table-based cache for user 'Bob':")
|
|
print(result)
|
|
self.assertEqual(len(result), 1, "Should return exactly 1 row for Bob.")
|
|
self.assertEqual(result.iloc[0]['name'], 'Bob', "The name should be Bob.")
|
|
|
|
# Step 5: Wait for 3 seconds (after expiration time) and query again to check expiration
|
|
print("\n5. Waiting for 3 seconds to allow the cache to expire...")
|
|
time.sleep(3)
|
|
result_after_expiry = table_cache.query([('name', 'Bob')])
|
|
print("\n5. After 3 seconds, queried again for Bob. Result should be empty due to expiration:")
|
|
print(result_after_expiry)
|
|
self.assertTrue(result_after_expiry.empty, "Result should be empty as Bob's entry has expired.")
|
|
|
|
# Step 6: Create a Row-Based Cache for candles
|
|
cache_manager.create_cache(name="candles", cache_type="row",
|
|
default_expiration=dt.timedelta(seconds=5), size_limit=10)
|
|
row_cache = cache_manager.get_cache("candles")
|
|
print("\n6. Created a Row-Based Cache for 'candles' with expiration set to 5 seconds.")
|
|
|
|
# Step 7: Add candle data to the row-based cache
|
|
candle_data_1 = pd.DataFrame([
|
|
{'time': '2024-09-11 00:00', 'open': 100, 'high': 105, 'low': 99, 'close': 102},
|
|
{'time': '2024-09-11 01:00', 'open': 101, 'high': 106, 'low': 100, 'close': 103}
|
|
])
|
|
row_cache.add_entry("candle_1", candle_data_1)
|
|
print("\n7. Added candle data to 'candles' row-based cache under key 'candle_1':")
|
|
print(candle_data_1)
|
|
|
|
# Step 8: Query the row-based cache to retrieve specific time entry for candle_1
|
|
result_candle = row_cache.query([("key", "candle_1"), ("time", "2024-09-11 00:00")])
|
|
print("\n8. Queried 'candles' row-based cache for 'candle_1' and time '2024-09-11 00:00':")
|
|
print(result_candle)
|
|
self.assertEqual(len(result_candle), 1, "Should return exactly 1 row for the specified time.")
|
|
self.assertEqual(result_candle.iloc[0]['time'], '2024-09-11 00:00', "The time should match the queried time.")
|
|
|
|
# Step 9: Wait for 6 seconds (after expiration time) and query again to check expiration
|
|
print("\n9. Waiting for 6 seconds to allow the 'candle_1' cache to expire...")
|
|
time.sleep(6)
|
|
result_candle_after_expiry = row_cache.query([("key", "candle_1"), ("time", "2024-09-11 00:00")])
|
|
print("\n9. After 6 seconds, queried again for 'candle_1'. Result should be empty due to expiration:")
|
|
print(result_candle_after_expiry)
|
|
self.assertTrue(result_candle_after_expiry.empty, "Result should be empty as 'candle_1' has expired.")
|
|
|
|
# Step 10: Test the size limit of the row-based cache (adding more than limit)
|
|
print("\n10. Testing row-based cache size limit (max 10 entries).")
|
|
for i in range(1, 12):
|
|
row_cache.add_entry(f"candle_{i}", pd.DataFrame([
|
|
{'time': f'2024-09-11 00:00', 'open': 100 + i, 'high': 105 + i, 'low': 99 + i, 'close': 102 + i}
|
|
]))
|
|
print(f"Added entry: candle_{i}")
|
|
|
|
print("\n11. Checking the size of the cache after adding 11 entries (limit is 10):")
|
|
result = row_cache.get_entry("candle_1")
|
|
print(f"Checking 'candle_1': {result}")
|
|
self.assertIsNone(result, "'candle_1' should have been evicted as the size limit is 10.")
|
|
|
|
# Final print statement for clarity of test ending
|
|
print("\n---- Cache System Test Completed ----")
|
|
|
|
def test_cache_system_advanced_usage(self):
|
|
print("\n---- Starting Advanced Cache System Test ----")
|
|
|
|
# Step 1: Create CacheManager instance
|
|
cache_manager = CacheManager()
|
|
print("\n1. Created CacheManager instance.")
|
|
|
|
# Row-Based Cache Test with Different Data Types
|
|
cache_manager.create_cache(name="row_cache", cache_type="row",
|
|
default_expiration=dt.timedelta(seconds=5), size_limit=10)
|
|
row_cache = cache_manager.get_cache("row_cache")
|
|
print("\n2. Created a Row-Based Cache with expiration set to 5 seconds.")
|
|
|
|
# Step 2: Add different types of data into Row-Based Cache
|
|
# Add a string
|
|
row_cache.add_entry("message", "Hello, World!")
|
|
print("\n3. Added a string to Row-Based Cache under key 'message'.")
|
|
|
|
# Add a dictionary
|
|
row_cache.add_entry("user_profile", {"name": "Charlie", "age": 28, "email": "charlie@example.com"})
|
|
print("\n4. Added a dictionary to Row-Based Cache under key 'user_profile'.")
|
|
|
|
# Add a list of numbers
|
|
row_cache.add_entry("numbers", [1, 2, 3, 4, 5])
|
|
print("\n5. Added a list of numbers to Row-Based Cache under key 'numbers'.")
|
|
|
|
# Step 3: Query the Row-Based Cache
|
|
print("\n6. Querying Row-Based Cache for different types of data:")
|
|
result_message = row_cache.query([("key", "message")])
|
|
print(f"Query result for key 'message': {result_message}")
|
|
|
|
result_profile = row_cache.query([("key", "user_profile")])
|
|
print(f"Query result for key 'user_profile': {result_profile}")
|
|
|
|
result_numbers = row_cache.query([("key", "numbers")])
|
|
print(f"Query result for key 'numbers': {result_numbers}")
|
|
|
|
# Assert non-expired entries
|
|
self.assertEqual(result_message.iloc[0][0], "Hello, World!", "Message should be 'Hello, World!'")
|
|
self.assertEqual(result_profile.iloc[0]['name'], 'Charlie', "User profile should have name 'Charlie'")
|
|
|
|
# Convert the DataFrame row back to a list and assert the values match
|
|
numbers_list = result_numbers.iloc[0].tolist()
|
|
self.assertEqual(numbers_list, [1, 2, 3, 4, 5], "Should return list of numbers.")
|
|
|
|
# Table-Based Cache Test with DataFrames
|
|
cache_manager.create_cache(name="table_cache", cache_type="table",
|
|
default_expiration=dt.timedelta(seconds=5), size_limit=5)
|
|
table_cache = cache_manager.get_cache("table_cache")
|
|
print("\n7. Created a Table-Based Cache with expiration set to 5 seconds.")
|
|
|
|
# Step 4: Add a DataFrame with mixed data types to Table-Based Cache
|
|
mixed_df = pd.DataFrame([
|
|
{'category': 'A', 'value': 100, 'timestamp': '2024-09-12 12:00'},
|
|
{'category': 'B', 'value': 200, 'timestamp': '2024-09-12 13:00'},
|
|
{'category': 'A', 'value': 150, 'timestamp': '2024-09-12 14:00'}
|
|
])
|
|
table_cache.add_table(mixed_df)
|
|
print("\n8. Added mixed DataFrame to Table-Based Cache:")
|
|
print(mixed_df)
|
|
|
|
# Step 5: Query the Table-Based Cache
|
|
print("\n9. Querying Table-Based Cache for category 'A':")
|
|
result_category_a = table_cache.query([("category", "A")])
|
|
print(result_category_a)
|
|
self.assertEqual(len(result_category_a), 2, "There should be 2 rows with category 'A'.")
|
|
|
|
print("\n10. Querying Table-Based Cache for value greater than 100:")
|
|
result_value_gt_100 = table_cache.query([("value", 150)])
|
|
print(result_value_gt_100)
|
|
self.assertEqual(len(result_value_gt_100), 1, "There should be 1 row with value of 150.")
|
|
|
|
# Step 6: Wait for entries to expire and query again
|
|
print("\n11. Waiting for 6 seconds to let all rows expire...")
|
|
time.sleep(6)
|
|
result_after_expiry = table_cache.query([("category", "A")])
|
|
print(f"\n12. After 6 seconds, querying again for category 'A'. Result should be empty:")
|
|
print(result_after_expiry)
|
|
self.assertTrue(result_after_expiry.empty, "Result should be empty due to expiration.")
|
|
|
|
# Final print statement for clarity of test ending
|
|
print("\n---- Advanced Cache System Test Completed ----")
|
|
|
|
def test_cache_system_edge_cases(self):
|
|
print("\n---- Starting Edge Case Cache System Test ----")
|
|
|
|
# Step 1: Create CacheManager instance
|
|
cache_manager = CacheManager()
|
|
print("\n1. Created CacheManager instance.")
|
|
|
|
# Test 1: Cache Size Limit (Row-Based Cache)
|
|
cache_manager.create_cache(name="limited_row_cache", cache_type="row", size_limit=3, eviction_policy="evict")
|
|
limited_row_cache = cache_manager.get_cache("limited_row_cache")
|
|
print("\n2. Created a Row-Based Cache with size limit of 3 and eviction policy 'evict'.")
|
|
|
|
# Add entries beyond the size limit and check eviction
|
|
limited_row_cache.add_entry("item1", "Data 1")
|
|
print(f"Cache after adding item1: {limited_row_cache.cache}")
|
|
|
|
limited_row_cache.add_entry("item2", "Data 2")
|
|
print(f"Cache after adding item2: {limited_row_cache.cache}")
|
|
|
|
limited_row_cache.add_entry("item3", "Data 3")
|
|
print(f"Cache after adding item3: {limited_row_cache.cache}")
|
|
|
|
# Add 4th entry, which should cause eviction of the first entry
|
|
limited_row_cache.add_entry("item4", "Data 4")
|
|
print(f"Cache after adding item4: {limited_row_cache.cache}")
|
|
|
|
# Verify eviction of the oldest entry (item1)
|
|
result_item1 = limited_row_cache.query([("key", "item1")])
|
|
print(f"Query result for 'item1' (should be evicted): {result_item1}")
|
|
self.assertTrue(result_item1.empty, "'item1' should be evicted.")
|
|
|
|
# Verify that other items exist
|
|
result_item4 = limited_row_cache.query([("key", "item4")])
|
|
print(f"Query result for 'item4': {result_item4}")
|
|
self.assertFalse(result_item4.empty, "'item4' should be present.")
|
|
|
|
def test_access_counter_and_purging(self):
|
|
"""Test access counter and purging mechanism."""
|
|
print("\n---- Starting Access Counter and Purge Mechanism Test ----")
|
|
|
|
# Step 1: Create Row-Based Cache with a purge threshold of 5 accesses
|
|
cache = RowBasedCache(default_expiration=1, purge_threshold=5)
|
|
print("\n1. Created a Row-Based Cache with purge threshold of 5 accesses.")
|
|
|
|
# Step 2: Add entries with expiration times
|
|
cache.add_entry("item1", "Data 1")
|
|
cache.add_entry("item2", "Data 2")
|
|
print("\n2. Added 2 entries with a 1-second expiration time.")
|
|
|
|
# Step 3: Access cache 5 times to trigger purge
|
|
for i in range(5):
|
|
cache.get_entry("item1")
|
|
print(f"\n3. Accessed cache {i+1} times.")
|
|
|
|
# Step 4: Ensure item1 is still in cache (it hasn't expired yet because of timing)
|
|
result_item1 = cache.get_entry("item1")
|
|
print(f"\n4. Retrieved 'item1' from cache before expiration: {result_item1}")
|
|
self.assertIsNotNone(result_item1, "'item1' should still be in cache.")
|
|
|
|
# Step 5: Wait for expiration and access again to trigger purge
|
|
time.sleep(2)
|
|
result_item1_expired = cache.get_entry("item1")
|
|
print(f"\n5. Retrieved 'item1' after expiration (should be None): {result_item1_expired}")
|
|
self.assertIsNone(result_item1_expired, "'item1' should have expired.")
|
|
|
|
# Step 6: Access cache 5 more times to confirm expired entries are purged
|
|
for i in range(5):
|
|
cache.get_entry("item2")
|
|
print(f"\n6. Accessed cache {i+1} more times to trigger another purge.")
|
|
|
|
# Verify item2 is also expired after accesses
|
|
result_item2_expired = cache.get_entry("item2")
|
|
print(f"\n7. Retrieved 'item2' after expiration (should be None): {result_item2_expired}")
|
|
self.assertIsNone(result_item2_expired, "'item2' should have expired.")
|
|
|
|
def test_is_attr_taken_in_row_cache(self):
|
|
"""Test the is_attr_taken method in the Row-Based Cache."""
|
|
print("\n---- Starting is_attr_taken in Row-Based Cache Test ----")
|
|
|
|
# Step 1: Create a Row-Based Cache
|
|
cache = RowBasedCache()
|
|
print("\n1. Created a Row-Based Cache.")
|
|
|
|
# Step 2: Add entries with DataFrames
|
|
df = pd.DataFrame({'name': ['Alice', 'Bob'], 'age': [30, 25]})
|
|
cache.add_entry("users", df)
|
|
print("\n2. Added a DataFrame to the cache under key 'users'.")
|
|
|
|
# Step 3: Test is_attr_taken
|
|
attr_taken = cache.is_attr_taken('name', 'Alice')
|
|
print(f"\n3. Checked if 'name' column contains 'Alice': {attr_taken}")
|
|
self.assertTrue(attr_taken, "'name' column should contain 'Alice'.")
|
|
|
|
attr_not_taken = cache.is_attr_taken('name', 'Charlie')
|
|
print(f"\n4. Checked if 'name' column contains 'Charlie': {attr_not_taken}")
|
|
self.assertFalse(attr_not_taken, "'name' column should not contain 'Charlie'.")
|
|
|
|
def test_is_attr_taken_in_table_cache(self):
|
|
"""Test the is_attr_taken method in the Table-Based Cache."""
|
|
print("\n---- Starting is_attr_taken in Table-Based Cache Test ----")
|
|
|
|
# Step 1: Create a Table-Based Cache
|
|
cache = TableBasedCache()
|
|
print("\n1. Created a Table-Based Cache.")
|
|
|
|
# Step 2: Add a DataFrame to the Table-Based Cache
|
|
df = pd.DataFrame({'name': ['Alice', 'Charlie'], 'age': [30, 40]})
|
|
cache.add_table(df)
|
|
print("\n2. Added a DataFrame to the Table-Based Cache.")
|
|
|
|
# Step 3: Test is_attr_taken
|
|
attr_taken = cache.is_attr_taken('name', 'Alice')
|
|
print(f"\n3. Checked if 'name' column contains 'Alice': {attr_taken}")
|
|
self.assertTrue(attr_taken, "'name' column should contain 'Alice'.")
|
|
|
|
attr_not_taken = cache.is_attr_taken('name', 'Bob')
|
|
print(f"\n4. Checked if 'name' column contains 'Bob': {attr_not_taken}")
|
|
self.assertFalse(attr_not_taken, "'name' column should not contain 'Bob'.")
|
|
|
|
def test_expired_entry_handling(self):
|
|
"""Test that expired entries are not returned when querying."""
|
|
print("\n---- Starting Expired Entry Handling Test ----")
|
|
|
|
# Step 1: Create a Row-Based Cache with a 1-second expiration
|
|
cache = RowBasedCache(default_expiration=1)
|
|
print("\n1. Created a Row-Based Cache with a 1-second expiration.")
|
|
|
|
# Step 2: Add an entry
|
|
cache.add_entry("item1", "Temporary Data")
|
|
print("\n2. Added an entry 'item1' with a 1-second expiration.")
|
|
|
|
# Step 3: Wait for expiration
|
|
time.sleep(2)
|
|
|
|
# Step 4: Query the expired entry
|
|
result = cache.get_entry("item1")
|
|
print(f"\n4. Queried 'item1' after expiration: {result}")
|
|
self.assertIsNone(result, "'item1' should have expired and not be returned.")
|
|
|
|
def test_remove_item_with_conditions_row_cache(self):
|
|
"""Test remove_item method in Row-Based Cache with conditions."""
|
|
print("\n---- Starting remove_item with conditions in Row-Based Cache Test ----")
|
|
|
|
# Step 1: Create a Row-Based Cache
|
|
cache = RowBasedCache()
|
|
print("\n1. Created a Row-Based Cache.")
|
|
|
|
# Step 2: Add a DataFrame to the cache under 'users'
|
|
df = pd.DataFrame({'name': ['Alice', 'Bob', 'Charlie'], 'age': [30, 25, 35]})
|
|
cache.add_entry("users", df)
|
|
print("\n2. Added a DataFrame to the cache under key 'users'.")
|
|
print(cache.get_entry("users"))
|
|
|
|
# Step 3: Remove a specific row (where 'name' == 'Alice') from the DataFrame
|
|
removed = cache.remove_item([('key', 'users'), ('name', 'Alice')])
|
|
print(f"\n3. Removed 'Alice' from the DataFrame: {removed}")
|
|
self.assertTrue(removed, "'Alice' should be removed from the DataFrame.")
|
|
|
|
# Verify that 'Alice' was removed
|
|
remaining_data = cache.get_entry('users')
|
|
print(f"\n4. Remaining data in 'users': \n{remaining_data}")
|
|
self.assertNotIn('Alice', remaining_data['name'].values, "'Alice' should no longer be in the DataFrame.")
|
|
self.assertIn('Bob', remaining_data['name'].values, "'Bob' should still be in the DataFrame.")
|
|
|
|
# Step 4: Remove the last remaining row (where 'name' == 'Charlie')
|
|
removed = cache.remove_item([('key', 'users'), ('name', 'Charlie')])
|
|
print(f"\n5. Removed 'Charlie' from the DataFrame: {removed}")
|
|
remaining_data = cache.get_entry('users')
|
|
print(f"\n6. Remaining data in 'users' after 'Charlie' removal: \n{remaining_data}")
|
|
self.assertNotIn('Charlie', remaining_data['name'].values, "'Charlie' should no longer be in the DataFrame.")
|
|
|
|
# Step 5: Remove the last row (where 'name' == 'Bob'), this should remove the entire entry
|
|
removed = cache.remove_item([('key', 'users'), ('name', 'Bob')])
|
|
print(f"\n7. Removed 'Bob' from the DataFrame: {removed}")
|
|
remaining_data = cache.get_entry('users')
|
|
print(f"\n8. Remaining data in 'users' after removing 'Bob' (should be None): {remaining_data}")
|
|
self.assertIsNone(remaining_data, "'users' entry should no longer exist in the cache.")
|
|
|
|
def test_remove_item_with_conditions_table_cache(self):
|
|
"""Test remove_item method in Table-Based Cache with conditions."""
|
|
print("\n---- Starting remove_item with conditions in Table-Based Cache Test ----")
|
|
|
|
# Step 1: Create a Table-Based Cache
|
|
cache = TableBasedCache()
|
|
print("\n1. Created a Table-Based Cache.")
|
|
|
|
# Step 2: Add a DataFrame to the cache
|
|
df = pd.DataFrame({'name': ['Alice', 'Bob', 'Charlie'], 'age': [30, 25, 35]})
|
|
cache.add_table(df)
|
|
print("\n2. Added a DataFrame to the Table-Based Cache.")
|
|
print(cache.get_all_items())
|
|
|
|
# Step 3: Remove a specific row (where 'name' == 'Alice')
|
|
removed = cache.remove_item([('name', 'Alice')])
|
|
print(f"\n3. Removed 'Alice' from the table: {removed}")
|
|
self.assertTrue(removed, "'Alice' should be removed from the table-based cache.")
|
|
|
|
# Verify that 'Alice' was removed
|
|
remaining_data = cache.get_all_items()
|
|
print(f"\n4. Remaining data in the cache: \n{remaining_data}")
|
|
self.assertNotIn('Alice', remaining_data['name'].values, "'Alice' should no longer be in the table.")
|
|
self.assertIn('Bob', remaining_data['name'].values, "'Bob' should still be in the table.")
|
|
|
|
# Step 4: Remove another row (where 'name' == 'Charlie')
|
|
removed = cache.remove_item([('name', 'Charlie')])
|
|
print(f"\n5. Removed 'Charlie' from the table: {removed}")
|
|
remaining_data = cache.get_all_items()
|
|
print(f"\n6. Remaining data in the cache after removing 'Charlie': \n{remaining_data}")
|
|
self.assertNotIn('Charlie', remaining_data['name'].values, "'Charlie' should no longer be in the table.")
|
|
|
|
# Step 5: Remove the last row (where 'name' == 'Bob')
|
|
removed = cache.remove_item([('name', 'Bob')])
|
|
print(f"\n7. Removed 'Bob' from the table: {removed}")
|
|
remaining_data = cache.get_all_items()
|
|
print(f"\n8. Remaining data in the cache after removing 'Bob' (should be empty): \n{remaining_data}")
|
|
self.assertTrue(remaining_data.empty, "The table should be empty after removing all rows.")
|
|
|
|
def test_remove_item_with_conditions_market_data(self):
|
|
"""Test remove_item method in Row-Based Cache with market OHLC data."""
|
|
print("\n---- Starting remove_item with market OHLC data in Row-Based Cache Test ----")
|
|
|
|
# Step 1: Create a Row-Based Cache for market data
|
|
cache = RowBasedCache()
|
|
print("\n1. Created a Row-Based Cache for market data.")
|
|
|
|
# Step 2: Add OHLC data for 'BTC' and 'ETH'
|
|
btc_data = pd.DataFrame({
|
|
'timestamp': ['2024-09-10 12:00', '2024-09-10 12:05', '2024-09-10 12:10'],
|
|
'open': [30000, 30100, 30200],
|
|
'high': [30500, 30600, 30700],
|
|
'low': [29900, 30050, 30150],
|
|
'close': [30400, 30550, 30650]
|
|
})
|
|
eth_data = pd.DataFrame({
|
|
'timestamp': ['2024-09-10 12:00', '2024-09-10 12:05', '2024-09-10 12:10'],
|
|
'open': [2000, 2010, 2020],
|
|
'high': [2050, 2060, 2070],
|
|
'low': [1990, 2005, 2015],
|
|
'close': [2040, 2055, 2065]
|
|
})
|
|
cache.add_entry("BTC", btc_data)
|
|
cache.add_entry("ETH", eth_data)
|
|
print("\n2. Added OHLC data for 'BTC' and 'ETH'.")
|
|
print(f"BTC Data:\n{cache.get_entry('BTC')}")
|
|
print(f"ETH Data:\n{cache.get_entry('ETH')}")
|
|
|
|
# Step 3: Remove a specific row from 'BTC' data where timestamp == '2024-09-10 12:05'
|
|
removed = cache.remove_item([('key', 'BTC'), ('timestamp', '2024-09-10 12:05')])
|
|
print(f"\n3. Removed '2024-09-10 12:05' row from 'BTC' data: {removed}")
|
|
self.assertTrue(removed, "'2024-09-10 12:05' should be removed from the 'BTC' data.")
|
|
|
|
# Verify that the timestamp was removed from 'BTC' data
|
|
remaining_btc = cache.get_entry('BTC')
|
|
print(f"\n4. Remaining BTC data after removal:\n{remaining_btc}")
|
|
self.assertNotIn('2024-09-10 12:05', remaining_btc['timestamp'].values,
|
|
"'2024-09-10 12:05' should no longer be in the 'BTC' data.")
|
|
|
|
# Step 4: Remove entire 'ETH' data entry
|
|
removed_eth = cache.remove_item([('key', 'ETH')])
|
|
print(f"\n5. Removed entire 'ETH' data entry: {removed_eth}")
|
|
self.assertTrue(removed_eth, "'ETH' data should be removed from the cache.")
|
|
|
|
# Verify that 'ETH' was completely removed from the cache
|
|
remaining_eth = cache.get_entry('ETH')
|
|
print(f"\n6. Remaining ETH data after removal (should be None): {remaining_eth}")
|
|
self.assertIsNone(remaining_eth, "'ETH' entry should no longer exist in the cache.")
|
|
|
|
def test_remove_item_with_conditions_trade_stats(self):
|
|
"""Test remove_item method in Row-Based Cache with trade statistics data."""
|
|
print("\n---- Starting remove_item with trade statistics in Row-Based Cache Test ----")
|
|
|
|
# Step 1: Create a Row-Based Cache for trade statistics
|
|
cache = RowBasedCache()
|
|
print("\n1. Created a Row-Based Cache for trade statistics.")
|
|
|
|
# Step 2: Add trade statistics for 'strategy_1' and 'strategy_2'
|
|
strategy_1_data = pd.DataFrame({
|
|
'date': ['2024-09-10', '2024-09-11', '2024-09-12'],
|
|
'success_rate': [80, 85, 75],
|
|
'trades': [10, 12, 8]
|
|
})
|
|
strategy_2_data = pd.DataFrame({
|
|
'date': ['2024-09-10', '2024-09-11', '2024-09-12'],
|
|
'success_rate': [60, 70, 65],
|
|
'trades': [15, 17, 14]
|
|
})
|
|
cache.add_entry("strategy_1", strategy_1_data)
|
|
cache.add_entry("strategy_2", strategy_2_data)
|
|
print("\n2. Added trade statistics for 'strategy_1' and 'strategy_2'.")
|
|
print(f"Strategy 1 Data:\n{cache.get_entry('strategy_1')}")
|
|
print(f"Strategy 2 Data:\n{cache.get_entry('strategy_2')}")
|
|
|
|
# Step 3: Remove a specific row from 'strategy_1' where date == '2024-09-11'
|
|
removed = cache.remove_item([('key', 'strategy_1'), ('date', '2024-09-11')])
|
|
print(f"\n3. Removed '2024-09-11' row from 'strategy_1' data: {removed}")
|
|
self.assertTrue(removed, "'2024-09-11' should be removed from the 'strategy_1' data.")
|
|
|
|
# Verify that the date was removed from 'strategy_1' data
|
|
remaining_strategy_1 = cache.get_entry('strategy_1')
|
|
print(f"\n4. Remaining strategy_1 data after removal:\n{remaining_strategy_1}")
|
|
self.assertNotIn('2024-09-11', remaining_strategy_1['date'].values,
|
|
"'2024-09-11' should no longer be in the 'strategy_1' data.")
|
|
|
|
# Step 4: Remove entire 'strategy_2' data entry
|
|
removed_strategy_2 = cache.remove_item([('key', 'strategy_2')])
|
|
print(f"\n5. Removed entire 'strategy_2' data entry: {removed_strategy_2}")
|
|
self.assertTrue(removed_strategy_2, "'strategy_2' data should be removed from the cache.")
|
|
|
|
# Verify that 'strategy_2' was completely removed from the cache
|
|
remaining_strategy_2 = cache.get_entry('strategy_2')
|
|
print(f"\n6. Remaining strategy_2 data after removal (should be None): {remaining_strategy_2}")
|
|
self.assertIsNone(remaining_strategy_2, "'strategy_2' entry should no longer exist in the cache.")
|
|
|
|
def test_remove_item_with_other_data_types(self):
|
|
"""Test remove_item method in Row-Based Cache with different data types."""
|
|
print("\n---- Starting remove_item with different data types in Row-Based Cache Test ----")
|
|
|
|
# Step 1: Create a Row-Based Cache for mixed data types
|
|
cache = RowBasedCache()
|
|
print("\n1. Created a Row-Based Cache for mixed data types.")
|
|
|
|
# Step 2: Add entries with different data types
|
|
|
|
# String
|
|
cache.add_entry("message", "Hello, World!")
|
|
print("\n2. Added a string 'Hello, World!' under key 'message'.")
|
|
|
|
# Dictionary
|
|
cache.add_entry("user_profile", {"name": "Alice", "age": 30, "email": "alice@example.com"})
|
|
print("\n3. Added a dictionary under key 'user_profile'.")
|
|
|
|
# List
|
|
cache.add_entry("numbers", [1, 2, 3, 4, 5])
|
|
print("\n4. Added a list of numbers under key 'numbers'.")
|
|
|
|
# Integer
|
|
cache.add_entry("count", 42)
|
|
print("\n5. Added an integer '42' under key 'count'.")
|
|
|
|
# Step 3: Remove specific entries based on key
|
|
|
|
# Remove string entry
|
|
removed_message = cache.remove_item([('key', 'message')])
|
|
print(f"\n6. Removed string entry: {removed_message}")
|
|
self.assertTrue(removed_message, "'message' should be removed from the cache.")
|
|
self.assertIsNone(cache.get_entry('message'), "'message' entry should no longer exist.")
|
|
|
|
# Remove dictionary entry
|
|
removed_user_profile = cache.remove_item([('key', 'user_profile')])
|
|
print(f"\n7. Removed dictionary entry: {removed_user_profile}")
|
|
self.assertTrue(removed_user_profile, "'user_profile' should be removed from the cache.")
|
|
self.assertIsNone(cache.get_entry('user_profile'), "'user_profile' entry should no longer exist.")
|
|
|
|
# Remove list entry
|
|
removed_numbers = cache.remove_item([('key', 'numbers')])
|
|
print(f"\n8. Removed list entry: {removed_numbers}")
|
|
self.assertTrue(removed_numbers, "'numbers' should be removed from the cache.")
|
|
self.assertIsNone(cache.get_entry('numbers'), "'numbers' entry should no longer exist.")
|
|
|
|
# Remove integer entry
|
|
removed_count = cache.remove_item([('key', 'count')])
|
|
print(f"\n9. Removed integer entry: {removed_count}")
|
|
self.assertTrue(removed_count, "'count' should be removed from the cache.")
|
|
self.assertIsNone(cache.get_entry('count'), "'count' entry should no longer exist.")
|
|
|
|
def test_snapshot_row_based_cache(self):
|
|
"""Test snapshot functionality with row-based cache."""
|
|
print("\n---- Starting Snapshot Test with Row-Based Cache ----")
|
|
|
|
# Step 1: Create an instance of SnapshotDataCache
|
|
snapshot_cache = SnapshotDataCache()
|
|
print("\n1. Created SnapshotDataCache instance.")
|
|
|
|
# Step 2: Create a row-based cache and add data
|
|
snapshot_cache.create_cache(name="market_data", cache_type="row")
|
|
market_data = pd.DataFrame({
|
|
'timestamp': ['2024-09-10 12:00', '2024-09-10 12:05'],
|
|
'open': [30000, 30100],
|
|
'high': [30500, 30600],
|
|
'low': [29900, 30050],
|
|
'close': [30400, 30550]
|
|
})
|
|
snapshot_cache.get_cache("market_data").add_entry("BTC", market_data)
|
|
print("\n2. Added 'BTC' market data to row-based cache.")
|
|
|
|
# Step 3: Take a snapshot of the row-based cache
|
|
snapshot_cache.snapshot_cache("market_data")
|
|
snapshot_list = snapshot_cache.list_snapshots()
|
|
print(f"\n3. Snapshot list after taking snapshot: {snapshot_list}")
|
|
self.assertIn("market_data", snapshot_list, "Snapshot for 'market_data' should be present.")
|
|
|
|
# Step 4: Retrieve the snapshot and verify its contents
|
|
snapshot = snapshot_cache.get_snapshot("market_data")
|
|
print(f"\n4. Retrieved snapshot of 'market_data':\n{snapshot.get_entry('BTC')}")
|
|
pd.testing.assert_frame_equal(snapshot.get_entry('BTC'), market_data)
|
|
|
|
# Step 5: Add more data to the live cache and verify the snapshot is unchanged
|
|
additional_data = pd.DataFrame({
|
|
'timestamp': ['2024-09-10 12:10'],
|
|
'open': [30200],
|
|
'high': [30700],
|
|
'low': [30150],
|
|
'close': [30650]
|
|
})
|
|
snapshot_cache.get_cache("market_data").add_entry("BTC", additional_data)
|
|
print("\n5. Added additional data to the live 'BTC' cache.")
|
|
|
|
# Verify live cache has updated but the snapshot remains unchanged
|
|
live_data = snapshot_cache.get_cache("market_data").get_entry("BTC")
|
|
print(f"\n6. Live 'BTC' cache data:\n{live_data}")
|
|
self.assertEqual(len(live_data), 3, "Live cache should have 3 rows after adding more data.")
|
|
|
|
# Ensure the snapshot still has the original data
|
|
snapshot_data = snapshot_cache.get_snapshot("market_data").get_entry("BTC")
|
|
print(f"\n7. Snapshot data (should still be original):\n{snapshot_data}")
|
|
self.assertEqual(len(snapshot_data), 2, "Snapshot should still have the original 2 rows.")
|
|
|
|
def test_snapshot_table_based_cache_with_overwrite_column(self):
|
|
"""Test snapshot functionality with table-based cache and overwrite by column."""
|
|
print("\n---- Starting Snapshot Test with Table-Based Cache ----")
|
|
|
|
# Step 1: Create an instance of SnapshotDataCache
|
|
snapshot_cache = SnapshotDataCache()
|
|
print("\n1. Created SnapshotDataCache instance.")
|
|
|
|
# Step 2: Create a table-based cache and add initial data
|
|
user_data = pd.DataFrame({
|
|
'name': ['Alice', 'Bob'],
|
|
'email': ['alice@example.com', 'bob@example.com'],
|
|
'age': [30, 25]
|
|
})
|
|
snapshot_cache.create_cache(name="user_data", cache_type="table")
|
|
snapshot_cache.get_cache("user_data").add_table(user_data)
|
|
print("\n2. Added user data to table-based cache.")
|
|
|
|
# Step 3: Take a snapshot of the table-based cache
|
|
snapshot_cache.snapshot_cache("user_data")
|
|
snapshot_list = snapshot_cache.list_snapshots()
|
|
print(f"\n3. Snapshot list after taking snapshot: {snapshot_list}")
|
|
self.assertIn("user_data", snapshot_list, "Snapshot for 'user_data' should be present.")
|
|
|
|
# Step 4: Retrieve the snapshot and verify its contents (excluding metadata)
|
|
snapshot = snapshot_cache.get_snapshot("user_data")
|
|
snapshot_data = snapshot.get_all_items().drop(columns=['metadata'])
|
|
print(f"\n4. Retrieved snapshot of 'user_data' (without metadata):\n{snapshot_data}")
|
|
pd.testing.assert_frame_equal(snapshot_data, user_data)
|
|
|
|
# Step 5: Modify the live cache and overwrite specific rows by 'name'
|
|
updated_user_data = pd.DataFrame({
|
|
'name': ['Alice', 'Bob', 'Charlie'],
|
|
'email': ['alice@example.com', 'bob@example.com', 'charlie@example.com'],
|
|
'age': [35, 25, 40]
|
|
})
|
|
snapshot_cache.get_cache("user_data").add_table(updated_user_data, overwrite='name')
|
|
print("\n5. Updated live table by overwriting rows based on 'name'.")
|
|
|
|
# Verify live cache has updated but the snapshot remains unchanged
|
|
live_data = snapshot_cache.get_cache("user_data").get_all_items().drop(columns=['metadata'])
|
|
print(f"\n6. Live user_data table (without metadata):\n{live_data}")
|
|
self.assertEqual(len(live_data), 3,
|
|
"Live cache should have 3 rows after adding 'Charlie' and overwriting 'Alice'.")
|
|
|
|
# Ensure the snapshot still has the original data
|
|
snapshot_data = snapshot_cache.get_snapshot("user_data").get_all_items().drop(columns=['metadata'])
|
|
print(f"\n7. Snapshot data (should still be original, without metadata):\n{snapshot_data}")
|
|
self.assertEqual(len(snapshot_data), 2, "Snapshot should still have the original 2 rows.")
|
|
|
|
def test_update_candle_cache(self):
|
|
self.load_prerequisites()
|
|
|
|
print('Testing update_candle_cache() method:')
|
|
|
|
# Set a cache key
|
|
candle_cache_key = f'{self.ex_details[0]}_{self.ex_details[1]}_{self.ex_details[2]}'
|
|
|
|
# Initialize the DataGenerator with the 5-minute timeframe
|
|
data_gen = DataGenerator('5m')
|
|
|
|
# Create initial DataFrame and insert it into the cache
|
|
df_initial = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 0, 0, tzinfo=dt.timezone.utc))
|
|
|
|
print(f'Inserting this table into cache:\n{df_initial}\n')
|
|
self.data.set_cache_item(key=candle_cache_key, data=df_initial, cache_name='candles')
|
|
|
|
# Create new DataFrame to be added to the cache
|
|
df_new = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 15, 0, tzinfo=dt.timezone.utc))
|
|
|
|
print(f'Updating cache with this table:\n{df_new}\n')
|
|
self.data._update_candle_cache(more_records=df_new, key=candle_cache_key)
|
|
|
|
# Retrieve the resulting DataFrame from the cache
|
|
result = self.data.get_cache_item(key=candle_cache_key, cache_name='candles')
|
|
print(f'The resulting table in cache is:\n{result}\n')
|
|
|
|
# Create the expected DataFrame
|
|
expected = data_gen.create_table(num_rec=6, start=dt.datetime(2024, 8, 9, 0, 0, 0, tzinfo=dt.timezone.utc))
|
|
print(f'The expected time values are:\n{expected["time"].tolist()}\n')
|
|
|
|
# Assert that the time values in the result match those in the expected DataFrame, in order
|
|
assert result['time'].tolist() == expected['time'].tolist(), \
|
|
f"time values in result are {result['time'].tolist()} expected {expected['time'].tolist()}"
|
|
|
|
print(f'The result time values match:\n{result["time"].tolist()}\n')
|
|
print(' - Update cache with new records passed.')
|
|
|
|
def analyze_cache_update(self, existing_records, more_records):
|
|
print("\n### Initial Data ###")
|
|
print("Existing Records:")
|
|
print(existing_records)
|
|
print("\nMore Records:")
|
|
print(more_records)
|
|
|
|
# Column-by-column comparison
|
|
print("\n### Column Comparison ###")
|
|
for col in existing_records.columns:
|
|
if col in more_records.columns:
|
|
print(f"\nAnalyzing column: {col}")
|
|
print(f"Existing Records '{col}' values:\n{existing_records[col].tolist()}")
|
|
print(f"More Records '{col}' values:\n{more_records[col].tolist()}")
|
|
print(f"Existing Records '{col}' type: {existing_records[col].dtype}")
|
|
print(f"More Records '{col}' type: {more_records[col].dtype}")
|
|
|
|
# Check for duplicate rows based on the 'time' column
|
|
print("\n### Duplicate Detection ###")
|
|
combined = pd.concat([existing_records, more_records], ignore_index=True)
|
|
print("Combined Records (before removing duplicates):")
|
|
print(combined)
|
|
|
|
# Method 1: Drop duplicates keeping the last occurrence
|
|
no_duplicates_last = combined.drop_duplicates(subset='time', keep='last')
|
|
print("\nAfter Dropping Duplicates (keep='last'):")
|
|
print(no_duplicates_last)
|
|
|
|
# Method 2: Drop duplicates keeping the first occurrence
|
|
no_duplicates_first = combined.drop_duplicates(subset='time', keep='first')
|
|
print("\nAfter Dropping Duplicates (keep='first'):")
|
|
print(no_duplicates_first)
|
|
|
|
# Method 3: Ensure 'time' is in a consistent data type, and drop duplicates
|
|
combined['time'] = combined['time'].astype('int64')
|
|
consistent_time = combined.drop_duplicates(subset='time', keep='last')
|
|
print("\nAfter Dropping Duplicates with 'time' as int64:")
|
|
print(consistent_time)
|
|
|
|
print("\n### Final Analysis ###")
|
|
print("Resulting DataFrame after sorting by 'time':")
|
|
final_result = consistent_time.sort_values(by='time').reset_index(drop=True)
|
|
print(final_result)
|
|
|
|
def test_reproduce_duplicate_issue(self):
|
|
# Simulating DataFrames like in your original test
|
|
# Time as epoch timestamps
|
|
existing_records = pd.DataFrame({
|
|
'market_id': [1, 1, 1],
|
|
'time': [1723161600000, 1723161900000, 1723162200000],
|
|
'open': [100, 101, 102],
|
|
'high': [110, 111, 112],
|
|
'low': [90, 91, 92],
|
|
'close': [105, 106, 107],
|
|
'volume': [1000, 1001, 1002]
|
|
})
|
|
|
|
more_records = pd.DataFrame({
|
|
'market_id': [1, 1, 1],
|
|
'time': [1723161600000, 1723161900000, 1723162500000], # Overlap at index 0 and 1
|
|
'open': [100, 101, 100],
|
|
'high': [110, 111, 110],
|
|
'low': [90, 91, 90],
|
|
'close': [105, 106, 105],
|
|
'volume': [1000, 1001, 1000]
|
|
})
|
|
|
|
# Run analysis
|
|
self.analyze_cache_update(existing_records, more_records)
|
|
|
|
def _test_get_records_since(self, set_cache=True, set_db=True, query_offset=None, num_rec=None, ex_details=None,
|
|
simulate_scenarios=None):
|
|
"""
|
|
Test the get_records_since() method by generating a table of simulated data,
|
|
inserting it into data and/or database, and then querying the records.
|
|
|
|
Parameters:
|
|
set_cache (bool): If True, the generated table is inserted into the cache.
|
|
set_db (bool): If True, the generated table is inserted into the database.
|
|
query_offset (int, optional): The offset in the timeframe units for the query.
|
|
num_rec (int, optional): The number of records to generate in the simulated table.
|
|
ex_details (list, optional): Exchange details to generate the data key.
|
|
simulate_scenarios (str, optional): The type of scenario to simulate. Options are:
|
|
- 'not_enough_data': The table data doesn't go far enough back.
|
|
- 'incomplete_data': The table doesn't have enough records to satisfy the query.
|
|
- 'missing_section': The table has missing records in the middle.
|
|
"""
|
|
|
|
print('Testing get_records_since() method:')
|
|
|
|
ex_details = ex_details or self.ex_details
|
|
key = f'{ex_details[0]}_{ex_details[1]}_{ex_details[2]}'
|
|
num_rec = num_rec or 12
|
|
table_timeframe = ex_details[1]
|
|
|
|
data_gen = DataGenerator(table_timeframe)
|
|
|
|
if simulate_scenarios == 'not_enough_data':
|
|
query_offset = (num_rec + 5) * data_gen.timeframe_amount
|
|
else:
|
|
query_offset = query_offset or (num_rec - 1) * data_gen.timeframe_amount
|
|
|
|
if simulate_scenarios == 'incomplete_data':
|
|
start_time_for_data = data_gen.x_time_ago(num_rec * data_gen.timeframe_amount)
|
|
num_rec = 5
|
|
else:
|
|
start_time_for_data = None
|
|
|
|
df_initial = data_gen.create_table(num_rec, start=start_time_for_data)
|
|
|
|
if simulate_scenarios == 'missing_section':
|
|
df_initial = data_gen.generate_missing_section(df_initial, drop_start=2, drop_end=5)
|
|
|
|
temp_df = df_initial.copy()
|
|
temp_df['time'] = pd.to_datetime(temp_df['time'], unit='ms')
|
|
print(f'Table Created:\n{temp_df}')
|
|
|
|
if set_cache:
|
|
print('Ensuring the cache exists and then inserting table into the cache.')
|
|
self.data.set_cache_item(data=df_initial, key=key, cache_name='candles')
|
|
|
|
if set_db:
|
|
print('Inserting table into the database.')
|
|
with SQLite(self.db_file) as con:
|
|
df_initial.to_sql(key, con, if_exists='replace', index=False)
|
|
|
|
start_datetime = data_gen.x_time_ago(query_offset)
|
|
if start_datetime.tzinfo is None:
|
|
start_datetime = start_datetime.replace(tzinfo=dt.timezone.utc)
|
|
|
|
query_end_time = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc)
|
|
print(f'Requesting records from {start_datetime} to {query_end_time}')
|
|
|
|
result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details)
|
|
|
|
expected = df_initial[df_initial['time'] >= data_gen.unix_time_millis(start_datetime)].reset_index(
|
|
drop=True)
|
|
temp_df = expected.copy()
|
|
temp_df['time'] = pd.to_datetime(temp_df['time'], unit='ms')
|
|
print(f'Expected table:\n{temp_df}')
|
|
|
|
temp_df = result.copy()
|
|
temp_df['time'] = pd.to_datetime(temp_df['time'], unit='ms')
|
|
print(f'Resulting table:\n{temp_df}')
|
|
|
|
if simulate_scenarios in ['not_enough_data', 'incomplete_data', 'missing_section']:
|
|
assert result.shape[0] > expected.shape[
|
|
0], "Result has fewer or equal rows compared to the incomplete data."
|
|
print("\nThe returned DataFrame has filled in the missing data!")
|
|
else:
|
|
assert result.shape == expected.shape, f"Shape mismatch: {result.shape} vs {expected.shape}"
|
|
pd.testing.assert_series_equal(result['time'], expected['time'], check_dtype=False)
|
|
print("\nThe DataFrames have the same shape and the 'time' columns match.")
|
|
|
|
oldest_timestamp = pd.to_datetime(result['time'].min(), unit='ms').tz_localize('UTC')
|
|
time_diff = oldest_timestamp - start_datetime
|
|
max_allowed_time_diff = dt.timedelta(**{data_gen.timeframe_unit: data_gen.timeframe_amount})
|
|
|
|
assert dt.timedelta(0) <= time_diff <= max_allowed_time_diff, \
|
|
f"Oldest timestamp {oldest_timestamp} is not within " \
|
|
f"{data_gen.timeframe_amount} {data_gen.timeframe_unit} of {start_datetime}"
|
|
|
|
print(f'The first timestamp is {time_diff} from {start_datetime}')
|
|
|
|
newest_timestamp = pd.to_datetime(result['time'].max(), unit='ms').tz_localize('UTC')
|
|
time_diff_end = abs(query_end_time - newest_timestamp)
|
|
|
|
assert dt.timedelta(0) <= time_diff_end <= max_allowed_time_diff, \
|
|
f"Newest timestamp {newest_timestamp} is not within {data_gen.timeframe_amount} " \
|
|
f"{data_gen.timeframe_unit} of {query_end_time}"
|
|
|
|
print(f'The last timestamp is {time_diff_end} from {query_end_time}')
|
|
print(' - Fetch records within the specified time range passed.')
|
|
|
|
def test_get_records_since(self):
|
|
self.connect_exchanges()
|
|
self.set_up_database()
|
|
self.load_test_data()
|
|
|
|
print('\nTest get_records_since with records set in data')
|
|
self._test_get_records_since()
|
|
|
|
print('\nTest get_records_since with records not in data')
|
|
self._test_get_records_since(set_cache=False)
|
|
|
|
print('\nTest get_records_since with records not in database')
|
|
self._test_get_records_since(set_cache=False, set_db=False)
|
|
|
|
print('\nTest get_records_since with a different timeframe')
|
|
self._test_get_records_since(query_offset=None, num_rec=None,
|
|
ex_details=['BTC/USD', '15m', 'binance', 'test_guy'])
|
|
|
|
print('\nTest get_records_since where data does not go far enough back')
|
|
self._test_get_records_since(simulate_scenarios='not_enough_data')
|
|
|
|
print('\nTest get_records_since with incomplete data')
|
|
self._test_get_records_since(simulate_scenarios='incomplete_data')
|
|
|
|
print('\nTest get_records_since with missing section in data')
|
|
self._test_get_records_since(simulate_scenarios='missing_section')
|
|
|
|
def test_other_timeframes(self):
|
|
print('\nTest get_records_since with a different timeframe')
|
|
|
|
ex_details = ['BTC/USD', '15m', 'binance', 'test_guy']
|
|
start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=2)
|
|
# Query the records since the calculated start time.
|
|
result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details)
|
|
last_record_time = pd.to_datetime(result['time'].max(), unit='ms').tz_localize('UTC')
|
|
assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=15.1)
|
|
|
|
print('\nTest get_records_since with a different timeframe')
|
|
ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
|
|
start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=1)
|
|
# Query the records since the calculated start time.
|
|
result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details)
|
|
last_record_time = pd.to_datetime(result['time'].max(), unit='ms').tz_localize('UTC')
|
|
assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=5.1)
|
|
|
|
print('\nTest get_records_since with a different timeframe')
|
|
ex_details = ['BTC/USD', '4h', 'binance', 'test_guy']
|
|
start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=12)
|
|
# Query the records since the calculated start time.
|
|
result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details)
|
|
last_record_time = pd.to_datetime(result['time'].max(), unit='ms').tz_localize('UTC')
|
|
assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=4.1)
|
|
|
|
def test_populate_db(self):
|
|
print('Testing _populate_db() method:')
|
|
# Create a table of candle records.
|
|
data_gen = DataGenerator(self.ex_details[1])
|
|
data = data_gen.create_table(num_rec=5)
|
|
|
|
self.data._populate_db(ex_details=self.ex_details, data=data)
|
|
|
|
with SQLite(self.db_file) as con:
|
|
result = pd.read_sql(f'SELECT * FROM "{self.key}"', con)
|
|
self.assertFalse(result.empty)
|
|
print(' - Populate database with data passed.')
|
|
|
|
def test_fetch_candles_from_exchange(self):
|
|
print('Testing _fetch_candles_from_exchange() method:')
|
|
|
|
# Define start and end times for the data fetch
|
|
start_time = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc) - dt.timedelta(days=1)
|
|
end_time = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc)
|
|
|
|
# Fetch the candles from the exchange using the method
|
|
result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h',
|
|
exchange_name='binance', user_name='test_guy',
|
|
start_datetime=start_time, end_datetime=end_time)
|
|
|
|
# Validate that the result is a DataFrame
|
|
self.assertIsInstance(result, pd.DataFrame)
|
|
|
|
# Validate that the DataFrame is not empty
|
|
self.assertFalse(result.empty, "The DataFrame returned from the exchange is empty.")
|
|
|
|
# Ensure that the 'time' column exists in the DataFrame
|
|
self.assertIn('time', result.columns, "'time' column is missing in the result DataFrame.")
|
|
|
|
# Check if the DataFrame contains valid timestamps within the specified range
|
|
min_time = pd.to_datetime(result['time'].min(), unit='ms').tz_localize('UTC')
|
|
max_time = pd.to_datetime(result['time'].max(), unit='ms').tz_localize('UTC')
|
|
self.assertTrue(start_time <= min_time <= end_time, f"Data starts outside the expected range: {min_time}")
|
|
self.assertTrue(start_time <= max_time <= end_time, f"Data ends outside the expected range: {max_time}")
|
|
|
|
print(' - Fetch candle data from exchange passed.')
|
|
|
|
def test_remove_row(self):
|
|
# Step 1: Create the cache and insert data
|
|
self.data.create_cache('users', cache_type='table') # Create 'users' cache for this test
|
|
|
|
# Insert test data into the cache and database
|
|
df = pd.DataFrame({
|
|
'user_name': ['Alice', 'Bob', 'Charlie'],
|
|
'age': [30, 25, 40],
|
|
'users_data': ['data1', 'data2', 'data3'],
|
|
'data': ['info1', 'info2', 'info3'],
|
|
'password': ['pass1', 'pass2', 'pass3']
|
|
})
|
|
self.data.insert_df_into_datacache(df=df, cache_name="users", skip_cache=False)
|
|
|
|
# Scenario 1: Remove a row from both cache and database
|
|
filter_vals = [('user_name', 'Bob')]
|
|
self.data.remove_row_from_datacache(cache_name="users", filter_vals=filter_vals, remove_from_db=True)
|
|
|
|
# Verify the row was removed from the cache
|
|
cache = self.data.get_cache('users')
|
|
cached_data = cache.get_all_items()
|
|
self.assertEqual(len(cached_data), 2) # Only 2 rows should remain
|
|
self.assertNotIn('Bob', cached_data['user_name'].values) # Ensure 'Bob' is not in the cache
|
|
|
|
# Verify the row was removed from the database
|
|
with SQLite(db_file=self.db_file) as con:
|
|
remaining_users = pd.read_sql_query("SELECT * FROM users", con)
|
|
self.assertEqual(len(remaining_users), 2) # Ensure 2 rows remain in the database
|
|
self.assertNotIn('Bob', remaining_users['user_name'].values) # Ensure 'Bob' is not in the database
|
|
|
|
# Scenario 2: Remove a row from the cache only (not from the database)
|
|
filter_vals = [('user_name', 'Charlie')]
|
|
self.data.remove_row_from_datacache(cache_name="users", filter_vals=filter_vals, remove_from_db=False)
|
|
|
|
# Verify the row was removed from the cache
|
|
cached_data = cache.get_all_items()
|
|
self.assertEqual(len(cached_data), 1) # Only 1 row should remain in the cache
|
|
self.assertNotIn('Charlie', cached_data['user_name'].values) # Ensure 'Charlie' is not in the cache
|
|
|
|
# Verify the row still exists in the database
|
|
with SQLite(db_file=self.db_file) as con:
|
|
remaining_users = pd.read_sql_query("SELECT * FROM users", con)
|
|
self.assertEqual(len(remaining_users), 2) # Ensure Charlie is still in the database
|
|
self.assertIn('Charlie', remaining_users['user_name'].values) # Charlie should still exist in the database
|
|
|
|
# Scenario 3: Try removing from a non-existing cache (expecting KeyError)
|
|
filter_vals = [('user_name', 'Bob')]
|
|
with self.assertRaises(KeyError) as context:
|
|
self.data.remove_row_from_datacache(cache_name="non_existing_cache", filter_vals=filter_vals, remove_from_db=True)
|
|
self.assertEqual(context.exception.args[0], "Cache: non_existing_cache, does not exist.")
|
|
|
|
# Scenario 4: Invalid filter_vals format (expecting ValueError)
|
|
invalid_filter_vals = 'invalid_filter' # Not a list of tuples
|
|
with self.assertRaises(ValueError) as context:
|
|
self.data.remove_row_from_datacache(cache_name="users", filter_vals=invalid_filter_vals, remove_from_db=True)
|
|
self.assertEqual(str(context.exception), "filter_vals must be a list of tuples (column, value)")
|
|
|
|
def test_timeframe_to_timedelta(self):
|
|
print('Testing timeframe_to_timedelta() function:')
|
|
|
|
result = timeframe_to_timedelta('2h')
|
|
expected = pd.Timedelta(hours=2)
|
|
self.assertEqual(result, expected, "Failed to convert '2h' to Timedelta")
|
|
|
|
result = timeframe_to_timedelta('5m')
|
|
expected = pd.Timedelta(minutes=5)
|
|
self.assertEqual(result, expected, "Failed to convert '5m' to Timedelta")
|
|
|
|
result = timeframe_to_timedelta('1d')
|
|
expected = pd.Timedelta(days=1)
|
|
self.assertEqual(result, expected, "Failed to convert '1d' to Timedelta")
|
|
|
|
result = timeframe_to_timedelta('3M')
|
|
expected = pd.DateOffset(months=3)
|
|
self.assertEqual(result, expected, "Failed to convert '3M' to DateOffset")
|
|
|
|
result = timeframe_to_timedelta('1Y')
|
|
expected = pd.DateOffset(years=1)
|
|
self.assertEqual(result, expected, "Failed to convert '1Y' to DateOffset")
|
|
|
|
with self.assertRaises(ValueError):
|
|
timeframe_to_timedelta('5x')
|
|
|
|
print(' - All timeframe_to_timedelta() tests passed.')
|
|
|
|
def test_estimate_record_count(self):
|
|
print('Testing estimate_record_count() function:')
|
|
|
|
# Test with '1h' timeframe (24 records expected)
|
|
start_time = dt.datetime(2023, 8, 1, 0, 0, 0, tzinfo=dt.timezone.utc)
|
|
end_time = dt.datetime(2023, 8, 2, 0, 0, 0, tzinfo=dt.timezone.utc)
|
|
result = estimate_record_count(start_time, end_time, '1h')
|
|
self.assertEqual(result, 24, "Failed to estimate record count for 1h timeframe")
|
|
|
|
# Test with '1d' timeframe (1 record expected)
|
|
result = estimate_record_count(start_time, end_time, '1d')
|
|
self.assertEqual(result, 1, "Failed to estimate record count for 1d timeframe")
|
|
|
|
# Test with '1h' timeframe and timestamps in milliseconds
|
|
start_time_ms = int(start_time.timestamp() * 1000) # Convert to milliseconds
|
|
end_time_ms = int(end_time.timestamp() * 1000) # Convert to milliseconds
|
|
result = estimate_record_count(start_time_ms, end_time_ms, '1h')
|
|
self.assertEqual(result, 24, "Failed to estimate record count for 1h timeframe with milliseconds")
|
|
|
|
# Test with '5m' timeframe and Unix timestamps in milliseconds
|
|
start_time_ms = 1672531200000 # Equivalent to '2023-01-01 00:00:00 UTC'
|
|
end_time_ms = 1672534800000 # Equivalent to '2023-01-01 01:00:00 UTC'
|
|
result = estimate_record_count(start_time_ms, end_time_ms, '5m')
|
|
self.assertEqual(result, 12, "Failed to estimate record count for 5m timeframe with milliseconds")
|
|
|
|
# Test with '5m' timeframe (12 records expected for 1-hour duration)
|
|
start_time = dt.datetime(2023, 1, 1, 0, 0, tzinfo=dt.timezone.utc)
|
|
end_time = dt.datetime(2023, 1, 1, 1, 0, tzinfo=dt.timezone.utc)
|
|
result = estimate_record_count(start_time, end_time, '5m')
|
|
self.assertEqual(result, 12, "Failed to estimate record count for 5m timeframe")
|
|
|
|
# Test with '1M' (3 records expected for 3 months)
|
|
start_time = dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc)
|
|
end_time = dt.datetime(2023, 4, 1, tzinfo=dt.timezone.utc)
|
|
result = estimate_record_count(start_time, end_time, '1M')
|
|
self.assertEqual(result, 3, "Failed to estimate record count for 1M timeframe")
|
|
|
|
# Test with invalid timeframe
|
|
with self.assertRaises(ValueError):
|
|
estimate_record_count(start_time, end_time, 'xyz') # Invalid timeframe
|
|
|
|
# Test with invalid start_time passed in
|
|
with self.assertRaises(ValueError):
|
|
estimate_record_count("invalid_start", end_time, '1h')
|
|
|
|
# Cross-Year Transition (Months)
|
|
start_time = dt.datetime(2022, 12, 1, tzinfo=dt.timezone.utc)
|
|
end_time = dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc)
|
|
result = estimate_record_count(start_time, end_time, '1M')
|
|
self.assertEqual(result, 1, "Failed to estimate record count for month across years")
|
|
|
|
# Leap Year (Months)
|
|
start_time = dt.datetime(2020, 2, 1, tzinfo=dt.timezone.utc)
|
|
end_time = dt.datetime(2021, 2, 1, tzinfo=dt.timezone.utc)
|
|
result = estimate_record_count(start_time, end_time, '1M')
|
|
self.assertEqual(result, 12, "Failed to estimate record count for months during leap year")
|
|
|
|
# Sub-Minute Timeframes (e.g., 30 seconds)
|
|
start_time = dt.datetime(2023, 1, 1, 0, 0, tzinfo=dt.timezone.utc)
|
|
end_time = dt.datetime(2023, 1, 1, 0, 1, tzinfo=dt.timezone.utc)
|
|
result = estimate_record_count(start_time, end_time, '30s')
|
|
self.assertEqual(result, 2, "Failed to estimate record count for 30 seconds timeframe")
|
|
|
|
# Different Timezones
|
|
start_time = dt.datetime(2023, 1, 1, 0, 0, tzinfo=dt.timezone(dt.timedelta(hours=5))) # UTC+5
|
|
end_time = dt.datetime(2023, 1, 1, 1, 0, tzinfo=dt.timezone.utc) # UTC
|
|
result = estimate_record_count(start_time, end_time, '1h')
|
|
self.assertEqual(result, 6,
|
|
"Failed to estimate record count for different timezones") # Expect 6 records, not 1
|
|
|
|
# Test with zero-length interval (should return 0)
|
|
result = estimate_record_count(start_time, start_time, '1h')
|
|
self.assertEqual(result, 0, "Failed to return 0 for zero-length interval")
|
|
|
|
# Test with negative interval (end_time earlier than start_time, should return 0)
|
|
result = estimate_record_count(end_time, start_time, '1h')
|
|
self.assertEqual(result, 0, "Failed to return 0 for negative interval")
|
|
|
|
# Test with small interval compared to timeframe (should return 0)
|
|
start_time = dt.datetime(2023, 8, 1, 0, 0, tzinfo=dt.timezone.utc)
|
|
end_time = dt.datetime(2023, 8, 1, 0, 30, tzinfo=dt.timezone.utc) # 30 minutes
|
|
result = estimate_record_count(start_time, end_time, '1h')
|
|
self.assertEqual(result, 0, "Failed to return 0 for small interval compared to timeframe")
|
|
|
|
print(' - All estimate_record_count() tests passed.')
|
|
|
|
def test_get_or_fetch_rows(self):
|
|
# Create mock DataFrames for different users
|
|
df1 = pd.DataFrame({
|
|
'user_name': ['billy'],
|
|
'password': ['1234'],
|
|
'exchanges': [['ex1', 'ex2', 'ex3']]
|
|
})
|
|
|
|
df2 = pd.DataFrame({
|
|
'user_name': ['john'],
|
|
'password': ['5678'],
|
|
'exchanges': [['ex4', 'ex5', 'ex6']]
|
|
})
|
|
|
|
df3 = pd.DataFrame({
|
|
'user_name': ['alice'],
|
|
'password': ['91011'],
|
|
'exchanges': [['ex7', 'ex8', 'ex9']]
|
|
})
|
|
|
|
# Insert these DataFrames into the 'users' cache with row-based caching
|
|
self.data.create_cache('users', cache_type='row') # Assuming 'row' cache type for this test
|
|
self.data.set_cache_item(key='user_billy', data=df1, cache_name='users')
|
|
self.data.set_cache_item(key='user_john', data=df2, cache_name='users')
|
|
self.data.set_cache_item(key='user_alice', data=df3, cache_name='users')
|
|
|
|
print('Testing get_or_fetch_rows() method:')
|
|
|
|
# Fetch user directly by key since this is a row-based cache
|
|
result = self.data.get_cache_item(key='user_billy', cache_name='users')
|
|
self.assertIsInstance(result, pd.DataFrame, "Failed to fetch DataFrame from cache")
|
|
self.assertFalse(result.empty, "The fetched DataFrame is empty")
|
|
self.assertEqual(result.iloc[0]['password'], '1234', "Incorrect data fetched from cache")
|
|
|
|
# Fetch another user by key
|
|
result = self.data.get_cache_item(key='user_john', cache_name='users')
|
|
self.assertIsInstance(result, pd.DataFrame, "Failed to fetch DataFrame from cache")
|
|
self.assertFalse(result.empty, "The fetched DataFrame is empty")
|
|
self.assertEqual(result.iloc[0]['password'], '5678', "Incorrect data fetched from cache")
|
|
|
|
# Test fetching a user that does not exist in the cache
|
|
result = self.data.get_cache_item(key='non_existent_user', cache_name='users')
|
|
|
|
# Check if result is None (indicating that no data was found)
|
|
self.assertIsNone(result, "Expected result to be None for a non-existent user")
|
|
|
|
print(' - Fetching rows from cache passed.')
|
|
|
|
def test_is_attr_taken(self):
|
|
# Create a cache named 'users'
|
|
user_cache = self.data.create_cache('users', cache_type='table')
|
|
|
|
# Create mock data for three users
|
|
user_data_1 = pd.DataFrame({
|
|
'user_name': ['billy'],
|
|
'password': ['1234'],
|
|
'exchanges': [['ex1', 'ex2', 'ex3']]
|
|
})
|
|
user_data_2 = pd.DataFrame({
|
|
'user_name': ['john'],
|
|
'password': ['5678'],
|
|
'exchanges': [['ex1', 'ex2', 'ex4']]
|
|
})
|
|
user_data_3 = pd.DataFrame({
|
|
'user_name': ['alice'],
|
|
'password': ['abcd'],
|
|
'exchanges': [['ex5', 'ex6', 'ex7']]
|
|
})
|
|
|
|
# Insert mock data into the cache
|
|
self.data.set_cache_item(cache_name='users', data=user_data_1)
|
|
self.data.set_cache_item(cache_name='users', data=user_data_2)
|
|
self.data.set_cache_item(cache_name='users', data=user_data_3)
|
|
|
|
# Test when attribute value is taken
|
|
result_taken = user_cache.is_attr_taken('user_name', 'billy')
|
|
self.assertTrue(result_taken, "Expected 'billy' to be taken, but it was not.")
|
|
|
|
# Test when attribute value is not taken
|
|
result_not_taken = user_cache.is_attr_taken('user_name', 'charlie')
|
|
self.assertFalse(result_not_taken, "Expected 'charlie' not to be taken, but it was.")
|
|
|
|
def test_insert_df_row_based_cache(self):
|
|
self._test_insert_df(cache_type='row')
|
|
|
|
def test_insert_df_table_based_cache(self):
|
|
self._test_insert_df(cache_type='table')
|
|
|
|
def _test_insert_df(self, cache_type):
|
|
self.data.create_cache('users', cache_type=cache_type) # Create 'users' cache for this test
|
|
|
|
# Arrange: Create a simple DataFrame to insert
|
|
df = pd.DataFrame({
|
|
'user_name': ['Alice', 'Bob'],
|
|
'age': [30, 25],
|
|
'users_data': ['data1', 'data2'],
|
|
'data': ['info1', 'info2'],
|
|
'password': ['pass1', 'pass2']
|
|
})
|
|
|
|
# Ensure the users table exists in the database and clear any existing data
|
|
with SQLite(db_file=self.db_file) as con:
|
|
con.execute("DELETE FROM users;") # Clear existing data for clean testing
|
|
|
|
# Act: Insert the DataFrame into the 'users' table without skipping the cache
|
|
self.data.insert_df_into_datacache(df=df, cache_name="users", skip_cache=False)
|
|
|
|
# Assert: Verify the data was correctly inserted into the database
|
|
with SQLite(db_file=self.db_file) as con:
|
|
inserted_users = pd.read_sql_query("SELECT * FROM users", con)
|
|
|
|
self.assertEqual(len(inserted_users), 2) # Ensure both rows are inserted
|
|
self.assertEqual(inserted_users.iloc[0]['user_name'], 'Alice') # Verify first row data
|
|
self.assertEqual(inserted_users.iloc[1]['user_name'], 'Bob') # Verify second row data
|
|
|
|
# Verify cache behavior (RowBasedCache)
|
|
cache = self.data.get_cache('users')
|
|
if isinstance(cache, RowBasedCache):
|
|
# Check if each row is added to the cache (in row-based cache)
|
|
cached_user1 = cache.get_entry('Alice')
|
|
cached_user2 = cache.get_entry('Bob')
|
|
|
|
self.assertIsNotNone(cached_user1) # Ensure user 'Alice' is cached
|
|
self.assertIsNotNone(cached_user2) # Ensure user 'Bob' is cached
|
|
self.assertEqual(cached_user1.iloc[0]['user_name'], 'Alice') # Verify cache content for Alice
|
|
self.assertEqual(cached_user2.iloc[0]['user_name'], 'Bob') # Verify cache content for Bob
|
|
|
|
elif isinstance(cache, TableBasedCache):
|
|
# For TableBasedCache, check if the entire DataFrame is cached
|
|
cached_data = cache.get_all_items()
|
|
self.assertEqual(len(cached_data), 2) # Ensure both rows are cached
|
|
self.assertEqual(cached_data.iloc[0]['user_name'], 'Alice')
|
|
self.assertEqual(cached_data.iloc[1]['user_name'], 'Bob')
|
|
|
|
def test_insert_row(self):
|
|
print("Testing insert_row() method:")
|
|
|
|
# Define the cache name, columns, and values to insert
|
|
cache_name = 'users'
|
|
columns = ('user_name', 'age')
|
|
values = ('Alice', 30)
|
|
|
|
# Create the cache with a row-based cache type (if that's how your system works now)
|
|
user_cache = self.data.create_cache(cache_name, cache_type='row')
|
|
|
|
# Insert a row into the cache and database without skipping the cache
|
|
# Ensure 'key' is passed correctly, if needed (depending on how `insert_row` works now)
|
|
self.data.insert_row_into_datacache(cache_name=cache_name, columns=columns, values=values, key='1', skip_cache=False)
|
|
|
|
# Retrieve the inserted item from the cache using the correct method
|
|
result = user_cache.get_entry(key='1')
|
|
|
|
# Assert that the data in the cache matches what was inserted
|
|
self.assertIsNotNone(result, "No data found in the cache for the inserted ID.")
|
|
self.assertEqual(result.iloc[0]['user_name'], 'Alice',
|
|
"The name in the cache doesn't match the inserted value.")
|
|
self.assertEqual(result.iloc[0]['age'], 30, "The age in the cache does not match the inserted value.")
|
|
|
|
# Now test with skipping the cache
|
|
print("Testing insert_row() with skip_cache=True")
|
|
|
|
# Insert another row into the database, this time skipping the cache
|
|
self.data.insert_row_into_datacache(cache_name=cache_name, columns=columns, values=('Bob', 40), key='2', skip_cache=True)
|
|
|
|
# Attempt to retrieve the newly inserted row from the cache
|
|
result_after_skip = user_cache.get_entry(key='2')
|
|
|
|
# Assert that no data is found in the cache for the new row
|
|
self.assertIsNone(result_after_skip, "Data should not have been cached when skip_cache=True.")
|
|
|
|
print(" - Insert row with and without caching passed all checks.")
|
|
|
|
def test_fill_data_holes(self):
|
|
print('Testing _fill_data_holes() method:')
|
|
|
|
# Create mock data with gaps
|
|
df = pd.DataFrame({
|
|
'time': [dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc).timestamp() * 1000,
|
|
dt.datetime(2023, 1, 1, 2, tzinfo=dt.timezone.utc).timestamp() * 1000,
|
|
dt.datetime(2023, 1, 1, 6, tzinfo=dt.timezone.utc).timestamp() * 1000,
|
|
dt.datetime(2023, 1, 1, 8, tzinfo=dt.timezone.utc).timestamp() * 1000,
|
|
dt.datetime(2023, 1, 1, 12, tzinfo=dt.timezone.utc).timestamp() * 1000]
|
|
})
|
|
|
|
# Call the method
|
|
result = self.data._fill_data_holes(records=df, interval='2h')
|
|
self.assertEqual(len(result), 7, "Data holes were not filled correctly.")
|
|
print(' - _fill_data_holes passed.')
|
|
|
|
def test_get_cache_item(self):
|
|
self.load_prerequisites()
|
|
# Case 1: Retrieve a stored Indicator instance (serialized and deserialized)
|
|
indicator = Indicator(name='SMA', indicator_type='SMA', properties={'period': 5})
|
|
|
|
# Create a row-based cache for indicators and store serialized Indicator data
|
|
self.data.create_cache('indicators', cache_type='row')
|
|
self.data.set_cache_item(key='indicator_key', data=indicator, cache_name='indicators')
|
|
|
|
# Retrieve the indicator and check for deserialization
|
|
stored_data = self.data.get_cache_item('indicator_key', cache_name='indicators')
|
|
self.assertIsInstance(stored_data, Indicator, "Failed to retrieve and deserialize the Indicator instance")
|
|
|
|
# Case 2: Retrieve non-Indicator data (e.g., dict)
|
|
data_dict = {'key': 'value'}
|
|
|
|
# Create a cache for generic data (row-based)
|
|
self.data.create_cache('default_cache', cache_type='row')
|
|
|
|
# Store a dictionary
|
|
self.data.set_cache_item(key='dict_key', data=data_dict, cache_name='default_cache')
|
|
|
|
# Retrieve and check if the data matches the original dict
|
|
stored_data = self.data.get_cache_item('dict_key', cache_name='default_cache')
|
|
self.assertEqual(stored_data, data_dict, "Failed to retrieve non-Indicator data correctly")
|
|
|
|
# Case 3: Retrieve a list stored in the cache
|
|
data_list = [1, 2, 3, 4, 5]
|
|
|
|
# Store a list in row-based cache
|
|
self.data.set_cache_item(key='list_key', data=data_list, cache_name='default_cache')
|
|
|
|
# Retrieve and check if the data matches the original list
|
|
stored_data = self.data.get_cache_item('list_key', cache_name='default_cache')
|
|
self.assertEqual(stored_data, data_list, "Failed to retrieve list data correctly")
|
|
|
|
# Case 4: Retrieve a DataFrame stored in the cache (Table-Based Cache)
|
|
data_df = pd.DataFrame({
|
|
'column1': [10, 20, 30],
|
|
'column2': ['A', 'B', 'C']
|
|
})
|
|
|
|
# Create a table-based cache
|
|
self.data.create_cache('table_cache', cache_type='table')
|
|
|
|
# Store a DataFrame in table-based cache
|
|
self.data.set_cache_item(key='testkey', data=data_df, cache_name='table_cache')
|
|
|
|
# Retrieve and check if the DataFrame matches the original
|
|
stored_data = self.data.get_cache_item(key='testkey', cache_name='table_cache')
|
|
pd.testing.assert_frame_equal(stored_data, data_df)
|
|
|
|
# Case 5: Attempt to retrieve a non-existent key
|
|
non_existent = self.data.get_cache_item('non_existent_key', cache_name='default_cache')
|
|
self.assertIsNone(non_existent, "Expected None for non-existent cache key")
|
|
|
|
print(" - All get_cache_item tests passed.")
|
|
|
|
def test_set_user_indicator_properties(self):
|
|
# Case 1: Store user-specific display properties
|
|
user_id = 'user123'
|
|
indicator_type = 'SMA'
|
|
symbol = 'AAPL'
|
|
timeframe = '1h'
|
|
exchange_name = 'NYSE'
|
|
display_properties = {'color': 'blue', 'line_width': 2}
|
|
|
|
# Call the method to set properties
|
|
self.data.set_user_indicator_properties(user_id, indicator_type, symbol, timeframe, exchange_name,
|
|
display_properties)
|
|
|
|
# Construct the cache key manually for validation
|
|
user_cache_key = f"user_{user_id}_{indicator_type}_{symbol}_{timeframe}_{exchange_name}"
|
|
|
|
# Retrieve the stored properties
|
|
stored_properties = self.data.get_cache_item(user_cache_key, cache_name='user_display_properties')
|
|
|
|
# Check if the properties were stored correctly
|
|
self.assertEqual(stored_properties, display_properties, "Failed to store user-specific display properties")
|
|
|
|
# Case 2: Update existing user-specific properties
|
|
updated_properties = {'color': 'red', 'line_width': 3}
|
|
|
|
# Update the properties
|
|
self.data.set_user_indicator_properties(user_id, indicator_type, symbol, timeframe, exchange_name,
|
|
updated_properties)
|
|
|
|
# Retrieve the updated properties
|
|
updated_stored_properties = self.data.get_cache_item(user_cache_key, cache_name='user_display_properties')
|
|
|
|
# Check if the properties were updated correctly
|
|
self.assertEqual(updated_stored_properties, updated_properties,
|
|
"Failed to update user-specific display properties")
|
|
|
|
# Case 3: Handle invalid user properties (e.g., non-dict input)
|
|
with self.assertRaises(ValueError):
|
|
self.data.set_user_indicator_properties(user_id, indicator_type, symbol, timeframe, exchange_name,
|
|
"invalid_properties")
|
|
|
|
def test_get_user_indicator_properties(self):
|
|
# Case 1: Retrieve existing user-specific display properties
|
|
user_id = 'user123'
|
|
indicator_type = 'SMA'
|
|
symbol = 'AAPL'
|
|
timeframe = '1h'
|
|
exchange_name = 'NYSE'
|
|
display_properties = {'color': 'blue', 'line_width': 2}
|
|
|
|
# Set the properties first
|
|
self.data.set_user_indicator_properties(user_id, indicator_type, symbol, timeframe, exchange_name,
|
|
display_properties)
|
|
|
|
# Retrieve the properties
|
|
retrieved_properties = self.data.get_user_indicator_properties(user_id, indicator_type, symbol, timeframe,
|
|
exchange_name)
|
|
self.assertEqual(retrieved_properties, display_properties,
|
|
"Failed to retrieve user-specific display properties")
|
|
|
|
# Case 2: Handle missing key (should return None)
|
|
missing_properties = self.data.get_user_indicator_properties('nonexistent_user', indicator_type, symbol,
|
|
timeframe, exchange_name)
|
|
self.assertIsNone(missing_properties, "Expected None for missing user-specific display properties")
|
|
|
|
# Case 3: Invalid argument handling
|
|
with self.assertRaises(TypeError):
|
|
self.data.get_user_indicator_properties(123, indicator_type, symbol, timeframe,
|
|
exchange_name) # Invalid user_id type
|
|
|
|
def test_set_cache_item(self):
|
|
data_cache = self.data
|
|
|
|
# -------------------------
|
|
# Row-Based Cache Test Cases
|
|
# -------------------------
|
|
# Case 1: Store and retrieve an item in a RowBasedCache with a key
|
|
data_cache.create_cache('row_cache', cache_type='row') # Create row-based cache
|
|
key = 'row_key'
|
|
data = {'some': 'data'}
|
|
|
|
data_cache.set_cache_item(cache_name='row_cache', data=data, key=key)
|
|
cached_item = data_cache.get_cache_item(key, cache_name='row_cache')
|
|
self.assertEqual(cached_item, data, "Failed to store and retrieve data in RowBasedCache")
|
|
|
|
# Case 2: Store and retrieve an Indicator instance (serialization)
|
|
indicator = Indicator(name='SMA', indicator_type='SMA', properties={'period': 5})
|
|
data_cache.set_cache_item(cache_name='row_cache', data=indicator, key='indicator_key')
|
|
cached_indicator = data_cache.get_cache_item('indicator_key', cache_name='row_cache')
|
|
|
|
# Assert that the data was correctly serialized and deserialized
|
|
self.assertIsInstance(pickle.loads(cached_indicator), Indicator, "Failed to deserialize Indicator instance")
|
|
|
|
# Case 3: Prevent overwriting an existing key if do_not_overwrite=True
|
|
new_data = {'new': 'data'}
|
|
data_cache.set_cache_item(cache_name='row_cache', data=new_data, key=key, do_not_overwrite=True)
|
|
cached_item_after = data_cache.get_cache_item(key, cache_name='row_cache')
|
|
self.assertEqual(cached_item_after, data, "Overwriting occurred when it should have been prevented")
|
|
|
|
# Case 4: Raise ValueError if key is None in RowBasedCache
|
|
with self.assertRaises(ValueError, msg="RowBasedCache requires a key to store the data."):
|
|
data_cache.set_cache_item(cache_name='row_cache', data=data, key=None)
|
|
|
|
# -------------------------
|
|
# Table-Based Cache Test Cases
|
|
# -------------------------
|
|
# Case 5: Store and retrieve a DataFrame in a TableBasedCache
|
|
data_cache.create_cache('table_cache', cache_type='table') # Create table-based cache
|
|
df = pd.DataFrame({'col1': [1, 2], 'col2': ['A', 'B']})
|
|
|
|
data_cache.set_cache_item(cache_name='table_cache', data=df, key='table_key')
|
|
cached_df = data_cache.get_cache_item('table_key', cache_name='table_cache')
|
|
pd.testing.assert_frame_equal(cached_df, df, "Failed to store and retrieve DataFrame in TableBasedCache")
|
|
|
|
# Case 6: Prevent overwriting an existing key if do_not_overwrite=True in TableBasedCache
|
|
new_df = pd.DataFrame({'col1': [3, 4], 'col2': ['C', 'D']})
|
|
data_cache.set_cache_item(cache_name='table_cache', data=new_df, key='table_key', do_not_overwrite=True)
|
|
cached_df_after = data_cache.get_cache_item('table_key', cache_name='table_cache')
|
|
pd.testing.assert_frame_equal(cached_df_after, df, "Overwriting occurred when it should have been prevented")
|
|
|
|
# Case 7: Raise ValueError if non-DataFrame data is provided in TableBasedCache
|
|
with self.assertRaises(ValueError, msg="TableBasedCache can only store DataFrames."):
|
|
data_cache.set_cache_item(cache_name='table_cache', data={'not': 'a dataframe'}, key='table_key')
|
|
|
|
# -------------------------
|
|
# Expiration Handling Test Case
|
|
# -------------------------
|
|
# Case 8: Store an item with an expiration time (RowBasedCache)
|
|
key = 'expiring_key'
|
|
data = {'some': 'data'}
|
|
expire_delta = dt.timedelta(seconds=5)
|
|
|
|
data_cache.set_cache_item(cache_name='row_cache', data=data, key=key, expire_delta=expire_delta)
|
|
cached_item = data_cache.get_cache_item(key, cache_name='row_cache')
|
|
self.assertEqual(cached_item, data, "Failed to store and retrieve data with expiration")
|
|
|
|
# Wait for expiration to occur (ensure data is removed after expiration)
|
|
import time
|
|
time.sleep(6)
|
|
expired_item = data_cache.get_cache_item(key, cache_name='row_cache')
|
|
self.assertIsNone(expired_item, "Data was not removed after expiration time")
|
|
|
|
# -------------------------
|
|
# Invalid Cache Type Test Case
|
|
# -------------------------
|
|
# Case 9: Raise ValueError if unsupported cache type is provided
|
|
with self.assertRaises(KeyError, msg="Unsupported cache type for 'unsupported_cache'"):
|
|
data_cache.set_cache_item(cache_name='unsupported_cache', data={'some': 'data'}, key='some_key')
|
|
|
|
def test_calculate_and_cache_indicator(self):
|
|
# Testing the calculation and caching of an indicator through DataCache (which includes IndicatorCache
|
|
# functionality)
|
|
|
|
user_properties = {'color_line_1': 'blue', 'thickness_line_1': 2}
|
|
ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
|
|
|
|
# Define the time range for the calculation
|
|
start_datetime = dt.datetime(2023, 9, 1, 0, 0, 0, tzinfo=dt.timezone.utc)
|
|
end_datetime = dt.datetime(2023, 9, 2, 0, 0, 0, tzinfo=dt.timezone.utc)
|
|
|
|
# Simulate calculating an indicator and caching it through DataCache
|
|
result = self.data.calculate_indicator(
|
|
user_name='test_guy',
|
|
symbol=ex_details[0],
|
|
timeframe=ex_details[1],
|
|
exchange_name=ex_details[2],
|
|
indicator_type='SMA', # Type of indicator
|
|
start_datetime=start_datetime,
|
|
end_datetime=end_datetime,
|
|
properties={'period': 5} # Add the necessary indicator properties like period
|
|
)
|
|
|
|
# Ensure that result is not None
|
|
self.assertIsNotNone(result, "Indicator calculation returned None.")
|
|
|
|
def test_calculate_indicator_multiple_users(self):
|
|
"""
|
|
Test that the calculate_indicator method handles multiple users' requests with different properties.
|
|
"""
|
|
ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
|
|
user1_properties = {'color': 'blue', 'thickness': 2}
|
|
user2_properties = {'color': 'red', 'thickness': 1}
|
|
|
|
# Set user-specific properties
|
|
self.data.set_user_indicator_properties('user_1', 'SMA', 'BTC/USD', '5m', 'binance', user1_properties)
|
|
self.data.set_user_indicator_properties('user_2', 'SMA', 'BTC/USD', '5m', 'binance', user2_properties)
|
|
|
|
# User 1 calculates the SMA indicator
|
|
result_user1 = self.data.calculate_indicator(
|
|
user_name='user_1',
|
|
symbol='BTC/USD',
|
|
timeframe='5m',
|
|
exchange_name='binance',
|
|
indicator_type='SMA',
|
|
start_datetime=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc),
|
|
end_datetime=dt.datetime(2023, 1, 2, tzinfo=dt.timezone.utc),
|
|
properties={'period': 5}
|
|
)
|
|
|
|
# User 2 calculates the same SMA indicator but with different display properties
|
|
result_user2 = self.data.calculate_indicator(
|
|
user_name='user_2',
|
|
symbol='BTC/USD',
|
|
timeframe='5m',
|
|
exchange_name='binance',
|
|
indicator_type='SMA',
|
|
start_datetime=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc),
|
|
end_datetime=dt.datetime(2023, 1, 2, tzinfo=dt.timezone.utc),
|
|
properties={'period': 5}
|
|
)
|
|
|
|
# Assert that the calculation data is the same
|
|
self.assertEqual(result_user1['calculation_data'], result_user2['calculation_data'])
|
|
|
|
# Assert that the display properties are different
|
|
self.assertNotEqual(result_user1['display_properties'], result_user2['display_properties'])
|
|
|
|
# Assert that the correct display properties are returned
|
|
self.assertEqual(result_user1['display_properties']['color'], 'blue')
|
|
self.assertEqual(result_user2['display_properties']['color'], 'red')
|
|
|
|
def test_calculate_indicator_cache_retrieval(self):
|
|
"""
|
|
Test that cached data is retrieved efficiently without recalculating when the same request is made.
|
|
"""
|
|
ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
|
|
properties = {'period': 5}
|
|
cache_key = 'BTC/USD_5m_binance_SMA_5'
|
|
|
|
# First calculation (should store result in cache)
|
|
result_first = self.data.calculate_indicator(
|
|
user_name='user_1',
|
|
symbol='BTC/USD',
|
|
timeframe='5m',
|
|
exchange_name='binance',
|
|
indicator_type='SMA',
|
|
start_datetime=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc),
|
|
end_datetime=dt.datetime(2023, 1, 2, tzinfo=dt.timezone.utc),
|
|
properties=properties
|
|
)
|
|
|
|
# Check if the data was cached after the first calculation
|
|
cached_data = self.data.get_cache_item(cache_key, cache_name='indicator_data')
|
|
print(f"Cached Data after first calculation: {cached_data}")
|
|
|
|
# Ensure the data was cached correctly
|
|
self.assertIsNotNone(cached_data, "The first calculation did not cache the result properly.")
|
|
|
|
# Second calculation with the same parameters (should retrieve from cache)
|
|
with self.assertLogs(level='INFO') as log:
|
|
result_second = self.data.calculate_indicator(
|
|
user_name='user_1',
|
|
symbol='BTC/USD',
|
|
timeframe='5m',
|
|
exchange_name='binance',
|
|
indicator_type='SMA',
|
|
start_datetime=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc),
|
|
end_datetime=dt.datetime(2023, 1, 2, tzinfo=dt.timezone.utc),
|
|
properties=properties
|
|
)
|
|
# Verify the log message for cache retrieval
|
|
self.assertTrue(
|
|
any(f"DataFrame retrieved from cache for key: {cache_key}" in message for message in log.output),
|
|
f"Cache retrieval log message not found for key: {cache_key}"
|
|
)
|
|
|
|
def test_calculate_indicator_partial_cache(self):
|
|
"""
|
|
Test handling of partial cache where some of the requested data is already cached,
|
|
and the rest needs to be fetched.
|
|
"""
|
|
ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
|
|
properties = {'period': 5}
|
|
|
|
# Simulate cache for part of the range (manual setup, no call to `get_records_since`)
|
|
cached_data = pd.DataFrame({
|
|
'time': pd.date_range(start="2023-01-01", periods=144, freq='5min', tz=dt.timezone.utc),
|
|
# Cached half a day of data
|
|
'value': [16500 + i for i in range(144)]
|
|
})
|
|
|
|
# Generate cache key with correct format
|
|
cache_key = self.data._make_indicator_key('BTC/USD', '5m', 'binance', 'SMA', properties['period'])
|
|
|
|
# Store the cached data as DataFrame (no need for to_dict('records'))
|
|
self.data.set_cache_item(cache_name='indicator_data',data=cached_data, key=cache_key)
|
|
|
|
# Print cached data to inspect its range
|
|
print("Cached data time range:")
|
|
print(f"Min cached time: {cached_data['time'].min()}")
|
|
print(f"Max cached time: {cached_data['time'].max()}")
|
|
|
|
# Now request a range that partially overlaps the cached data
|
|
result = self.data.calculate_indicator(
|
|
user_name='user_1',
|
|
symbol='BTC/USD',
|
|
timeframe='5m',
|
|
exchange_name='binance',
|
|
indicator_type='SMA',
|
|
start_datetime=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc),
|
|
end_datetime=dt.datetime(2023, 1, 2, tzinfo=dt.timezone.utc),
|
|
properties=properties
|
|
)
|
|
|
|
# Convert the result into a DataFrame
|
|
result_df = pd.DataFrame(result['calculation_data'])
|
|
|
|
# Convert the 'time' column from Unix timestamp (ms) back to datetime with timezone
|
|
result_df['time'] = pd.to_datetime(result_df['time'], unit='ms', utc=True)
|
|
|
|
# Debugging: print the full result to inspect the time range
|
|
print("Result data time range:")
|
|
print(f"Min result time: {result_df['time'].min()}")
|
|
print(f"Max result time: {result_df['time'].max()}")
|
|
|
|
# Now you can safely find the min and max values
|
|
min_time = result_df['time'].min()
|
|
max_time = result_df['time'].max()
|
|
|
|
# Debugging print statements to confirm the values
|
|
print(f"Min time in result: {min_time}")
|
|
print(f"Max time in result: {max_time}")
|
|
|
|
# Assert that the min and max time in the result cover the full range from the cache and new data
|
|
self.assertEqual(min_time, pd.Timestamp("2023-01-01 00:00:00", tz=dt.timezone.utc))
|
|
self.assertEqual(max_time, pd.Timestamp("2023-01-02 00:00:00", tz=dt.timezone.utc))
|
|
|
|
def test_calculate_indicator_no_data(self):
|
|
"""
|
|
Test that the indicator calculation handles cases where no data is available for the requested range.
|
|
"""
|
|
ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
|
|
properties = {'period': 5}
|
|
|
|
# Request data for a period where no data exists
|
|
result = self.data.calculate_indicator(
|
|
user_name='user_1',
|
|
symbol='BTC/USD',
|
|
timeframe='5m',
|
|
exchange_name='binance',
|
|
indicator_type='SMA',
|
|
start_datetime=dt.datetime(1900, 1, 1, tzinfo=dt.timezone.utc),
|
|
end_datetime=dt.datetime(1900, 1, 2, tzinfo=dt.timezone.utc),
|
|
properties=properties
|
|
)
|
|
|
|
# Ensure no calculation data is returned
|
|
self.assertEqual(len(result['calculation_data']), 0)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|