229 lines
8.6 KiB
Python
229 lines
8.6 KiB
Python
from DataCache import DataCache
|
|
from ExchangeInterface import ExchangeInterface
|
|
import unittest
|
|
import pandas as pd
|
|
import datetime as dt
|
|
import os
|
|
from Database import SQLite, Database
|
|
from shared_utilities import unix_time_millis
|
|
|
|
|
|
class TestDataCache(unittest.TestCase):
|
|
def setUp(self):
|
|
# Set the database connection here
|
|
self.exchanges = ExchangeInterface()
|
|
self.exchanges.connect_exchange(exchange_name='binance', user_name='test_guy', api_keys=None)
|
|
# This object maintains all the cached data. Pass it connection to the exchanges.
|
|
self.db_file = 'test_db.sqlite'
|
|
self.database = Database(db_file=self.db_file)
|
|
|
|
# Create necessary tables
|
|
with SQLite(db_file=self.db_file) as con:
|
|
cursor = con.cursor()
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS exchange (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT UNIQUE
|
|
)
|
|
""")
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS markets (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
symbol TEXT,
|
|
exchange_id INTEGER,
|
|
FOREIGN KEY (exchange_id) REFERENCES exchange(id)
|
|
)
|
|
""")
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS test_table (
|
|
id INTEGER PRIMARY KEY,
|
|
market_id INTEGER,
|
|
open_time INTEGER UNIQUE,
|
|
open REAL NOT NULL,
|
|
high REAL NOT NULL,
|
|
low REAL NOT NULL,
|
|
close REAL NOT NULL,
|
|
volume REAL NOT NULL,
|
|
FOREIGN KEY (market_id) REFERENCES markets(id)
|
|
)
|
|
""")
|
|
|
|
self.data = DataCache(self.exchanges)
|
|
self.data.db = self.database
|
|
|
|
asset, timeframe, exchange = 'BTC/USD', '2h', 'binance'
|
|
self.key1 = f'{asset}_{timeframe}_{exchange}'
|
|
|
|
asset, timeframe, exchange = 'ETH/USD', '2h', 'binance'
|
|
self.key2 = f'{asset}_{timeframe}_{exchange}'
|
|
|
|
def tearDown(self):
|
|
if os.path.exists(self.db_file):
|
|
os.remove(self.db_file)
|
|
|
|
def test_set_cache(self):
|
|
print('Testing set_cache flag not set:')
|
|
self.data.set_cache(data='data', key=self.key1)
|
|
attr = self.data.__getattribute__('cached_data')
|
|
self.assertEqual(attr[self.key1], 'data')
|
|
|
|
self.data.set_cache(data='more_data', key=self.key1)
|
|
attr = self.data.__getattribute__('cached_data')
|
|
self.assertEqual(attr[self.key1], 'more_data')
|
|
|
|
print('Testing set_cache no-overwrite flag set:')
|
|
self.data.set_cache(data='even_more_data', key=self.key1, do_not_overwrite=True)
|
|
attr = self.data.__getattribute__('cached_data')
|
|
self.assertEqual(attr[self.key1], 'more_data')
|
|
|
|
def test_cache_exists(self):
|
|
print('Testing cache_exists() method:')
|
|
self.assertFalse(self.data.cache_exists(key=self.key2))
|
|
self.data.set_cache(data='data', key=self.key1)
|
|
self.assertTrue(self.data.cache_exists(key=self.key1))
|
|
|
|
def test_update_candle_cache(self):
|
|
print('Testing update_candle_cache() method:')
|
|
df_initial = pd.DataFrame({
|
|
'open_time': [1, 2, 3],
|
|
'open': [100, 101, 102],
|
|
'high': [110, 111, 112],
|
|
'low': [90, 91, 92],
|
|
'close': [105, 106, 107],
|
|
'volume': [1000, 1001, 1002]
|
|
})
|
|
|
|
df_new = pd.DataFrame({
|
|
'open_time': [3, 4, 5],
|
|
'open': [102, 103, 104],
|
|
'high': [112, 113, 114],
|
|
'low': [92, 93, 94],
|
|
'close': [107, 108, 109],
|
|
'volume': [1002, 1003, 1004]
|
|
})
|
|
|
|
self.data.set_cache(data=df_initial, key=self.key1)
|
|
self.data.update_candle_cache(more_records=df_new, key=self.key1)
|
|
|
|
result = self.data.get_cache(key=self.key1)
|
|
expected = pd.DataFrame({
|
|
'open_time': [1, 2, 3, 4, 5],
|
|
'open': [100, 101, 102, 103, 104],
|
|
'high': [110, 111, 112, 113, 114],
|
|
'low': [90, 91, 92, 93, 94],
|
|
'close': [105, 106, 107, 108, 109],
|
|
'volume': [1000, 1001, 1002, 1003, 1004]
|
|
})
|
|
|
|
pd.testing.assert_frame_equal(result, expected)
|
|
|
|
def test_update_cached_dict(self):
|
|
print('Testing update_cached_dict() method:')
|
|
self.data.set_cache(data={}, key=self.key1)
|
|
self.data.update_cached_dict(cache_key=self.key1, dict_key='sub_key', data='value')
|
|
|
|
cache = self.data.get_cache(key=self.key1)
|
|
self.assertEqual(cache['sub_key'], 'value')
|
|
|
|
def test_get_cache(self):
|
|
print('Testing get_cache() method:')
|
|
self.data.set_cache(data='data', key=self.key1)
|
|
result = self.data.get_cache(key=self.key1)
|
|
self.assertEqual(result, 'data')
|
|
|
|
def test_get_records_since(self):
|
|
print('Testing get_records_since() method:')
|
|
df_initial = pd.DataFrame({
|
|
'open_time': [unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=i)) for i in range(3)],
|
|
'open': [100, 101, 102],
|
|
'high': [110, 111, 112],
|
|
'low': [90, 91, 92],
|
|
'close': [105, 106, 107],
|
|
'volume': [1000, 1001, 1002]
|
|
})
|
|
|
|
self.data.set_cache(data=df_initial, key=self.key1)
|
|
start_datetime = dt.datetime.utcnow() - dt.timedelta(minutes=2)
|
|
result = self.data.get_records_since(key=self.key1, start_datetime=start_datetime, record_length=60,
|
|
ex_details=['BTC/USD', '2h', 'binance'])
|
|
|
|
expected = pd.DataFrame({
|
|
'open_time': df_initial['open_time'][:2].values,
|
|
'open': [100, 101],
|
|
'high': [110, 111],
|
|
'low': [90, 91],
|
|
'close': [105, 106],
|
|
'volume': [1000, 1001]
|
|
})
|
|
|
|
pd.testing.assert_frame_equal(result, expected)
|
|
|
|
def test_get_records_since_from_db(self):
|
|
print('Testing get_records_since_from_db() method:')
|
|
df_initial = pd.DataFrame({
|
|
'market_id': [None],
|
|
'open_time': [unix_time_millis(dt.datetime.utcnow())],
|
|
'open': [1.0],
|
|
'high': [1.0],
|
|
'low': [1.0],
|
|
'close': [1.0],
|
|
'volume': [1.0]
|
|
})
|
|
|
|
with SQLite(self.db_file) as con:
|
|
df_initial.to_sql('test_table', con, if_exists='append', index=False)
|
|
|
|
start_datetime = dt.datetime.utcnow() - dt.timedelta(minutes=1)
|
|
end_datetime = dt.datetime.utcnow()
|
|
result = self.data.get_records_since_from_db(table_name='test_table', st=start_datetime, et=end_datetime,
|
|
rl=1, ex_details=['BTC/USD', '2h', 'binance']).sort_values(
|
|
by='open_time').reset_index(drop=True)
|
|
|
|
print("Columns in the result DataFrame:", result.columns)
|
|
print("Result DataFrame:\n", result)
|
|
|
|
# Remove 'id' column from the result DataFrame if it exists
|
|
if 'id' in result.columns:
|
|
result = result.drop(columns=['id'])
|
|
|
|
expected = pd.DataFrame({
|
|
'market_id': [None],
|
|
'open_time': [unix_time_millis(dt.datetime.utcnow())],
|
|
'open': [1.0],
|
|
'high': [1.0],
|
|
'low': [1.0],
|
|
'close': [1.0],
|
|
'volume': [1.0]
|
|
})
|
|
|
|
print("Expected DataFrame:\n", expected)
|
|
|
|
pd.testing.assert_frame_equal(result, expected)
|
|
|
|
def test_populate_db(self):
|
|
print('Testing _populate_db() method:')
|
|
start_time = dt.datetime.utcnow() - dt.timedelta(days=1)
|
|
end_time = dt.datetime.utcnow()
|
|
|
|
result = self.data._populate_db(table_name='test_table', start_time=start_time,
|
|
end_time=end_time, ex_details=['BTC/USD', '2h', 'binance', 'test_guy'])
|
|
|
|
self.assertIsInstance(result, pd.DataFrame)
|
|
self.assertFalse(result.empty)
|
|
|
|
def test_fetch_candles_from_exchange(self):
|
|
print('Testing _fetch_candles_from_exchange() method:')
|
|
start_time = dt.datetime.utcnow() - dt.timedelta(days=1)
|
|
end_time = dt.datetime.utcnow()
|
|
|
|
result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h', exchange_name='binance',
|
|
user_name='test_guy', start_datetime=start_time,
|
|
end_datetime=end_time)
|
|
|
|
self.assertIsInstance(result, pd.DataFrame)
|
|
self.assertFalse(result.empty)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|