brighter-trading/tests/test_DataCache.py

229 lines
8.6 KiB
Python

from DataCache import DataCache
from exchangeinterface import ExchangeInterface
import unittest
import pandas as pd
import datetime as dt
import os
from Database import SQLite, Database
from shared_utilities import unix_time_millis
class TestDataCache(unittest.TestCase):
def setUp(self):
# Set the database connection here
self.exchanges = ExchangeInterface()
self.exchanges.connect_exchange(exchange_name='binance', user_name='test_guy', api_keys=None)
# This object maintains all the cached data. Pass it connection to the exchanges.
self.db_file = 'test_db.sqlite'
self.database = Database(db_file=self.db_file)
# Create necessary tables
with SQLite(db_file=self.db_file) as con:
cursor = con.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS exchange (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS markets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT,
exchange_id INTEGER,
FOREIGN KEY (exchange_id) REFERENCES exchange(id)
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS test_table (
id INTEGER PRIMARY KEY,
market_id INTEGER,
open_time INTEGER UNIQUE,
open REAL NOT NULL,
high REAL NOT NULL,
low REAL NOT NULL,
close REAL NOT NULL,
volume REAL NOT NULL,
FOREIGN KEY (market_id) REFERENCES markets(id)
)
""")
self.data = DataCache(self.exchanges)
self.data.db = self.database
asset, timeframe, exchange = 'BTC/USD', '2h', 'binance'
self.key1 = f'{asset}_{timeframe}_{exchange}'
asset, timeframe, exchange = 'ETH/USD', '2h', 'binance'
self.key2 = f'{asset}_{timeframe}_{exchange}'
def tearDown(self):
if os.path.exists(self.db_file):
os.remove(self.db_file)
def test_set_cache(self):
print('Testing set_cache flag not set:')
self.data.set_cache(data='data', key=self.key1)
attr = self.data.__getattribute__('cached_data')
self.assertEqual(attr[self.key1], 'data')
self.data.set_cache(data='more_data', key=self.key1)
attr = self.data.__getattribute__('cached_data')
self.assertEqual(attr[self.key1], 'more_data')
print('Testing set_cache no-overwrite flag set:')
self.data.set_cache(data='even_more_data', key=self.key1, do_not_overwrite=True)
attr = self.data.__getattribute__('cached_data')
self.assertEqual(attr[self.key1], 'more_data')
def test_cache_exists(self):
print('Testing cache_exists() method:')
self.assertFalse(self.data.cache_exists(key=self.key2))
self.data.set_cache(data='data', key=self.key1)
self.assertTrue(self.data.cache_exists(key=self.key1))
def test_update_candle_cache(self):
print('Testing update_candle_cache() method:')
df_initial = pd.DataFrame({
'open_time': [1, 2, 3],
'open': [100, 101, 102],
'high': [110, 111, 112],
'low': [90, 91, 92],
'close': [105, 106, 107],
'volume': [1000, 1001, 1002]
})
df_new = pd.DataFrame({
'open_time': [3, 4, 5],
'open': [102, 103, 104],
'high': [112, 113, 114],
'low': [92, 93, 94],
'close': [107, 108, 109],
'volume': [1002, 1003, 1004]
})
self.data.set_cache(data=df_initial, key=self.key1)
self.data.update_candle_cache(more_records=df_new, key=self.key1)
result = self.data.get_cache(key=self.key1)
expected = pd.DataFrame({
'open_time': [1, 2, 3, 4, 5],
'open': [100, 101, 102, 103, 104],
'high': [110, 111, 112, 113, 114],
'low': [90, 91, 92, 93, 94],
'close': [105, 106, 107, 108, 109],
'volume': [1000, 1001, 1002, 1003, 1004]
})
pd.testing.assert_frame_equal(result, expected)
def test_update_cached_dict(self):
print('Testing update_cached_dict() method:')
self.data.set_cache(data={}, key=self.key1)
self.data.update_cached_dict(cache_key=self.key1, dict_key='sub_key', data='value')
cache = self.data.get_cache(key=self.key1)
self.assertEqual(cache['sub_key'], 'value')
def test_get_cache(self):
print('Testing get_cache() method:')
self.data.set_cache(data='data', key=self.key1)
result = self.data.get_cache(key=self.key1)
self.assertEqual(result, 'data')
def test_get_records_since(self):
print('Testing get_records_since() method:')
df_initial = pd.DataFrame({
'open_time': [unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=i)) for i in range(3)],
'open': [100, 101, 102],
'high': [110, 111, 112],
'low': [90, 91, 92],
'close': [105, 106, 107],
'volume': [1000, 1001, 1002]
})
self.data.set_cache(data=df_initial, key=self.key1)
start_datetime = dt.datetime.utcnow() - dt.timedelta(minutes=2)
result = self.data.get_records_since(key=self.key1, start_datetime=start_datetime, record_length=60,
ex_details=['BTC/USD', '2h', 'binance'])
expected = pd.DataFrame({
'open_time': df_initial['open_time'][:2].values,
'open': [100, 101],
'high': [110, 111],
'low': [90, 91],
'close': [105, 106],
'volume': [1000, 1001]
})
pd.testing.assert_frame_equal(result, expected)
def test_get_records_since_from_db(self):
print('Testing get_records_since_from_db() method:')
df_initial = pd.DataFrame({
'market_id': [None],
'open_time': [unix_time_millis(dt.datetime.utcnow())],
'open': [1.0],
'high': [1.0],
'low': [1.0],
'close': [1.0],
'volume': [1.0]
})
with SQLite(self.db_file) as con:
df_initial.to_sql('test_table', con, if_exists='append', index=False)
start_datetime = dt.datetime.utcnow() - dt.timedelta(minutes=1)
end_datetime = dt.datetime.utcnow()
result = self.data.get_records_since_from_db(table_name='test_table', st=start_datetime, et=end_datetime,
rl=1, ex_details=['BTC/USD', '2h', 'binance']).sort_values(
by='open_time').reset_index(drop=True)
print("Columns in the result DataFrame:", result.columns)
print("Result DataFrame:\n", result)
# Remove 'id' column from the result DataFrame if it exists
if 'id' in result.columns:
result = result.drop(columns=['id'])
expected = pd.DataFrame({
'market_id': [None],
'open_time': [unix_time_millis(dt.datetime.utcnow())],
'open': [1.0],
'high': [1.0],
'low': [1.0],
'close': [1.0],
'volume': [1.0]
})
print("Expected DataFrame:\n", expected)
pd.testing.assert_frame_equal(result, expected)
def test_populate_db(self):
print('Testing _populate_db() method:')
start_time = dt.datetime.utcnow() - dt.timedelta(days=1)
end_time = dt.datetime.utcnow()
result = self.data._populate_db(table_name='test_table', start_time=start_time,
end_time=end_time, ex_details=['BTC/USD', '2h', 'binance', 'test_guy'])
self.assertIsInstance(result, pd.DataFrame)
self.assertFalse(result.empty)
def test_fetch_candles_from_exchange(self):
print('Testing _fetch_candles_from_exchange() method:')
start_time = dt.datetime.utcnow() - dt.timedelta(days=1)
end_time = dt.datetime.utcnow()
result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h', exchange_name='binance',
user_name='test_guy', start_datetime=start_time,
end_datetime=end_time)
self.assertIsInstance(result, pd.DataFrame)
self.assertFalse(result.empty)
if __name__ == '__main__':
unittest.main()