brighter-trading/tests/test_DataCache.py

1393 lines
67 KiB
Python

import pickle
import time
import pytz
from DataCache_v3 import DataCache, timeframe_to_timedelta, estimate_record_count, InMemoryCache, DataCacheBase, \
SnapshotDataCache, IndicatorCache
from ExchangeInterface import ExchangeInterface
import unittest
import pandas as pd
import datetime as dt
import os
from Database import SQLite, Database
import logging
from indicators import Indicator
logging.basicConfig(level=logging.DEBUG)
class DataGenerator:
def __init__(self, timeframe_str):
"""
Initialize the DataGenerator with a timeframe string like '2h', '5m', '1d', '1w', '1M', or '1y'.
"""
# Initialize attributes with placeholder values
self.timeframe_amount = None
self.timeframe_unit = None
# Set the actual timeframe
self.set_timeframe(timeframe_str)
def set_timeframe(self, timeframe_str):
"""
Set the timeframe unit and amount based on a string like '2h', '5m', '1d', '1w', '1M', or '1y'.
"""
self.timeframe_amount = int(timeframe_str[:-1])
unit = timeframe_str[-1]
if unit == 's':
self.timeframe_unit = 'seconds'
elif unit == 'm':
self.timeframe_unit = 'minutes'
elif unit == 'h':
self.timeframe_unit = 'hours'
elif unit == 'd':
self.timeframe_unit = 'days'
elif unit == 'w':
self.timeframe_unit = 'weeks'
elif unit == 'M':
self.timeframe_unit = 'months'
elif unit == 'Y':
self.timeframe_unit = 'years'
else:
raise ValueError(
"Unsupported timeframe unit. Use 's,m,h,d,w,M,Y'.")
def create_table(self, num_rec=None, start=None, end=None):
"""
Create a table with simulated data. If both start and end are provided, num_rec is derived from the interval.
If neither are provided the table will have num_rec and end at the current time.
Parameters:
num_rec (int, optional): The number of records to generate.
start (datetime, optional): The start time for the first record.
end (datetime, optional): The end time for the last record.
Returns:
pd.DataFrame: A DataFrame with the simulated data.
"""
# Ensure provided datetime parameters are timezone aware
if start and start.tzinfo is None:
raise ValueError('start datetime must be timezone aware.')
if end and end.tzinfo is None:
raise ValueError('end datetime must be timezone aware.')
# If neither start nor end are provided.
if start is None and end is None:
end = dt.datetime.now(dt.timezone.utc)
if num_rec is None:
raise ValueError("num_rec must be provided if both start and end are not specified.")
# If start and end are provided.
if start is not None and end is not None:
total_duration = (end - start).total_seconds()
interval_seconds = self.timeframe_amount * self._get_seconds_per_unit(self.timeframe_unit)
num_rec = int(total_duration // interval_seconds) + 1
# If only end is provided.
if end is not None and start is None:
if num_rec is None:
raise ValueError("num_rec must be provided if both start and end are not specified.")
interval_seconds = self.timeframe_amount * self._get_seconds_per_unit(self.timeframe_unit)
start = end - dt.timedelta(seconds=(num_rec - 1) * interval_seconds)
start = start.replace(tzinfo=pytz.utc)
# Ensure start is aligned to the timeframe interval
start = self.round_down_datetime(start, self.timeframe_unit[0], self.timeframe_amount)
# Generate times
times = [self.unix_time_millis(start + self._delta(i)) for i in range(num_rec)]
df = pd.DataFrame({
'market_id': 1,
'time': times,
'open': [100 + i for i in range(num_rec)],
'high': [110 + i for i in range(num_rec)],
'low': [90 + i for i in range(num_rec)],
'close': [105 + i for i in range(num_rec)],
'volume': [1000 + i for i in range(num_rec)]
})
return df
@staticmethod
def _get_seconds_per_unit(unit):
"""Helper method to convert timeframe units to seconds."""
units_in_seconds = {
'seconds': 1,
'minutes': 60,
'hours': 3600,
'days': 86400,
'weeks': 604800,
'months': 2592000, # Assuming 30 days per month
'years': 31536000 # Assuming 365 days per year
}
if unit not in units_in_seconds:
raise ValueError(f"Unsupported timeframe unit: {unit}")
return units_in_seconds[unit]
def generate_incomplete_data(self, query_offset, num_rec=5):
"""
Generate data that is incomplete, i.e., starts before the query but doesn't fully satisfy it.
"""
query_start_time = self.x_time_ago(query_offset)
start_time_for_data = self.get_start_time(query_start_time)
return self.create_table(num_rec, start=start_time_for_data)
@staticmethod
def generate_missing_section(df, drop_start=5, drop_end=8):
"""
Generate data with a missing section.
"""
df = df.drop(df.index[drop_start:drop_end]).reset_index(drop=True)
return df
def get_start_time(self, query_start_time):
margin = 2
delta_args = {self.timeframe_unit: margin * self.timeframe_amount}
return query_start_time - dt.timedelta(**delta_args)
def x_time_ago(self, offset):
"""
Returns a datetime object representing the current time minus the offset in the specified units.
"""
delta_args = {self.timeframe_unit: offset}
return dt.datetime.utcnow().replace(tzinfo=pytz.utc) - dt.timedelta(**delta_args)
def _delta(self, i):
"""
Returns a timedelta object for the ith increment based on the timeframe unit and amount.
"""
delta_args = {self.timeframe_unit: i * self.timeframe_amount}
return dt.timedelta(**delta_args)
@staticmethod
def unix_time_millis(dt_obj: dt.datetime):
"""
Convert a datetime object to Unix time in milliseconds.
"""
if dt_obj.tzinfo is None:
raise ValueError('dt_obj needs to be timezone aware.')
epoch = dt.datetime(1970, 1, 1).replace(tzinfo=pytz.UTC)
return int((dt_obj - epoch).total_seconds() * 1000)
@staticmethod
def round_down_datetime(dt_obj: dt.datetime, unit: str, interval: int) -> dt.datetime:
if dt_obj.tzinfo is None:
raise ValueError('dt_obj needs to be timezone aware.')
if unit == 's': # Round down to the nearest interval of seconds
seconds = (dt_obj.second // interval) * interval
dt_obj = dt_obj.replace(second=seconds, microsecond=0)
elif unit == 'm': # Round down to the nearest interval of minutes
minutes = (dt_obj.minute // interval) * interval
dt_obj = dt_obj.replace(minute=minutes, second=0, microsecond=0)
elif unit == 'h': # Round down to the nearest interval of hours
hours = (dt_obj.hour // interval) * interval
dt_obj = dt_obj.replace(hour=hours, minute=0, second=0, microsecond=0)
elif unit == 'd': # Round down to the nearest interval of days
days = (dt_obj.day // interval) * interval
dt_obj = dt_obj.replace(day=days, hour=0, minute=0, second=0, microsecond=0)
elif unit == 'w': # Round down to the nearest interval of weeks
dt_obj -= dt.timedelta(days=dt_obj.weekday() % (interval * 7))
dt_obj = dt_obj.replace(hour=0, minute=0, second=0, microsecond=0)
elif unit == 'M': # Round down to the nearest interval of months
months = ((dt_obj.month - 1) // interval) * interval + 1
dt_obj = dt_obj.replace(month=months, day=1, hour=0, minute=0, second=0, microsecond=0)
elif unit == 'y': # Round down to the nearest interval of years
years = (dt_obj.year // interval) * interval
dt_obj = dt_obj.replace(year=years, month=1, day=1, hour=0, minute=0, second=0, microsecond=0)
return dt_obj
class TestDataCache(unittest.TestCase):
def setUp(self):
# Set up database and exchanges
self.exchanges = ExchangeInterface()
self.exchanges.connect_exchange(exchange_name='binance', user_name='test_guy', api_keys=None)
self.exchanges.connect_exchange(exchange_name='binance', user_name='user_1', api_keys=None)
self.exchanges.connect_exchange(exchange_name='binance', user_name='user_2', api_keys=None)
self.db_file = 'test_db.sqlite'
self.database = Database(db_file=self.db_file)
# Create necessary tables
sql_create_table_1 = f"""
CREATE TABLE IF NOT EXISTS test_table (
id INTEGER PRIMARY KEY,
market_id INTEGER,
time INTEGER UNIQUE ON CONFLICT IGNORE,
open REAL NOT NULL,
high REAL NOT NULL,
low REAL NOT NULL,
close REAL NOT NULL,
volume REAL NOT NULL,
FOREIGN KEY (market_id) REFERENCES market (id)
)"""
sql_create_table_2 = """
CREATE TABLE IF NOT EXISTS exchange (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE
)"""
sql_create_table_3 = """
CREATE TABLE IF NOT EXISTS markets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT,
exchange_id INTEGER,
FOREIGN KEY (exchange_id) REFERENCES exchange(id)
)"""
sql_create_table_4 = f"""
CREATE TABLE IF NOT EXISTS test_table_2 (
key TEXT PRIMARY KEY,
data TEXT NOT NULL
)"""
sql_create_table_5 = """
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT,
age INTEGER,
users_data TEXT,
data TEXT,
password TEXT -- Moved to a new line and added a comma after 'data'
)
"""
with SQLite(db_file=self.db_file) as con:
con.execute(sql_create_table_1)
con.execute(sql_create_table_2)
con.execute(sql_create_table_3)
con.execute(sql_create_table_4)
con.execute(sql_create_table_5)
# Initialize DataCache, which inherits IndicatorCache
self.data = DataCache(self.exchanges)
self.data.db = self.database # Keep the database setup
# Create caches needed for testing
self.data.create_cache('candles', cache_type=InMemoryCache)
# Reuse details for exchange and market
self.ex_details = ['BTC/USD', '2h', 'binance', 'test_guy']
self.key = f'{self.ex_details[0]}_{self.ex_details[1]}_{self.ex_details[2]}'
def tearDown(self):
if os.path.exists(self.db_file):
os.remove(self.db_file)
def test_InMemoryCache(self):
# Step 1: Create a cache with a limit of 2 items and 'evict' policy
print("Creating a cache with a limit of 2 items and 'evict' policy.")
cached_users = InMemoryCache(limit=2, eviction_policy='evict')
# Step 2: Set some items in the cache.
print("Setting 'user_bob' in the cache with an expiration of 10 seconds.")
cached_users.set_item("user_bob", "{password:'BobPass'}", expire_delta=dt.timedelta(seconds=10))
print("Setting 'user_alice' in the cache with an expiration of 20 seconds.")
cached_users.set_item("user_alice", "{password:'AlicePass'}", expire_delta=dt.timedelta(seconds=20))
# Step 3: Retrieve 'user_bob' from the cache
print("Retrieving 'user_bob' from the cache.")
retrieved_item = cached_users.get_item('user_bob')
print(f"Retrieved: {retrieved_item}")
assert retrieved_item == "{password:'BobPass'}", "user_bob should have been retrieved successfully."
# Step 4: Add another item, causing the oldest item to be evicted
print("Adding 'user_billy' to the cache, which should evict 'user_bob' due to the limit.")
cached_users.set_item("user_billy", "{password:'BillyPass'}")
# Step 5: Attempt to retrieve the evicted item 'user_bob'
print("Attempting to retrieve the evicted item 'user_bob'.")
evicted_item = cached_users.get_item('user_bob')
print(f"Evicted Item: {evicted_item}")
assert evicted_item is None, "user_bob should have been evicted from the cache."
# Step 6: Retrieve the current items in the cache
print("Retrieving all current items in the cache after eviction.")
all_items = cached_users.get_all_items()
print("Current items in cache:\n", all_items)
assert "user_alice" in all_items['key'].values, "user_alice should still be in the cache."
assert "user_billy" in all_items['key'].values, "user_billy should still be in the cache."
# Step 7: Simulate waiting for 'user_alice' to expire (assuming 20 seconds pass)
print("Simulating time passing to expire 'user_alice' (20 seconds).")
time.sleep(20) # This is to simulate the passage of time; in real tests, you may mock datetime.
# Step 8: Clean expired items from the cache
print("Cleaning expired items from the cache.")
cached_users.clean_expired_items()
# Step 9: Retrieve the current items in the cache after cleaning expired items
print("Retrieving all current items in the cache after cleaning expired items.")
all_items_after_cleaning = cached_users.get_all_items()
print("Current items in cache after cleaning:\n", all_items_after_cleaning)
assert "user_alice" not in all_items_after_cleaning[
'key'].values, "user_alice should have been expired and removed from the cache."
assert "user_billy" in all_items_after_cleaning['key'].values, "user_billy should still be in the cache."
# Step 10: Check if 'user_billy' still exists as it should not expire
print("Checking if 'user_billy' still exists in the cache (it should not have expired).")
user_billy_item = cached_users.get_item('user_billy')
print(f"'user_billy' still exists: {user_billy_item}")
assert user_billy_item == "{password:'BillyPass'}", "user_billy should still exist in the cache."
def test_DataCacheBase(self):
# Step 1: Create a DataCacheBase instance
print("Creating a DataCacheBase instance.")
cache_manager = DataCacheBase()
# Step 2: Set some items in 'my_cache'. The cache is created automatically with limit 2 and 'evict' policy.
print("Setting 'key1' in 'my_cache' with an expiration of 10 seconds.")
cache_manager.set_cache_item('key1', 'data1', expire_delta=dt.timedelta(seconds=10), cache_name='my_cache',
limit=2, eviction_policy='evict')
print("Setting 'key2' in 'my_cache' with an expiration of 20 seconds.")
cache_manager.set_cache_item('key2', 'data2', expire_delta=dt.timedelta(seconds=20), cache_name='my_cache')
# Step 3: Set some items in 'second_cache'. The cache is created automatically with limit 3 and 'deny' policy.
print("Setting 'keyA' in 'second_cache' with an expiration of 15 seconds.")
cache_manager.set_cache_item('keyA', 'dataA', expire_delta=dt.timedelta(seconds=15), cache_name='second_cache',
limit=3, eviction_policy='deny')
print("Setting 'keyB' in 'second_cache' with an expiration of 30 seconds.")
cache_manager.set_cache_item('keyB', 'dataB', expire_delta=dt.timedelta(seconds=30), cache_name='second_cache')
print("Setting 'keyC' in 'second_cache' with no expiration.")
cache_manager.set_cache_item('keyC', 'dataC', cache_name='second_cache')
# Step 4: Add another item to 'my_cache', causing the oldest item to be evicted.
print("Adding 'key3' to 'my_cache', which should evict 'key1' due to the limit.")
cache_manager.set_cache_item('key3', 'data3', cache_name='my_cache')
# Step 5: Attempt to retrieve the evicted item 'key1' from 'my_cache'.
print("Attempting to retrieve the evicted item 'key1' from 'my_cache'.")
evicted_item = cache_manager.get_cache_item('key1', cache_name='my_cache')
print(f"Evicted Item from 'my_cache': {evicted_item}")
assert evicted_item is None, "'key1' should have been evicted from 'my_cache'."
# Step 6: Retrieve all current items in both caches before cleaning.
print("Retrieving all current items in 'my_cache' before cleaning.")
all_items_my_cache = cache_manager.get_all_cache_items('my_cache')
print("Current items in 'my_cache':\n", all_items_my_cache)
print("Retrieving all current items in 'second_cache' before cleaning.")
all_items_second_cache = cache_manager.get_all_cache_items('second_cache')
print("Current items in 'second_cache':\n", all_items_second_cache)
# Step 7: Simulate time passing to expire 'key2' in 'my_cache' and 'keyA' in 'second_cache'.
print("Simulating time passing to expire 'key2' in 'my_cache' (20 seconds)"
" and 'keyA' in 'second_cache' (15 seconds).")
time.sleep(20) # Simulate the passage of time; in real tests, you may mock datetime.
# Step 8: Clean expired items in all caches
print("Cleaning expired items in all caches.")
cache_manager.clean_expired_items()
# Step 9: Verify the cleaning of expired items in 'my_cache'.
print("Retrieving all current items in 'my_cache' after cleaning expired items.")
all_items_after_cleaning_my_cache = cache_manager.get_all_cache_items('my_cache')
print("Items in 'my_cache' after cleaning:\n", all_items_after_cleaning_my_cache)
assert 'key2' not in all_items_after_cleaning_my_cache[
'key'].values, "'key2' should have been expired and removed from 'my_cache'."
assert 'key3' in all_items_after_cleaning_my_cache['key'].values, "'key3' should still be in 'my_cache'."
# Step 10: Verify the cleaning of expired items in 'second_cache'.
print("Retrieving all current items in 'second_cache' after cleaning expired items.")
all_items_after_cleaning_second_cache = cache_manager.get_all_cache_items('second_cache')
print("Items in 'second_cache' after cleaning:\n", all_items_after_cleaning_second_cache)
assert 'keyA' not in all_items_after_cleaning_second_cache[
'key'].values, "'keyA' should have been expired and removed from 'second_cache'."
assert 'keyB' in all_items_after_cleaning_second_cache[
'key'].values, "'keyB' should still be in 'second_cache'."
assert 'keyC' in all_items_after_cleaning_second_cache[
'key'].values, "'keyC' should still be in 'second_cache' since it has no expiration."
def test_SnapshotDataCache(self):
# Step 1: Create a SnapshotDataCache instance
print("Creating a SnapshotDataCache instance.")
snapshot_cache_manager = SnapshotDataCache()
# Step 2: Create an in-memory cache with a limit of 2 items and 'evict' policy
print("Creating an in-memory cache named 'my_cache' with a limit of 2 items and 'evict' policy.")
snapshot_cache_manager.create_cache('my_cache', cache_type=InMemoryCache, limit=2, eviction_policy='evict')
# Step 3: Set some items in the cache
print("Setting 'key1' in 'my_cache' with an expiration of 10 seconds.")
snapshot_cache_manager.set_cache_item(key='key1', data='data1', expire_delta=dt.timedelta(seconds=10),
cache_name='my_cache')
print("Setting 'key2' in 'my_cache' with an expiration of 20 seconds.")
snapshot_cache_manager.set_cache_item(key='key2', data='data2', expire_delta=dt.timedelta(seconds=20),
cache_name='my_cache')
# Step 4: Take a snapshot of the current state of 'my_cache'
print("Taking a snapshot of the current state of 'my_cache'.")
snapshot_cache_manager.snapshot_cache('my_cache')
# Step 5: Add another item, causing the oldest item to be evicted
print("Adding 'key3' to 'my_cache', which should evict 'key1' due to the limit.")
snapshot_cache_manager.set_cache_item(key='key3', data='data3', cache_name='my_cache')
# Step 6: Retrieve the most recent snapshot of 'my_cache'
print("Retrieving the most recent snapshot of 'my_cache'.")
snapshot = snapshot_cache_manager.get_snapshot('my_cache')
print(f"Snapshot Data:\n{snapshot}")
# Assert that the snapshot contains 'key1' and 'key2', but not 'key3'
assert 'key1' in snapshot['key'].values, "'key1' should be in the snapshot."
assert 'key2' in snapshot['key'].values, "'key2' should be in the snapshot."
assert 'key3' not in snapshot[
'key'].values, "'key3' should not be in the snapshot as it was added after the snapshot."
# Step 7: List all available snapshots with their timestamps
print("Listing all available snapshots with their timestamps.")
snapshots_list = snapshot_cache_manager.list_snapshots()
print(f"Snapshots List: {snapshots_list}")
# Assert that the snapshot list contains 'my_cache'
assert 'my_cache' in snapshots_list, "'my_cache' should be in the snapshots list."
assert isinstance(snapshots_list['my_cache'], str), "The snapshot for 'my_cache' should have a timestamp."
# Additional validation: Ensure 'key3' is present in the live cache but not in the snapshot
print("Ensuring 'key3' is present in the live 'my_cache'.")
live_cache_items = snapshot_cache_manager.get_all_cache_items('my_cache')
print(f"Live 'my_cache' items after adding 'key3':\n{live_cache_items}")
assert 'key3' in live_cache_items['key'].values, "'key3' should be in the live cache."
# Ensure the live cache does not contain 'key1'
assert 'key1' not in live_cache_items['key'].values, "'key1' should have been evicted from the live cache."
def test_update_candle_cache(self):
print('Testing update_candle_cache() method:')
# Initialize the DataGenerator with the 5-minute timeframe
data_gen = DataGenerator('5m')
# Create initial DataFrame and insert it into the cache
df_initial = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 0, 0, tzinfo=dt.timezone.utc))
print(f'Inserting this table into cache:\n{df_initial}\n')
self.data.set_cache_item(key=self.key, data=df_initial, cache_name='candles')
# Create new DataFrame to be added to the cache
df_new = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 15, 0, tzinfo=dt.timezone.utc))
print(f'Updating cache with this table:\n{df_new}\n')
self.data._update_candle_cache(more_records=df_new, key=self.key)
# Retrieve the resulting DataFrame from the cache
result = self.data.get_cache_item(key=self.key, cache_name='candles')
print(f'The resulting table in cache is:\n{result}\n')
# Create the expected DataFrame
expected = data_gen.create_table(num_rec=6, start=dt.datetime(2024, 8, 9, 0, 0, 0, tzinfo=dt.timezone.utc))
print(f'The expected time values are:\n{expected["time"].tolist()}\n')
# Assert that the time values in the result match those in the expected DataFrame, in order
assert result['time'].tolist() == expected['time'].tolist(), \
f"time values in result are {result['time'].tolist()} expected {expected['time'].tolist()}"
print(f'The result time values match:\n{result["time"].tolist()}\n')
print(' - Update cache with new records passed.')
def test_update_cached_dict(self):
print('Testing update_cached_dict() method:')
# Step 1: Set an empty dictionary in the cache for the specified key
print(f'Setting an empty dictionary in the cache with key: {self.key}')
self.data.set_cache_item(data={}, key=self.key)
# Step 2: Update the cached dictionary with a new key-value pair
print(f'Updating the cached dictionary with key: {self.key}, adding sub_key="sub_key" with value="value".')
self.data.update_cached_dict(cache_name='default_cache', cache_key=self.key, dict_key='sub_key', data='value')
# Step 3: Retrieve the updated cache
print(f'Retrieving the updated dictionary from the cache with key: {self.key}')
cache = self.data.get_cache_item(key=self.key)
# Step 4: Verify that the 'sub_key' in the cached dictionary has the correct value
print(f'Verifying that "sub_key" in the cached dictionary has the value "value".')
self.assertIsInstance(cache, dict, "The cache should be a dictionary.")
self.assertIn('sub_key', cache, "The 'sub_key' should be present in the cached dictionary.")
self.assertEqual(cache['sub_key'], 'value')
print(' - Update dictionary in cache passed.')
def _test_get_records_since(self, set_cache=True, set_db=True, query_offset=None, num_rec=None, ex_details=None,
simulate_scenarios=None):
"""
Test the get_records_since() method by generating a table of simulated data,
inserting it into data and/or database, and then querying the records.
Parameters:
set_cache (bool): If True, the generated table is inserted into the cache.
set_db (bool): If True, the generated table is inserted into the database.
query_offset (int, optional): The offset in the timeframe units for the query.
num_rec (int, optional): The number of records to generate in the simulated table.
ex_details (list, optional): Exchange details to generate the data key.
simulate_scenarios (str, optional): The type of scenario to simulate. Options are:
- 'not_enough_data': The table data doesn't go far enough back.
- 'incomplete_data': The table doesn't have enough records to satisfy the query.
- 'missing_section': The table has missing records in the middle.
"""
print('Testing get_records_since() method:')
ex_details = ex_details or self.ex_details
key = f'{ex_details[0]}_{ex_details[1]}_{ex_details[2]}'
num_rec = num_rec or 12
table_timeframe = ex_details[1]
data_gen = DataGenerator(table_timeframe)
if simulate_scenarios == 'not_enough_data':
query_offset = (num_rec + 5) * data_gen.timeframe_amount
else:
query_offset = query_offset or (num_rec - 1) * data_gen.timeframe_amount
if simulate_scenarios == 'incomplete_data':
start_time_for_data = data_gen.x_time_ago(num_rec * data_gen.timeframe_amount)
num_rec = 5
else:
start_time_for_data = None
df_initial = data_gen.create_table(num_rec, start=start_time_for_data)
if simulate_scenarios == 'missing_section':
df_initial = data_gen.generate_missing_section(df_initial, drop_start=2, drop_end=5)
temp_df = df_initial.copy()
temp_df['time'] = pd.to_datetime(temp_df['time'], unit='ms')
print(f'Table Created:\n{temp_df}')
if set_cache:
print('Ensuring the cache exists and then inserting table into the cache.')
self.data.set_cache_item(data=df_initial, key=key, cache_name='candles')
if set_db:
print('Inserting table into the database.')
with SQLite(self.db_file) as con:
df_initial.to_sql(key, con, if_exists='replace', index=False)
start_datetime = data_gen.x_time_ago(query_offset)
if start_datetime.tzinfo is None:
start_datetime = start_datetime.replace(tzinfo=dt.timezone.utc)
query_end_time = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc)
print(f'Requesting records from {start_datetime} to {query_end_time}')
result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details)
expected = df_initial[df_initial['time'] >= data_gen.unix_time_millis(start_datetime)].reset_index(
drop=True)
temp_df = expected.copy()
temp_df['time'] = pd.to_datetime(temp_df['time'], unit='ms')
print(f'Expected table:\n{temp_df}')
temp_df = result.copy()
temp_df['time'] = pd.to_datetime(temp_df['time'], unit='ms')
print(f'Resulting table:\n{temp_df}')
if simulate_scenarios in ['not_enough_data', 'incomplete_data', 'missing_section']:
assert result.shape[0] > expected.shape[
0], "Result has fewer or equal rows compared to the incomplete data."
print("\nThe returned DataFrame has filled in the missing data!")
else:
assert result.shape == expected.shape, f"Shape mismatch: {result.shape} vs {expected.shape}"
pd.testing.assert_series_equal(result['time'], expected['time'], check_dtype=False)
print("\nThe DataFrames have the same shape and the 'time' columns match.")
oldest_timestamp = pd.to_datetime(result['time'].min(), unit='ms').tz_localize('UTC')
time_diff = oldest_timestamp - start_datetime
max_allowed_time_diff = dt.timedelta(**{data_gen.timeframe_unit: data_gen.timeframe_amount})
assert dt.timedelta(0) <= time_diff <= max_allowed_time_diff, \
f"Oldest timestamp {oldest_timestamp} is not within " \
f"{data_gen.timeframe_amount} {data_gen.timeframe_unit} of {start_datetime}"
print(f'The first timestamp is {time_diff} from {start_datetime}')
newest_timestamp = pd.to_datetime(result['time'].max(), unit='ms').tz_localize('UTC')
time_diff_end = abs(query_end_time - newest_timestamp)
assert dt.timedelta(0) <= time_diff_end <= max_allowed_time_diff, \
f"Newest timestamp {newest_timestamp} is not within {data_gen.timeframe_amount} " \
f"{data_gen.timeframe_unit} of {query_end_time}"
print(f'The last timestamp is {time_diff_end} from {query_end_time}')
print(' - Fetch records within the specified time range passed.')
def test_get_records_since(self):
print('\nTest get_records_since with records set in data')
self._test_get_records_since()
print('\nTest get_records_since with records not in data')
self._test_get_records_since(set_cache=False)
print('\nTest get_records_since with records not in database')
self._test_get_records_since(set_cache=False, set_db=False)
print('\nTest get_records_since with a different timeframe')
self._test_get_records_since(query_offset=None, num_rec=None,
ex_details=['BTC/USD', '15m', 'binance', 'test_guy'])
print('\nTest get_records_since where data does not go far enough back')
self._test_get_records_since(simulate_scenarios='not_enough_data')
print('\nTest get_records_since with incomplete data')
self._test_get_records_since(simulate_scenarios='incomplete_data')
print('\nTest get_records_since with missing section in data')
self._test_get_records_since(simulate_scenarios='missing_section')
def test_other_timeframes(self):
print('\nTest get_records_since with a different timeframe')
if 'candles' not in self.data.caches:
self.data.create_cache(cache_name='candles')
ex_details = ['BTC/USD', '15m', 'binance', 'test_guy']
start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=2)
# Query the records since the calculated start time.
result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details)
last_record_time = pd.to_datetime(result['time'].max(), unit='ms').tz_localize('UTC')
assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=15.1)
print('\nTest get_records_since with a different timeframe')
ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=1)
# Query the records since the calculated start time.
result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details)
last_record_time = pd.to_datetime(result['time'].max(), unit='ms').tz_localize('UTC')
assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=5.1)
print('\nTest get_records_since with a different timeframe')
ex_details = ['BTC/USD', '4h', 'binance', 'test_guy']
start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=12)
# Query the records since the calculated start time.
result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details)
last_record_time = pd.to_datetime(result['time'].max(), unit='ms').tz_localize('UTC')
assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=4.1)
def test_populate_db(self):
print('Testing _populate_db() method:')
# Create a table of candle records.
data_gen = DataGenerator(self.ex_details[1])
data = data_gen.create_table(num_rec=5)
self.data._populate_db(ex_details=self.ex_details, data=data)
with SQLite(self.db_file) as con:
result = pd.read_sql(f'SELECT * FROM "{self.key}"', con)
self.assertFalse(result.empty)
print(' - Populate database with data passed.')
def test_fetch_candles_from_exchange(self):
print('Testing _fetch_candles_from_exchange() method:')
# Define start and end times for the data fetch
start_time = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc) - dt.timedelta(days=1)
end_time = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc)
# Fetch the candles from the exchange using the method
result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h',
exchange_name='binance', user_name='test_guy',
start_datetime=start_time, end_datetime=end_time)
# Validate that the result is a DataFrame
self.assertIsInstance(result, pd.DataFrame)
# Validate that the DataFrame is not empty
self.assertFalse(result.empty, "The DataFrame returned from the exchange is empty.")
# Ensure that the 'time' column exists in the DataFrame
self.assertIn('time', result.columns, "'time' column is missing in the result DataFrame.")
# Check if the DataFrame contains valid timestamps within the specified range
min_time = pd.to_datetime(result['time'].min(), unit='ms').tz_localize('UTC')
max_time = pd.to_datetime(result['time'].max(), unit='ms').tz_localize('UTC')
self.assertTrue(start_time <= min_time <= end_time, f"Data starts outside the expected range: {min_time}")
self.assertTrue(start_time <= max_time <= end_time, f"Data ends outside the expected range: {max_time}")
print(' - Fetch candle data from exchange passed.')
def test_remove_row(self):
print('Testing remove_row() method:')
# Create a DataFrame to insert as the data
user_data = pd.DataFrame({
'user_name': ['test_user'],
'password': ['test_password']
})
# Insert data into the cache
self.data.set_cache_item(
cache_name='users',
key='user1',
data=user_data
)
# Ensure the data is in the cache
cache_item = self.data.get_cache_item('user1', 'users')
self.assertIsNotNone(cache_item, "Data was not correctly inserted into the cache.")
# The cache_item is a DataFrame, so we access the 'user_name' column directly
self.assertEqual(cache_item['user_name'].iloc[0], 'test_user', "Inserted data is incorrect.")
# Remove the row from the cache only (soft delete)
self.data.remove_row(cache_name='users', filter_vals=('user_name', 'test_user'), remove_from_db=False)
# Verify the row has been removed from the cache
cache_item = self.data.get_cache_item('user1', 'users')
self.assertIsNone(cache_item, "Row was not correctly removed from the cache.")
# Reinsert the data for hard delete test
self.data.set_cache_item(
cache_name='users',
key='user1',
data=user_data
)
# Mock database delete by adding the row to the database
self.data.db.insert_row(table='users', columns=('user_name', 'password'), values=('test_user', 'test_password'))
# Remove the row from both cache and database (hard delete)
self.data.remove_row(cache_name='users', filter_vals=('user_name', 'test_user'), remove_from_db=True)
# Verify the row has been removed from the cache
cache_item = self.data.get_cache_item('user1', 'users')
self.assertIsNone(cache_item, "Row was not correctly removed from the cache.")
# Verify the row has been removed from the database
with SQLite(self.db_file) as con:
result = pd.read_sql(f'SELECT * FROM users WHERE user_name="test_user"', con)
self.assertTrue(result.empty, "Row was not correctly removed from the database.")
print(' - Remove row from cache and database passed.')
def test_timeframe_to_timedelta(self):
print('Testing timeframe_to_timedelta() function:')
result = timeframe_to_timedelta('2h')
expected = pd.Timedelta(hours=2)
self.assertEqual(result, expected, "Failed to convert '2h' to Timedelta")
result = timeframe_to_timedelta('5m')
expected = pd.Timedelta(minutes=5)
self.assertEqual(result, expected, "Failed to convert '5m' to Timedelta")
result = timeframe_to_timedelta('1d')
expected = pd.Timedelta(days=1)
self.assertEqual(result, expected, "Failed to convert '1d' to Timedelta")
result = timeframe_to_timedelta('3M')
expected = pd.DateOffset(months=3)
self.assertEqual(result, expected, "Failed to convert '3M' to DateOffset")
result = timeframe_to_timedelta('1Y')
expected = pd.DateOffset(years=1)
self.assertEqual(result, expected, "Failed to convert '1Y' to DateOffset")
with self.assertRaises(ValueError):
timeframe_to_timedelta('5x')
print(' - All timeframe_to_timedelta() tests passed.')
def test_estimate_record_count(self):
print('Testing estimate_record_count() function:')
# Test with '1h' timeframe (24 records expected)
start_time = dt.datetime(2023, 8, 1, 0, 0, 0, tzinfo=dt.timezone.utc)
end_time = dt.datetime(2023, 8, 2, 0, 0, 0, tzinfo=dt.timezone.utc)
result = estimate_record_count(start_time, end_time, '1h')
self.assertEqual(result, 24, "Failed to estimate record count for 1h timeframe")
# Test with '1d' timeframe (1 record expected)
result = estimate_record_count(start_time, end_time, '1d')
self.assertEqual(result, 1, "Failed to estimate record count for 1d timeframe")
# Test with '1h' timeframe and timestamps in milliseconds
start_time_ms = int(start_time.timestamp() * 1000) # Convert to milliseconds
end_time_ms = int(end_time.timestamp() * 1000) # Convert to milliseconds
result = estimate_record_count(start_time_ms, end_time_ms, '1h')
self.assertEqual(result, 24, "Failed to estimate record count for 1h timeframe with milliseconds")
# Test with '5m' timeframe and Unix timestamps in milliseconds
start_time_ms = 1672531200000 # Equivalent to '2023-01-01 00:00:00 UTC'
end_time_ms = 1672534800000 # Equivalent to '2023-01-01 01:00:00 UTC'
result = estimate_record_count(start_time_ms, end_time_ms, '5m')
self.assertEqual(result, 12, "Failed to estimate record count for 5m timeframe with milliseconds")
# Test with '5m' timeframe (12 records expected for 1-hour duration)
start_time = dt.datetime(2023, 1, 1, 0, 0, tzinfo=dt.timezone.utc)
end_time = dt.datetime(2023, 1, 1, 1, 0, tzinfo=dt.timezone.utc)
result = estimate_record_count(start_time, end_time, '5m')
self.assertEqual(result, 12, "Failed to estimate record count for 5m timeframe")
# Test with '1M' (3 records expected for 3 months)
start_time = dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc)
end_time = dt.datetime(2023, 4, 1, tzinfo=dt.timezone.utc)
result = estimate_record_count(start_time, end_time, '1M')
self.assertEqual(result, 3, "Failed to estimate record count for 1M timeframe")
# Test with invalid timeframe
with self.assertRaises(ValueError):
estimate_record_count(start_time, end_time, 'xyz') # Invalid timeframe
# Test with invalid start_time passed in
with self.assertRaises(ValueError):
estimate_record_count("invalid_start", end_time, '1h')
# Cross-Year Transition (Months)
start_time = dt.datetime(2022, 12, 1, tzinfo=dt.timezone.utc)
end_time = dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc)
result = estimate_record_count(start_time, end_time, '1M')
self.assertEqual(result, 1, "Failed to estimate record count for month across years")
# Leap Year (Months)
start_time = dt.datetime(2020, 2, 1, tzinfo=dt.timezone.utc)
end_time = dt.datetime(2021, 2, 1, tzinfo=dt.timezone.utc)
result = estimate_record_count(start_time, end_time, '1M')
self.assertEqual(result, 12, "Failed to estimate record count for months during leap year")
# Sub-Minute Timeframes (e.g., 30 seconds)
start_time = dt.datetime(2023, 1, 1, 0, 0, tzinfo=dt.timezone.utc)
end_time = dt.datetime(2023, 1, 1, 0, 1, tzinfo=dt.timezone.utc)
result = estimate_record_count(start_time, end_time, '30s')
self.assertEqual(result, 2, "Failed to estimate record count for 30 seconds timeframe")
# Different Timezones
start_time = dt.datetime(2023, 1, 1, 0, 0, tzinfo=dt.timezone(dt.timedelta(hours=5))) # UTC+5
end_time = dt.datetime(2023, 1, 1, 1, 0, tzinfo=dt.timezone.utc) # UTC
result = estimate_record_count(start_time, end_time, '1h')
self.assertEqual(result, 6,
"Failed to estimate record count for different timezones") # Expect 6 records, not 1
# Test with zero-length interval (should return 0)
result = estimate_record_count(start_time, start_time, '1h')
self.assertEqual(result, 0, "Failed to return 0 for zero-length interval")
# Test with negative interval (end_time earlier than start_time, should return 0)
result = estimate_record_count(end_time, start_time, '1h')
self.assertEqual(result, 0, "Failed to return 0 for negative interval")
# Test with small interval compared to timeframe (should return 0)
start_time = dt.datetime(2023, 8, 1, 0, 0, tzinfo=dt.timezone.utc)
end_time = dt.datetime(2023, 8, 1, 0, 30, tzinfo=dt.timezone.utc) # 30 minutes
result = estimate_record_count(start_time, end_time, '1h')
self.assertEqual(result, 0, "Failed to return 0 for small interval compared to timeframe")
print(' - All estimate_record_count() tests passed.')
def test_get_or_fetch_rows(self):
# Create a mock table in the cache with multiple entries
df1 = pd.DataFrame({
'user_name': ['billy'],
'password': ['1234'],
'exchanges': [['ex1', 'ex2', 'ex3']]
})
df2 = pd.DataFrame({
'user_name': ['john'],
'password': ['5678'],
'exchanges': [['ex4', 'ex5', 'ex6']]
})
df3 = pd.DataFrame({
'user_name': ['alice'],
'password': ['91011'],
'exchanges': [['ex7', 'ex8', 'ex9']]
})
# Insert these DataFrames into the 'users' cache
self.data.create_cache('users', cache_type=InMemoryCache)
self.data.set_cache_item(key='user_billy', data=df1, cache_name='users')
self.data.set_cache_item(key='user_john', data=df2, cache_name='users')
self.data.set_cache_item(key='user_alice', data=df3, cache_name='users')
print('Testing get_or_fetch_rows() method:')
# Test fetching an existing user from the cache
result = self.data.get_or_fetch_rows('users', ('user_name', 'billy'))
self.assertIsInstance(result, pd.DataFrame, "Failed to fetch DataFrame from cache")
self.assertFalse(result.empty, "The fetched DataFrame is empty")
self.assertEqual(result.iloc[0]['password'], '1234', "Incorrect data fetched from cache")
# Test fetching another user from the cache
result = self.data.get_or_fetch_rows('users', ('user_name', 'john'))
self.assertIsInstance(result, pd.DataFrame, "Failed to fetch DataFrame from cache")
self.assertFalse(result.empty, "The fetched DataFrame is empty")
self.assertEqual(result.iloc[0]['password'], '5678', "Incorrect data fetched from cache")
# Test fetching a user that does not exist in the cache
result = self.data.get_or_fetch_rows('users', ('user_name', 'non_existent_user'))
# Check if result is None (indicating that no data was found)
self.assertIsNone(result, "Expected result to be None for a non-existent user")
print(' - Fetching rows from cache passed.')
def test_is_attr_taken(self):
# Create a cache named 'users'
self.data.create_cache('users', cache_type=InMemoryCache)
# Create mock data for three users
user_data_1 = pd.DataFrame({
'user_name': ['billy'],
'password': ['1234'],
'exchanges': [['ex1', 'ex2', 'ex3']]
})
user_data_2 = pd.DataFrame({
'user_name': ['john'],
'password': ['5678'],
'exchanges': [['ex1', 'ex2', 'ex4']]
})
user_data_3 = pd.DataFrame({
'user_name': ['alice'],
'password': ['abcd'],
'exchanges': [['ex5', 'ex6', 'ex7']]
})
# Insert mock data into the cache
self.data.set_cache_item('user1', user_data_1, cache_name='users')
self.data.set_cache_item('user2', user_data_2, cache_name='users')
self.data.set_cache_item('user3', user_data_3, cache_name='users')
# Test when attribute value is taken
result_taken = self.data.is_attr_taken(cache_name='users', attr='user_name', val='billy')
self.assertTrue(result_taken, "Expected 'billy' to be taken, but it was not.")
# Test when attribute value is not taken
result_not_taken = self.data.is_attr_taken(cache_name='users', attr='user_name', val='charlie')
self.assertFalse(result_not_taken, "Expected 'charlie' not to be taken, but it was.")
def test_insert_df(self):
print('Testing insert_df() method:')
# Create a DataFrame to insert
df = pd.DataFrame({
'user_name': ['Alice'],
'age': [30],
'users_data': ['user_data_1'],
'data': ['additional_data'],
'password': ['1234']
})
# Insert data into the database and cache
self.data.insert_df(df=df, cache_name='users')
# Assume the database will return an auto-incremented ID starting at 1
auto_incremented_id = 1
# Verify that the data was added to the cache using the auto-incremented ID as the key
cached_df = self.data.get_cache_item(key=str(auto_incremented_id), cache_name='users')
# Check that the DataFrame in the cache matches the original DataFrame
pd.testing.assert_frame_equal(cached_df, df, check_dtype=False)
# Now, let's verify the data was inserted into the database
with SQLite(self.data.db.db_file) as conn:
# Query the users table for the inserted data
query_result = pd.read_sql_query(f"SELECT * FROM users WHERE id = {auto_incremented_id}", conn)
# Verify the database content matches the inserted DataFrame
expected_db_df = df.copy()
expected_db_df['id'] = auto_incremented_id # Add the auto-incremented ID to the expected DataFrame
# Align column order
expected_db_df = expected_db_df[['id', 'user_name', 'age', 'users_data', 'data', 'password']]
# Check that the database DataFrame matches the expected DataFrame
pd.testing.assert_frame_equal(query_result, expected_db_df, check_dtype=False)
print(' - Data insertion into cache and database verified successfully.')
def test_insert_row(self):
print("Testing insert_row() method:")
# Define the cache name, columns, and values to insert
cache_name = 'users'
columns = ('user_name', 'age')
values = ('Alice', 30)
# Create the cache first
self.data.create_cache(cache_name, cache_type=InMemoryCache)
# Insert a row into the cache and database without skipping the cache
self.data.insert_row(cache_name=cache_name, columns=columns, values=values, skip_cache=False)
# Retrieve the inserted item from the cache
result = self.data.get_cache_item(key='1', cache_name=cache_name)
# Assert that the data in the cache matches what was inserted
self.assertIsNotNone(result, "No data found in the cache for the inserted ID.")
self.assertEqual(result.iloc[0]['user_name'], 'Alice',
"The name in the cache doesn't match the inserted value.")
self.assertEqual(result.iloc[0]['age'], 30, "The age in the cache does not match the inserted value.")
# Now test with skipping the cache
print("Testing insert_row() with skip_cache=True")
# Insert another row into the database, this time skipping the cache
self.data.insert_row(cache_name=cache_name, columns=columns, values=('Bob', 40), skip_cache=True)
# Attempt to retrieve the newly inserted row from the cache
result_after_skip = self.data.get_cache_item(key='2', cache_name=cache_name)
# Assert that no data is found in the cache for the new row
self.assertIsNone(result_after_skip, "Data should not have been cached when skip_cache=True.")
print(" - Insert row with and without caching passed all checks.")
def test_fill_data_holes(self):
print('Testing _fill_data_holes() method:')
# Create mock data with gaps
df = pd.DataFrame({
'time': [dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc).timestamp() * 1000,
dt.datetime(2023, 1, 1, 2, tzinfo=dt.timezone.utc).timestamp() * 1000,
dt.datetime(2023, 1, 1, 6, tzinfo=dt.timezone.utc).timestamp() * 1000,
dt.datetime(2023, 1, 1, 8, tzinfo=dt.timezone.utc).timestamp() * 1000,
dt.datetime(2023, 1, 1, 12, tzinfo=dt.timezone.utc).timestamp() * 1000]
})
# Call the method
result = self.data._fill_data_holes(records=df, interval='2h')
self.assertEqual(len(result), 7, "Data holes were not filled correctly.")
print(' - _fill_data_holes passed.')
def test_get_cache_item(self):
# Case 1: Retrieve a stored Indicator instance (serialized)
indicator = Indicator(name='SMA', indicator_type='SMA', properties={'period': 5})
self.data.set_cache_item('indicator_key', indicator, cache_name='indicators')
stored_data = self.data.get_cache_item('indicator_key', cache_name='indicators')
self.assertIsInstance(stored_data, Indicator, "Failed to retrieve and deserialize the Indicator instance")
# Case 2: Retrieve non-Indicator data (e.g., dict)
data = {'key': 'value'}
self.data.set_cache_item('non_indicator_key', data)
stored_data = self.data.get_cache_item('non_indicator_key')
self.assertEqual(stored_data, data, "Failed to retrieve non-Indicator data correctly")
# Case 3: Retrieve expired cache item (should return None)
self.data.set_cache_item('expiring_key', 'test_data', expire_delta=dt.timedelta(seconds=1))
time.sleep(2) # Wait for the cache to expire
self.assertIsNone(self.data.get_cache_item('expiring_key'), "Expired cache item should return None")
# Case 4: Retrieve non-existent key (should return None)
self.assertIsNone(self.data.get_cache_item('non_existent_key'), "Non-existent key should return None")
# Case 5: Retrieve with invalid key type (should raise ValueError)
with self.assertRaises(ValueError):
self.data.get_cache_item(12345) # Invalid key type
# Case 6: Test Deserialization Failure
# Simulate corrupted serialized data
corrupted_data = b'\x80\x03corrupted_data'
self.data.set_cache_item('corrupted_key', corrupted_data, cache_name='indicators')
with self.assertLogs(level='ERROR') as log:
self.assertIsNone(self.data.get_cache_item('corrupted_key', cache_name='indicators'))
self.assertIn("Deserialization failed", log.output[0])
# Case 7: Test Cache Eviction
# Create a cache with a limit of 2 items
self.data.set_cache_item('key1', 'data1', cache_name='test_cache', limit=2)
self.data.set_cache_item('key2', 'data2', cache_name='test_cache', limit=2)
self.data.set_cache_item('key3', 'data3', cache_name='test_cache', limit=2)
# Verify that the oldest item (key1) has been evicted
self.assertIsNone(self.data.get_cache_item('key1', cache_name='test_cache'))
self.assertEqual(self.data.get_cache_item('key2', cache_name='test_cache'), 'data2')
self.assertEqual(self.data.get_cache_item('key3', cache_name='test_cache'), 'data3')
def test_set_user_indicator_properties(self):
# Case 1: Store user-specific display properties
user_id = 'user123'
indicator_type = 'SMA'
symbol = 'AAPL'
timeframe = '1h'
exchange_name = 'NYSE'
display_properties = {'color': 'blue', 'line_width': 2}
# Call the method to set properties
self.data.set_user_indicator_properties(user_id, indicator_type, symbol, timeframe, exchange_name,
display_properties)
# Construct the cache key manually for validation
user_cache_key = f"user_{user_id}_{indicator_type}_{symbol}_{timeframe}_{exchange_name}"
# Retrieve the stored properties
stored_properties = self.data.get_cache_item(user_cache_key, cache_name='user_display_properties')
# Check if the properties were stored correctly
self.assertEqual(stored_properties, display_properties, "Failed to store user-specific display properties")
# Case 2: Update existing user-specific properties
updated_properties = {'color': 'red', 'line_width': 3}
# Update the properties
self.data.set_user_indicator_properties(user_id, indicator_type, symbol, timeframe, exchange_name,
updated_properties)
# Retrieve the updated properties
updated_stored_properties = self.data.get_cache_item(user_cache_key, cache_name='user_display_properties')
# Check if the properties were updated correctly
self.assertEqual(updated_stored_properties, updated_properties,
"Failed to update user-specific display properties")
# Case 3: Handle invalid user properties (e.g., non-dict input)
with self.assertRaises(ValueError):
self.data.set_user_indicator_properties(user_id, indicator_type, symbol, timeframe, exchange_name,
"invalid_properties")
def test_get_user_indicator_properties(self):
# Case 1: Retrieve existing user-specific display properties
user_id = 'user123'
indicator_type = 'SMA'
symbol = 'AAPL'
timeframe = '1h'
exchange_name = 'NYSE'
display_properties = {'color': 'blue', 'line_width': 2}
# Set the properties first
self.data.set_user_indicator_properties(user_id, indicator_type, symbol, timeframe, exchange_name,
display_properties)
# Retrieve the properties
retrieved_properties = self.data.get_user_indicator_properties(user_id, indicator_type, symbol, timeframe,
exchange_name)
self.assertEqual(retrieved_properties, display_properties,
"Failed to retrieve user-specific display properties")
# Case 2: Handle missing key (should return None)
missing_properties = self.data.get_user_indicator_properties('nonexistent_user', indicator_type, symbol,
timeframe, exchange_name)
self.assertIsNone(missing_properties, "Expected None for missing user-specific display properties")
# Case 3: Invalid argument handling
with self.assertRaises(TypeError):
self.data.get_user_indicator_properties(123, indicator_type, symbol, timeframe,
exchange_name) # Invalid user_id type
def test_set_cache_item(self):
# Case 1: Store and retrieve an Indicator instance (serialized)
indicator = Indicator(name='SMA', indicator_type='SMA', properties={'period': 5})
self.data.set_cache_item('indicator_key', indicator, cache_name='indicators')
stored_data = self.data.get_cache_item('indicator_key', cache_name='indicators')
self.assertIsInstance(stored_data, Indicator, "Failed to deserialize the Indicator instance")
# Case 2: Store and retrieve non-Indicator data (e.g., dict)
data = {'key': 'value'}
self.data.set_cache_item('non_indicator_key', data)
stored_data = self.data.get_cache_item('non_indicator_key')
self.assertEqual(stored_data, data, "Non-Indicator data was modified or not stored correctly")
# Case 3: Handle invalid key type (non-string)
with self.assertRaises(ValueError):
self.data.set_cache_item(12345, 'test_data') # Invalid key type
# Case 4: Cache item expiration (item should expire after set time)
self.data.set_cache_item('expiring_key', 'test_data', expire_delta=dt.timedelta(seconds=1))
time.sleep(2) # Wait for expiration time
self.assertIsNone(self.data.get_cache_item('expiring_key'), "Cached item did not expire as expected")
def test_calculate_and_cache_indicator(self):
# Testing the calculation and caching of an indicator through DataCache (which includes IndicatorCache
# functionality)
user_properties = {'color_line_1': 'blue', 'thickness_line_1': 2}
ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
# Define the time range for the calculation
start_datetime = dt.datetime(2023, 9, 1, 0, 0, 0, tzinfo=dt.timezone.utc)
end_datetime = dt.datetime(2023, 9, 2, 0, 0, 0, tzinfo=dt.timezone.utc)
# Simulate calculating an indicator and caching it through DataCache
result = self.data.calculate_indicator(
user_name='test_guy',
symbol=ex_details[0],
timeframe=ex_details[1],
exchange_name=ex_details[2],
indicator_type='SMA', # Type of indicator
start_datetime=start_datetime,
end_datetime=end_datetime,
properties={'period': 5} # Add the necessary indicator properties like period
)
# Ensure that result is not None
self.assertIsNotNone(result, "Indicator calculation returned None.")
def test_calculate_indicator_multiple_users(self):
"""
Test that the calculate_indicator method handles multiple users' requests with different properties.
"""
ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
user1_properties = {'color': 'blue', 'thickness': 2}
user2_properties = {'color': 'red', 'thickness': 1}
# Set user-specific properties
self.data.set_user_indicator_properties('user_1', 'SMA', 'BTC/USD', '5m', 'binance', user1_properties)
self.data.set_user_indicator_properties('user_2', 'SMA', 'BTC/USD', '5m', 'binance', user2_properties)
# User 1 calculates the SMA indicator
result_user1 = self.data.calculate_indicator(
user_name='user_1',
symbol='BTC/USD',
timeframe='5m',
exchange_name='binance',
indicator_type='SMA',
start_datetime=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc),
end_datetime=dt.datetime(2023, 1, 2, tzinfo=dt.timezone.utc),
properties={'period': 5}
)
# User 2 calculates the same SMA indicator but with different display properties
result_user2 = self.data.calculate_indicator(
user_name='user_2',
symbol='BTC/USD',
timeframe='5m',
exchange_name='binance',
indicator_type='SMA',
start_datetime=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc),
end_datetime=dt.datetime(2023, 1, 2, tzinfo=dt.timezone.utc),
properties={'period': 5}
)
# Assert that the calculation data is the same
self.assertEqual(result_user1['calculation_data'], result_user2['calculation_data'])
# Assert that the display properties are different
self.assertNotEqual(result_user1['display_properties'], result_user2['display_properties'])
# Assert that the correct display properties are returned
self.assertEqual(result_user1['display_properties']['color'], 'blue')
self.assertEqual(result_user2['display_properties']['color'], 'red')
def test_calculate_indicator_cache_retrieval(self):
"""
Test that cached data is retrieved efficiently without recalculating when the same request is made.
"""
ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
properties = {'period': 5}
cache_key = 'BTC/USD_5m_binance_SMA_5'
# First calculation (should store result in cache)
result_first = self.data.calculate_indicator(
user_name='user_1',
symbol='BTC/USD',
timeframe='5m',
exchange_name='binance',
indicator_type='SMA',
start_datetime=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc),
end_datetime=dt.datetime(2023, 1, 2, tzinfo=dt.timezone.utc),
properties=properties
)
# Check if the data was cached after the first calculation
cached_data = self.data.get_cache_item(cache_key, cache_name='indicator_data')
print(f"Cached Data after first calculation: {cached_data}")
# Ensure the data was cached correctly
self.assertIsNotNone(cached_data, "The first calculation did not cache the result properly.")
# Second calculation with the same parameters (should retrieve from cache)
with self.assertLogs(level='INFO') as log:
result_second = self.data.calculate_indicator(
user_name='user_1',
symbol='BTC/USD',
timeframe='5m',
exchange_name='binance',
indicator_type='SMA',
start_datetime=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc),
end_datetime=dt.datetime(2023, 1, 2, tzinfo=dt.timezone.utc),
properties=properties
)
# Verify the log message for cache retrieval
self.assertTrue(
any(f"DataFrame retrieved from cache for key: {cache_key}" in message for message in log.output),
f"Cache retrieval log message not found for key: {cache_key}"
)
def test_calculate_indicator_partial_cache(self):
"""
Test handling of partial cache where some of the requested data is already cached,
and the rest needs to be fetched.
"""
ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
properties = {'period': 5}
# Simulate cache for part of the range (manual setup, no call to `get_records_since`)
cached_data = pd.DataFrame({
'time': pd.date_range(start="2023-01-01", periods=144, freq='5min', tz=dt.timezone.utc),
# Cached half a day of data
'value': [16500 + i for i in range(144)]
})
# Generate cache key with correct format
cache_key = self.data._make_indicator_key('BTC/USD', '5m', 'binance', 'SMA', properties['period'])
# Store the cached data as DataFrame (no need for to_dict('records'))
self.data.set_cache_item(cache_key, cached_data, cache_name='indicator_data')
# Print cached data to inspect its range
print("Cached data time range:")
print(f"Min cached time: {cached_data['time'].min()}")
print(f"Max cached time: {cached_data['time'].max()}")
# Now request a range that partially overlaps the cached data
result = self.data.calculate_indicator(
user_name='user_1',
symbol='BTC/USD',
timeframe='5m',
exchange_name='binance',
indicator_type='SMA',
start_datetime=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc),
end_datetime=dt.datetime(2023, 1, 2, tzinfo=dt.timezone.utc),
properties=properties
)
# Convert the result into a DataFrame
result_df = pd.DataFrame(result['calculation_data'])
# Convert the 'time' column from Unix timestamp (ms) back to datetime with timezone
result_df['time'] = pd.to_datetime(result_df['time'], unit='ms', utc=True)
# Debugging: print the full result to inspect the time range
print("Result data time range:")
print(f"Min result time: {result_df['time'].min()}")
print(f"Max result time: {result_df['time'].max()}")
# Now you can safely find the min and max values
min_time = result_df['time'].min()
max_time = result_df['time'].max()
# Debugging print statements to confirm the values
print(f"Min time in result: {min_time}")
print(f"Max time in result: {max_time}")
# Assert that the min and max time in the result cover the full range from the cache and new data
self.assertEqual(min_time, pd.Timestamp("2023-01-01 00:00:00", tz=dt.timezone.utc))
self.assertEqual(max_time, pd.Timestamp("2023-01-02 00:00:00", tz=dt.timezone.utc))
def test_calculate_indicator_no_data(self):
"""
Test that the indicator calculation handles cases where no data is available for the requested range.
"""
ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
properties = {'period': 5}
# Request data for a period where no data exists
result = self.data.calculate_indicator(
user_name='user_1',
symbol='BTC/USD',
timeframe='5m',
exchange_name='binance',
indicator_type='SMA',
start_datetime=dt.datetime(1900, 1, 1, tzinfo=dt.timezone.utc),
end_datetime=dt.datetime(1900, 1, 2, tzinfo=dt.timezone.utc),
properties=properties
)
# Ensure no calculation data is returned
self.assertEqual(len(result['calculation_data']), 0)
if __name__ == '__main__':
unittest.main()