From 4130e0ca9abd4f2c60e71dde76b9c8b4cf87d7d2 Mon Sep 17 00:00:00 2001 From: Rob Date: Sat, 3 Aug 2024 16:56:13 -0300 Subject: [PATCH] Completed unittests for Database and DataCache. --- src/BrighterTrades.py | 2 +- src/DataCache.py | 292 +++++++++-- src/Database.py | 352 +++++++++++++ src/Users.py | 2 +- src/candles.py | 7 +- src/database.py | 496 ------------------ tests/test_DataCache.py | 132 ++++- tests/test_database.py | 297 ++++++++--- ...n.py => test_live_exchange_integration.py} | 2 +- 9 files changed, 942 insertions(+), 640 deletions(-) create mode 100644 src/Database.py delete mode 100644 src/database.py rename tests/{test_binance_integration.py => test_live_exchange_integration.py} (99%) diff --git a/src/BrighterTrades.py b/src/BrighterTrades.py index c1d524b..61d31d4 100644 --- a/src/BrighterTrades.py +++ b/src/BrighterTrades.py @@ -27,7 +27,7 @@ class BrighterTrades: self.signals = Signals(self.config.signals_list) # Object that maintains candlestick and price data. - self.candles = Candles(config_obj=self.config, exchanges=self.exchanges, database=self.data) + self.candles = Candles(config_obj=self.config, exchanges=self.exchanges, data_source=self.data) # Object that interacts with and maintains data from available indicators self.indicators = Indicators(self.candles, self.config) diff --git a/src/DataCache.py b/src/DataCache.py index 2a41d71..0d2f4de 100644 --- a/src/DataCache.py +++ b/src/DataCache.py @@ -1,25 +1,41 @@ from typing import Any, List - import pandas as pd import datetime as dt +from Database import Database +from shared_utilities import query_satisfied, query_uptodate, unix_time_millis, timeframe_to_minutes +import logging -from database import Database -from shared_utilities import query_satisfied, query_uptodate, unix_time_millis +# Set up logging +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) class DataCache: """ - Fetches manages data limits and optimises memory storage. + Fetches manages data limits and optimises memory storage. + Handles connections and operations for the given exchanges. + + Example usage: + -------------- + db = DataCache(exchanges=some_exchanges_object) """ def __init__(self, exchanges): + """ + Initializes the DataCache class. - # Define the cache size - self.max_tables = 50 # Maximum amount of tables that will be cached at any given time. - self.max_records = 1000 # Maximum number of records that will be kept per table. - self.cached_data = {} # A dictionary that holds all the cached records. - self.db = Database( - exchanges) # The class that handles the DB interactions pass it a connection to the exchange_interface. + :param exchanges: The exchanges object handling communication with connected exchanges. + """ + # Maximum number of tables to cache at any given time. + self.max_tables = 50 + # Maximum number of records to be kept per table. + self.max_records = 1000 + # A dictionary that holds all the cached records. + self.cached_data = {} + # The class that handles the DB interactions. + self.db = Database() + # The class that handles exchange interactions. + self.exchanges = exchanges def cache_exists(self, key: str) -> bool: """ @@ -89,27 +105,32 @@ class DataCache: """ Return any records from the cache indexed by table_name that are newer than start_datetime. - :param ex_details: Details to pass to the server - :param key: : - The dictionary table_name of the records. - :param start_datetime: : - The datetime of the first record requested. - :param record_length: The timespan of the records. + :param ex_details: List[str] - Details to pass to the server + :param key: str - The dictionary table_name of the records. + :param start_datetime: dt.datetime - The datetime of the first record requested. + :param record_length: int - The timespan of the records. - :return pd.DataFrame: - The Requested records + :return: pd.DataFrame - The Requested records + + Example: + -------- + records = data_cache.get_records_since('BTC/USD_2h_binance', dt.datetime.utcnow() - dt.timedelta(minutes=60), 60, ['BTC/USD', '2h', 'binance']) """ try: # End time of query defaults to the current time. end_datetime = dt.datetime.utcnow() if self.cache_exists(key=key): - print('\nGetting records from cache.') - # If the records exist retrieve them from the cache. + logger.debug('Getting records from cache.') + # If the records exist, retrieve them from the cache. records = self.get_cache(key) else: - # If they don't exist in cache, get them from the db. - print(f'\nRecords not in cache. Requesting from DB: starting at: {start_datetime} to: {end_datetime}') - records = self.db.get_records_since(table_name=key, st=start_datetime, - et=end_datetime, rl=record_length, ex_details=ex_details) - print(f'Got {len(records.index)} records from db') + # If they don't exist in cache, get them from the database. + logger.debug( + f'Records not in cache. Requesting from DB: starting at: {start_datetime} to: {end_datetime}') + records = self.get_records_since_from_db(table_name=key, st=start_datetime, + et=end_datetime, rl=record_length, ex_details=ex_details) + logger.debug(f'Got {len(records.index)} records from DB.') self.set_cache(data=records, key=key) # Check if the records in the cache go far enough back to satisfy the query. @@ -119,13 +140,14 @@ class DataCache: if first_timestamp: # The records didn't go far enough back if a timestamp was returned. end_time = dt.datetime.utcfromtimestamp(first_timestamp) - # Request candles with open_times between [start_time:end_time] from the database. - print(f'requesting additional records from {start_datetime} to {end_time}') - additional_records = self.db.get_records_since(table_name=key, st=start_datetime, - et=end_time, rl=record_length, ex_details=ex_details) - print(f'Got {len(additional_records.index)} additional records from db') + logger.debug(f'Requesting additional records from {start_datetime} to {end_time}') + # Request additional records from the database. + additional_records = self.get_records_since_from_db(table_name=key, st=start_datetime, + et=end_time, rl=record_length, + ex_details=ex_details) + logger.debug(f'Got {len(additional_records.index)} additional records from DB.') if not additional_records.empty: - # If more records were received update the cache. + # If more records were received, update the cache. self.update_candle_cache(additional_records, key) # Check if the records received are up-to-date. @@ -134,21 +156,219 @@ class DataCache: if last_timestamp: # The query was not up-to-date if a timestamp was returned. start_time = dt.datetime.utcfromtimestamp(last_timestamp) - print(f'requesting additional records from {start_time} to {end_datetime}') - # Request the database update its table starting from start_datetime. - additional_records = self.db.get_records_since(table_name=key, st=start_time, - et=end_datetime, rl=record_length, - ex_details=ex_details) - print(f'Got {len(additional_records.index)} additional records from db') + logger.debug(f'Requesting additional records from {start_time} to {end_datetime}') + # Request additional records from the database. + additional_records = self.get_records_since_from_db(table_name=key, st=start_time, + et=end_datetime, rl=record_length, + ex_details=ex_details) + logger.debug(f'Got {len(additional_records.index)} additional records from DB.') if not additional_records.empty: self.update_candle_cache(additional_records, key) # Create a UTC timestamp. _timestamp = unix_time_millis(start_datetime) - # Return all records equal or newer than timestamp. + logger.debug(f"Start timestamp in human-readable form: {dt.datetime.utcfromtimestamp(_timestamp / 1000.0)}") + + # Return all records equal to or newer than the timestamp. result = self.get_cache(key).query('open_time >= @_timestamp') return result except Exception as e: - print(f"An error occurred: {str(e)}") + logger.error(f"An error occurred: {str(e)}") raise + + def get_records_since_from_db(self, table_name: str, st: dt.datetime, + et: dt.datetime, rl: float, ex_details: list) -> pd.DataFrame: + """ + Returns records from a specified table that meet a criteria and ensures the records are complete. + If the records do not go back far enough or are not up-to-date, it fetches additional records + from an exchange and populates the table. + + :param table_name: Database table name. + :param st: Start datetime. + :param et: End datetime. + :param rl: Timespan in minutes each record represents. + :param ex_details: Exchange details [symbol, interval, exchange_name]. + :return: DataFrame of records. + + Example: + -------- + records = db.get_records_since('test_table', start_time, end_time, 1, ['BTC/USDT', '1m', 'binance']) + """ + + def add_data(data, tn, start_t, end_t): + new_records = self._populate_db(table_name=tn, start_time=start_t, end_time=end_t, ex_details=ex_details) + print(f'Got {len(new_records.index)} records from exchange_name') + if not new_records.empty: + data = pd.concat([data, new_records], axis=0, ignore_index=True) + data = data.drop_duplicates(subset="open_time", keep='first') + return data + + if self.db.table_exists(table_name=table_name): + print('\nTable existed retrieving records from DB') + print(f'Requesting from {st} to {et}') + records = self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=st, et=et) + print(f'Got {len(records.index)} records from db') + else: + print(f'\nTable didnt exist fetching from {ex_details[2]}') + temp = (((unix_time_millis(et) - unix_time_millis(st)) / 1000) / 60) / rl + print(f'Requesting from {st} to {et}, Should be {temp} records') + records = self._populate_db(table_name=table_name, start_time=st, end_time=et, ex_details=ex_details) + print(f'Got {len(records.index)} records from {ex_details[2]}') + + first_timestamp = query_satisfied(start_datetime=st, records=records, r_length_min=rl) + if first_timestamp: + print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}') + print(f'first ts on record is: {first_timestamp}') + end_time = dt.datetime.utcfromtimestamp(first_timestamp) + print(f'Requesting from {st} to {end_time}') + records = add_data(data=records, tn=table_name, start_t=st, end_t=end_time) + + last_timestamp = query_uptodate(records=records, r_length_min=rl) + if last_timestamp: + print(f'\nRecords were not updated. Requesting from {ex_details[2]}.') + print(f'the last record on file is: {last_timestamp}') + start_time = dt.datetime.utcfromtimestamp(last_timestamp) + print(f'Requesting from {start_time} to {et}') + records = add_data(data=records, tn=table_name, start_t=start_time, end_t=et) + + return records + + def _populate_db(self, table_name: str, start_time: dt.datetime, ex_details: list, + end_time: dt.datetime = None) -> pd.DataFrame: + """ + Populates a database table with records from the exchange. + + :param table_name: Name of the table in the database. + :param start_time: Start time to fetch the records from. + :param end_time: End time to fetch the records until (optional). + :param ex_details: Exchange details [symbol, interval, exchange_name, user_name]. + :return: DataFrame of the data downloaded. + + Example: + -------- + records = db._populate_table('test_table', start_time, ['BTC/USDT', '1m', 'binance', 'user1']) + """ + if end_time is None: + end_time = dt.datetime.utcnow() + sym, inter, ex, un = ex_details + records = self._fetch_candles_from_exchange(symbol=sym, interval=inter, exchange_name=ex, user_name=un, + start_datetime=start_time, end_datetime=end_time) + if not records.empty: + self.db.insert_candles_into_db(records, table_name=table_name, symbol=sym, exchange_name=ex) + else: + print(f'No records inserted {records}') + return records + + def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str, + start_datetime: dt.datetime = None, + end_datetime: dt.datetime = None) -> pd.DataFrame: + """ + Fetches and returns all candles from the specified market, timeframe, and exchange. + + :param symbol: Symbol of the market. + :param interval: Timeframe in the format '' (e.g., '15m', '4h'). + :param exchange_name: Name of the exchange. + :param user_name: Name of the user. + :param start_datetime: Start datetime for fetching data (optional). + :param end_datetime: End datetime for fetching data (optional). + :return: DataFrame of candle data. + + Example: + -------- + candles = db._fetch_candles_from_exchange('BTC/USDT', '1m', 'binance', 'user1', start_time, end_time) + """ + + def fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame: + """ + Fills gaps in the data by replicating the last known data point for the missing periods. + + :param records: DataFrame containing the original records. + :param interval: Interval of the data (e.g., '1m', '5m'). + :return: DataFrame with gaps filled. + + Example: + -------- + filled_records = fill_data_holes(df, '1m') + """ + time_span = timeframe_to_minutes(interval) + last_timestamp = None + filled_records = [] + + logger.info(f"Starting to fill data holes for interval: {interval}") + + for index, row in records.iterrows(): + time_stamp = row['open_time'] + + # If last_timestamp is None, this is the first record + if last_timestamp is None: + last_timestamp = time_stamp + filled_records.append(row) + logger.debug(f"First timestamp: {time_stamp}") + continue + + # Calculate the difference in milliseconds and minutes + delta_ms = time_stamp - last_timestamp + delta_minutes = (delta_ms / 1000) / 60 + + logger.debug(f"Timestamp: {time_stamp}, Delta minutes: {delta_minutes}") + + # If the gap is larger than the time span of the interval, fill the gap + if delta_minutes > time_span: + num_missing_rec = int(delta_minutes / time_span) + step = int(delta_ms / num_missing_rec) + logger.debug(f"Gap detected. Filling {num_missing_rec} records with step: {step}") + + for ts in range(int(last_timestamp) + step, int(time_stamp), step): + new_row = row.copy() + new_row['open_time'] = ts + filled_records.append(new_row) + logger.debug(f"Filled timestamp: {ts}") + + filled_records.append(row) + last_timestamp = time_stamp + + logger.info("Data holes filled successfully.") + return pd.DataFrame(filled_records) + + # Default start date if not provided + if start_datetime is None: + start_datetime = dt.datetime(year=2017, month=1, day=1) + + # Default end date if not provided + if end_datetime is None: + end_datetime = dt.datetime.utcnow() + + # Check if start date is greater than end date + if start_datetime > end_datetime: + raise ValueError("Invalid start and end parameters: start_datetime must be before end_datetime.") + + # Get the exchange object + exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name) + + # Calculate the expected number of records + temp = (((unix_time_millis(end_datetime) - unix_time_millis( + start_datetime)) / 1000) / 60) / timeframe_to_minutes(interval) + + logger.info(f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {temp}') + + # If start and end times are the same, set end_datetime to None + if start_datetime == end_datetime: + end_datetime = None + + # Fetch historical candlestick data from the exchange + candles = exchange.get_historical_klines(symbol=symbol, interval=interval, start_dt=start_datetime, + end_dt=end_datetime) + num_rec_records = len(candles.index) + + logger.info(f'{num_rec_records} candles retrieved from the exchange.') + + # Check if the retrieved data covers the expected time range + open_times = candles.open_time + estimated_num_records = (((open_times.max() - open_times.min()) / 1000) / 60) / timeframe_to_minutes(interval) + + # Fill in any missing data if the retrieved data is less than expected + if num_rec_records < estimated_num_records: + candles = fill_data_holes(candles, interval) + + return candles diff --git a/src/Database.py b/src/Database.py new file mode 100644 index 0000000..cecbe7c --- /dev/null +++ b/src/Database.py @@ -0,0 +1,352 @@ +import sqlite3 +from functools import lru_cache +from typing import Any +import config +import datetime as dt +import pandas as pd +from shared_utilities import unix_time_millis + + +class SQLite: + """ + Context manager for SQLite database connections. + Accepts a database file name or defaults to the file in config.DB_FILE. + + Example usage: + -------------- + with SQLite(db_file='test_db.sqlite') as con: + cursor = con.cursor() + cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)') + """ + + def __init__(self, db_file=None): + self.db_file = db_file if db_file else config.DB_FILE + self.connection = sqlite3.connect(self.db_file) + + def __enter__(self): + return self.connection + + def __exit__(self, exc_type, exc_val, exc_tb): + self.connection.commit() + self.connection.close() + + +class HDict(dict): + """ + Hashable dictionary to use as cache keys. + + Example usage: + -------------- + hdict = HDict({'key1': 'value1', 'key2': 'value2'}) + hash(hdict) + """ + + def __hash__(self): + return hash(frozenset(self.items())) + + +def make_query(item: str, table: str, columns: list) -> str: + """ + Creates a SQL select query string with the required number of placeholders. + + :param item: The field to select. + :param table: The table to select from. + :param columns: List of columns for the where clause. + :return: The query string. + + Example: + -------- + query = make_query('id', 'test_table', ['name', 'age']) + # Result: 'SELECT id FROM test_table WHERE name = ? AND age = ?;' + """ + an_itr = iter(columns) + k = next(an_itr) + where_str = f"SELECT {item} FROM {table} WHERE {k} = ?" + where_str += "".join([f" AND {k} = ?" for k in an_itr]) + ';' + return where_str + + +def make_insert(table: str, values: tuple) -> str: + """ + Creates a SQL insert query string with the required number of placeholders. + + :param table: The table to insert into. + :param values: Tuple of values to insert. + :return: The query string. + + Example: + -------- + insert = make_insert('test_table', ('name', 'age')) + # Result: "INSERT INTO test_table ('name', 'age') VALUES(?, ?);" + """ + itr1 = iter(values) + itr2 = iter(values) + k1 = next(itr1) + _ = next(itr2) + insert_str = f"INSERT INTO {table} ('{k1}'" + insert_str += "".join([f", '{k1}'" for k1 in itr1]) + ") VALUES(?" + "".join( + [", ?" for _ in enumerate(itr2)]) + ");" + return insert_str + + +class Database: + """ + Database class to communicate and maintain the database. + Handles connections and operations for the given exchanges. + + Example usage: + -------------- + db = Database(exchanges=some_exchanges_object, db_file='test_db.sqlite') + db.execute_sql('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)') + """ + + def __init__(self, db_file=None): + """ + Initializes the Database class. + + :param db_file: Optional database file name. + """ + self.db_file = db_file + + def execute_sql(self, sql: str) -> None: + """ + Executes a raw SQL statement. + + :param sql: SQL statement to execute. + + Example: + -------- + db = Database(exchanges=some_exchanges_object, db_file='test_db.sqlite') + db.execute_sql('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)') + """ + with SQLite(self.db_file) as con: + cur = con.cursor() + cur.execute(sql) + + def get_item_where(self, item_name: str, table_name: str, filter_vals: tuple) -> int: + """ + Returns an item from a table where the filter results should isolate a single row. + + :param item_name: Name of the item to fetch. + :param table_name: Name of the table. + :param filter_vals: Tuple of column name and value to filter by. + :return: The item. + + Example: + -------- + item = db.get_item_where('name', 'test_table', ('id', 1)) + # Fetches the 'name' from 'test_table' where 'id' is 1 + """ + with SQLite(self.db_file) as con: + cur = con.cursor() + qry = make_query(item_name, table_name, [filter_vals[0]]) + cur.execute(qry, (filter_vals[1],)) + if user_id := cur.fetchone(): + return user_id[0] + else: + error = f"Couldn't fetch item {item_name} from {table_name} where {filter_vals[0]} = {filter_vals[1]}" + raise ValueError(error) + + def get_rows_where(self, table: str, filter_vals: tuple) -> pd.DataFrame | None: + """ + Returns a DataFrame containing all rows of a table that meet the filter criteria. + + :param table: Name of the table. + :param filter_vals: Tuple of column name and value to filter by. + :return: DataFrame of the query result or None if empty. + + Example: + -------- + rows = db.get_rows_where('test_table', ('name', 'test')) + # Fetches all rows from 'test_table' where 'name' is 'test' + """ + with SQLite(self.db_file) as con: + qry = f"SELECT * FROM {table} WHERE {filter_vals[0]}='{filter_vals[1]}'" + result = pd.read_sql(qry, con=con) + if not result.empty: + return result + else: + return None + + def insert_dataframe(self, df: pd.DataFrame, table: str) -> None: + """ + Inserts a DataFrame into a specified table. + + :param df: DataFrame to insert. + :param table: Name of the table. + + Example: + -------- + df = pd.DataFrame({'id': [1], 'name': ['test']}) + db.insert_dataframe(df, 'test_table') + """ + with SQLite(self.db_file) as con: + df.to_sql(name=table, con=con, index=False, if_exists='append') + + def insert_row(self, table: str, columns: tuple, values: tuple) -> None: + """ + Inserts a row into a specified table. + + :param table: Name of the table. + :param columns: Tuple of column names. + :param values: Tuple of values to insert. + + Example: + -------- + db.insert_row('test_table', ('id', 'name'), (1, 'test')) + """ + with SQLite(self.db_file) as conn: + cursor = conn.cursor() + sql = make_insert(table=table, values=columns) + cursor.execute(sql, values) + + def table_exists(self, table_name: str) -> bool: + """ + Checks if a table exists in the database. + + :param table_name: Name of the table. + :return: True if the table exists, False otherwise. + + Example: + -------- + exists = db._table_exists('test_table') + # Checks if 'test_table' exists in the database + """ + with SQLite(self.db_file) as conn: + cursor = conn.cursor() + sql = make_query(item='name', table='sqlite_schema', columns=['type', 'name']) + cursor.execute(sql, ('table', table_name)) + result = cursor.fetchone() + return result is not None + + @lru_cache(maxsize=1000) + def get_from_static_table(self, item: str, table: str, indexes: HDict, create_id: bool = False) -> Any: + """ + Returns the row id of an item from a table specified. If the item isn't listed in the table, + it will insert the item into a new row and return the autoincremented id. The item received as a hashable + dictionary so the results can be cached. + + :param item: Name of the item requested. + :param table: Table being queried. + :param indexes: Hashable dictionary of indexing columns and their values. + :param create_id: If True, create a row if it doesn't exist and return the autoincrement ID. + :return: The content of the field. + + Example: + -------- + exchange_id = db.get_from_static_table('id', 'exchange', HDict({'name': 'binance'}), create_id=True) + """ + with SQLite(self.db_file) as conn: + cursor = conn.cursor() + sql = make_query(item, table, list(indexes.keys())) + cursor.execute(sql, tuple(indexes.values())) + result = cursor.fetchone() + if result is None and create_id: + sql = make_insert(table, tuple(indexes.keys())) + cursor.execute(sql, tuple(indexes.values())) + sql = make_query(item, table, list(indexes.keys())) + cursor.execute(sql, tuple(indexes.values())) + result = cursor.fetchone() + return result[0] if result else None + + def _fetch_exchange_id(self, exchange_name: str) -> int: + """ + Fetches the primary ID of an exchange from the database. + + :param exchange_name: Name of the exchange. + :return: Primary ID of the exchange. + + Example: + -------- + exchange_id = db._fetch_exchange_id('binance') + """ + return self.get_from_static_table(item='id', table='exchange', create_id=True, + indexes=HDict({'name': exchange_name})) + + def _fetch_market_id(self, symbol: str, exchange_name: str) -> int: + """ + Returns the market ID for a trading pair listed in the database. + + :param symbol: Symbol of the trading pair. + :param exchange_name: Name of the exchange. + :return: Market ID. + + Example: + -------- + market_id = db._fetch_market_id('BTC/USDT', 'binance') + """ + exchange_id = self._fetch_exchange_id(exchange_name) + market_id = self.get_from_static_table(item='id', table='markets', create_id=True, + indexes=HDict({'symbol': symbol, 'exchange_id': exchange_id})) + return market_id + + def insert_candles_into_db(self, candlesticks: pd.DataFrame, table_name: str, symbol: str, + exchange_name: str) -> None: + """ + Inserts all candlesticks from a DataFrame into the database. + + :param candlesticks: DataFrame of candlestick data. + :param table_name: Name of the table to insert into. + :param symbol: Symbol of the trading pair. + :param exchange_name: Name of the exchange. + + Example: + -------- + df = pd.DataFrame({ + 'open_time': [unix_time_millis(dt.datetime.utcnow())], + 'open': [1.0], + 'high': [1.0], + 'low': [1.0], + 'close': [1.0], + 'volume': [1.0] + }) + db._insert_candles_into_db(df, 'test_table', 'BTC/USDT', 'binance') + """ + market_id = self._fetch_market_id(symbol, exchange_name) + candlesticks.insert(0, 'market_id', market_id) + sql_create = f""" + CREATE TABLE IF NOT EXISTS '{table_name}' ( + id INTEGER PRIMARY KEY, + market_id INTEGER, + open_time INTEGER UNIQUE ON CONFLICT IGNORE, + open REAL NOT NULL, + high REAL NOT NULL, + low REAL NOT NULL, + close REAL NOT NULL, + volume REAL NOT NULL, + FOREIGN KEY (market_id) REFERENCES market (id) + )""" + with SQLite(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute(sql_create) + candlesticks.to_sql(table_name, conn, if_exists='append', index=False) + + def get_timestamped_records(self, table_name: str, timestamp_field: str, st: dt.datetime, + et: dt.datetime = None) -> pd.DataFrame: + """ + Returns records from a specified table in the database that have timestamps greater than or equal to a given + start time and, optionally, less than or equal to a given end time. + + :param table_name: Database table name. + :param timestamp_field: Field name that contains the timestamp. + :param st: Start datetime. + :param et: End datetime (optional). + :return: DataFrame of records. + + Example: + -------- + records = db.get_timestamped_records('test_table', 'open_time', start_time, end_time) + """ + with SQLite(self.db_file) as conn: + start_stamp = unix_time_millis(st) + if et is not None: + end_stamp = unix_time_millis(et) + q_str = ( + f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= {start_stamp} " + f"AND {timestamp_field} <= {end_stamp};" + ) + else: + q_str = f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= {start_stamp};" + records = pd.read_sql(q_str, conn) + records = records.drop('id', axis=1) + return records diff --git a/src/Users.py b/src/Users.py index 51c54af..7bcb84a 100644 --- a/src/Users.py +++ b/src/Users.py @@ -5,7 +5,7 @@ from typing import Any from passlib.hash import bcrypt import pandas as pd -from database import HDict +from Database import HDict class Users: diff --git a/src/candles.py b/src/candles.py index 271237a..def398a 100644 --- a/src/candles.py +++ b/src/candles.py @@ -1,13 +1,12 @@ import datetime as dt import logging as log -from DataCache import DataCache from shared_utilities import timeframe_to_minutes, ts_of_n_minutes_ago # log.basicConfig(level=log.ERROR) class Candles: - def __init__(self, exchanges, config_obj, database): + def __init__(self, exchanges, config_obj, data_source): # A reference to the app configuration self.config = config_obj @@ -15,8 +14,8 @@ class Candles: # The maximum amount of candles to load at one time. self.max_records = self.config.app_data.get('max_data_loaded') - # This object maintains all the cached data. Pass it connection to the exchanges. - self.data = database + # This object maintains all the cached data. + self.data = data_source # print('Setting the candle cache.') # # Populate the cache: diff --git a/src/database.py b/src/database.py deleted file mode 100644 index 36b30f8..0000000 --- a/src/database.py +++ /dev/null @@ -1,496 +0,0 @@ -import sqlite3 -from functools import lru_cache -from typing import Any -import config -import datetime as dt -import pandas as pd - -from shared_utilities import query_satisfied, query_uptodate, unix_time_millis, timeframe_to_minutes - - -class SQLite: - """ - Context manager returns a cursor. The connection is closed when - the cursor is destroyed, even if an exception is thrown. - """ - - def __init__(self): - self.connection = sqlite3.connect(config.DB_FILE) - - def __enter__(self): - return self.connection - - def __exit__(self, exc_type, exc_val, exc_tb): - self.connection.commit() - self.connection.close() - - -class HDict(dict): - def __hash__(self): - return hash(frozenset(self.items())) - - -def make_query(item: str, table: str, columns: list) -> str: - """ - Creates a sql select string with the required number of ?'s to match the given columns. - - :param item: The field to select. - :param table: The table to select. - :param columns: list - A list of database columns. - :return: str: The query string. - """ - an_itr = iter(columns) - k = next(an_itr) - where_str = f"SELECT {item} FROM {table} WHERE {k} = ?" - where_str = where_str + "".join([f" AND {k} = ?" for k in an_itr]) + ';' - return where_str - - -def make_insert(table: str, values: tuple) -> str: - """ - Creates a sql insert string with the required number of ?'s to match the given values. - - :param table: The table to insert into. - :param values: dict - A dictionary of table_name-value pairs used to index a db query. - :return: str: The query string. - """ - itr1 = iter(values) - itr2 = iter(values) - k1 = next(itr1) - _ = next(itr2) - insert_str = f"INSERT INTO {table} ('{k1}'" - insert_str = insert_str + "".join([f", '{k1}'" for k1 in itr1]) + ") VALUES(?" + "".join( - [", ?" for _ in enumerate(itr2)]) + ");" - return insert_str - - -class Database: - """ - Communicates and maintains the database. - """ - - def __init__(self, exchanges): - # The exchanges object handles communication with all connected exchanges. - self.exchanges = exchanges - - @staticmethod - def execute_sql(sql: str) -> None: - """ - Executes a sql statement. This is for stuff I haven't created a function for yet. - - :param sql: str - sql statement. - :return: None - """ - with SQLite() as con: - cur = con.cursor() - cur.execute(sql) - - @staticmethod - def get_item_where(item_name: str, table_name: str, filter_vals: tuple) -> int: - """ - Returns an item from a table where the filter results should isolate a single row. - - :param item_name: str - The name of the item to fetch. - :param table_name: str - The name of the table. - :param filter_vals: tuple(str, str) - The column and value to filter the results with. - :return: str - The item. - """ - with SQLite() as con: - cur = con.cursor() - qry = make_query(item_name, table_name, [filter_vals[0]]) - cur.execute(qry, (filter_vals[1],)) - if user_id := cur.fetchone(): - return user_id[0] - else: - error = f"Couldn't fetch item{item_name} from {table_name} where {filter_vals[0]} = {filter_vals[1]}" - raise ValueError(error) - - @staticmethod - def get_rows_where(table: str, filter_vals: tuple) -> pd.DataFrame | None: - """ - Returns a dataframe containing all rows of a table that meet the filter criteria. - - :param table: str - The name of the table. - :param filter_vals: tuple(column: str, value: str) - the criteria - :return: dataframe|None - returns the data in a dataframe or None if the query fails. - """ - with SQLite() as con: - qry = f"select * from {table} where {filter_vals[0]}='{filter_vals[1]}'" - result = pd.read_sql(qry, con=con) - if not result.empty: - return result - else: - return None - - @staticmethod - def insert_dataframe(df, table): - # Connect to the database. - with SQLite() as con: - # Insert the modified user as a new record in the table. - df.to_sql(name=table, con=con, index=False, if_exists='append') - # Commit the changes to the database. - con.commit() - - @staticmethod - def insert_row(table: str, columns: tuple, values: tuple) -> None: - """ - Saves user specific data from a table in the database. - - :param table: str - The table to insert into - :param columns: tuple(str1, str2, ...) - The columns of the database. - :param values: tuple(val1, val2, ...) - The values to be inserted. - :return: None - """ - # Connect to the database. - with SQLite() as conn: - # Get a cursor from the sql connection. - cursor = conn.cursor() - sql = make_insert(table=table, values=columns) - cursor.execute(sql, values) - - @staticmethod - def _table_exists(table_name: str) -> bool: - """ - Returns True if table_name exists in the database. - - :param table_name: The name of the database. - :return: bool - True|False - """ - # Connect to the database. - with SQLite() as conn: - # Get a cursor from the sql connection. - cursor = conn.cursor() - # sql = f"SELECT name FROM sqlite_schema WHERE type = 'table' AND name = '{table_name}';" - sql = make_query(item='name', table='sqlite_schema', columns=['type', 'name']) - # Check if the table exists. - cursor.execute(sql, ('table', table_name)) - # Fetch the results from the cursor. - result = cursor.fetchone() - if not result: - # If the table doesn't exist return False. - return False - return True - - def _populate_table(self, table_name: str, start_time: dt.datetime, ex_details: list, end_time: dt.datetime = None): - """ - Populates a database table with records from the exchange_name. - :param table_name: str - The name of the table in the database. - :param start_time: datetime - The starting time to fetch the records from. - :param end_time: datetime - The end time to get the records until. - :return: pdDataframe: - The data that was downloaded. - """ - # Set the default end_time to UTC now. - if end_time is None: - end_time = dt.datetime.utcnow() - # Fetch the records from the exchange_name. - # Extract the parameters from the details. Format: __. - sym, inter, ex, un = ex_details - records = self._fetch_candles_from_exchange(symbol=sym, interval=inter, exchange_name=ex, user_name=un, - start_datetime=start_time, end_datetime=end_time) - # Update the database. - if not records.empty: - # Inert into the database any received records. - self._insert_candles_into_db(records, table_name=table_name, symbol=sym, exchange_name=ex) - else: - print(f'No records inserted {records}') - return records - - @staticmethod - @lru_cache(maxsize=1000) - def get_from_static_table(item: str, table: str, indexes: HDict, create_id: bool = False) -> Any: - """ - Retrieves a single item from a table. This method returns a cached result and is ment - for fetching static data like settings, names and ID's. - - :param create_id: bool: - If True, create a row if it doesn't exist and return the autoincrement ID. - :param item: str - The name of the item requested. - :param table: str - The table being queried. - :param indexes: str - A hashable dictionary containing the indexing columns and their values. - :return: Any - The content of the field. - """ - - # Connect to the database. - with SQLite() as conn: - # Get a cursor from the sql connection. - cursor = conn.cursor() - # Retrieve the record from the db. - sql = make_query(item, table, list(indexes.keys())) - cursor.execute(sql, tuple(indexes.values())) - # The result is returned as tuple. Example: (id,) - result = cursor.fetchone() - - if result is None and create_id is True: - # Insert the indexes into the db. - sql = make_insert(table, tuple(indexes.keys())) - cursor.execute(sql, tuple(indexes.values())) - # Retrieve the record from the db. - sql = make_query(item, table, list(indexes.keys())) - cursor.execute(sql, tuple(indexes.values())) - # Get the first element of the tuple received from sql query. - result = cursor.fetchone() - - # Return the result from the tuple if it exists. - if result: - return result[0] - else: - return None - - def _fetch_exchange_id(self, exchange_name: str) -> int: - """ - Fetch the primary id of exchange_name from the database. - - :param exchange_name: str - The name of the exchange_name. - :return: int - The primary id of the exchange_name. - """ - return self.get_from_static_table(item='id', table='exchange', create_id=True,indexes=HDict({'name': exchange_name})) - - def _fetch_market_id(self, symbol: str, exchange_name: str) -> int: - """ - Returns the market id that belongs to a trading pair listed in the database. - - :param symbol: str - The symbol of the trading pair. - :param exchange_name: str - The exchange_name name. - :return: int - The market ID - """ - # Fetch the id of the exchange_name. - exchange_id = self._fetch_exchange_id(exchange_name) - - # Ask the db for the market_id. Tell it to create one if it doesn't already exist. - market_id = self.get_from_static_table(item='id', table='markets', create_id=True, - indexes=HDict({'symbol': symbol, 'exchange_id': exchange_id})) - # Return the market id. - return market_id - - def _insert_candles_into_db(self, candlesticks, table_name: str, symbol, exchange_name) -> None: - """ - Insert all the candlesticks from a dataframe into the database. - - :param exchange_name: The name of the exchange_name. - :param symbol: The symbol of the trading pair. - :param candlesticks: pd.dataframe - A rows of candlestick attributes. - :param table_name: str - The name of the table to inset. - :return: None - """ - - # Retrieve the market id for the symbol. - market_id = self._fetch_market_id(symbol, exchange_name) - # Insert the market id into the dataframe. - candlesticks.insert(0, 'market_id', market_id) - # Create a table schema. todo delete these line if not needed anymore - # # Get a list of all the columns in the dataframe. - # columns = list(candlesticks.columns.values) - # # Isolate any extra columns specific to individual exchanges. - # # The carriage return and tabs are unnecessary, they just tidy output for debugging. - # columns = ',\n\t\t\t\t\t'.join(columns[7:], ) - # # Define the columns common with all exchanges and append any extras columns. - sql_create = f""" - CREATE TABLE IF NOT EXISTS '{table_name}' ( - id INTEGER PRIMARY KEY, - market_id INTEGER, - open_time INTEGER UNIQUE ON CONFLICT IGNORE, - open REAL NOT NULL, - high REAL NOT NULL, - low REAL NOT NULL, - close REAL NOT NULL, - volume REAL NOT NULL, - FOREIGN KEY (market_id) REFERENCES market (id) - )""" - # Connect to the database. - with SQLite() as conn: - # Get a cursor from the sql connection. - cursor = conn.cursor() - # Create the table if it doesn't exist. - cursor.execute(sql_create) - # Insert the candles into the table. - candlesticks.to_sql(table_name, conn, if_exists='append', index=False) - return - - def get_records_since(self, table_name: str, st: dt.datetime, - et: dt.datetime, rl: float, ex_details: list) -> pd.DataFrame: - """ - Returns all the candles newer than the provided start_datetime from the specified table. - - :param ex_details: list of details to pass to the server. [symbol, interval, exchange_name] - :param table_name: str - The database table name. Format: : __. - :param st: dt.datetime.start_datetime - The start_datetime of the first record requested. - :param et: dt.datetime - The end time of the query - :param rl: float - The timespan in minutes each record represents. - :return: pd.dataframe - - """ - - def add_data(data, tn, start_t, end_t): - new_records = self._populate_table(table_name=tn, start_time=start_t, end_time=end_t, ex_details=ex_details) - print(f'Got {len(new_records.index)} records from exchange_name') - if not new_records.empty: - # Combine the new records with the previously records. - data = pd.concat([data, new_records], axis=0, ignore_index=True) - # Drop any duplicates from overlap. - data = data.drop_duplicates(subset="open_time", keep='first') - # Return the modified dataframe. - return data - - if self._table_exists(table_name=table_name): - # If the table exists retrieve all the records. - print('\nTable existed retrieving records from DB') - print(f'Requesting from {st} to {et}') - records = self._get_records(table_name=table_name, st=st, et=et) - print(f'Got {len(records.index)} records from db') - else: - # If the table doesn't exist, get them from the exchange_name. - print(f'\nTable didnt exist fetching from {ex_details[2]}') - temp = (((unix_time_millis(et) - unix_time_millis(st)) / 1000) / 60) / rl - print(f'Requesting from {st} to {et}, Should be {temp} records') - records = self._populate_table(table_name=table_name, start_time=st, end_time=et, ex_details=ex_details) - print(f'Got {len(records.index)} records from {ex_details[2]}') - - # Check if the records in the db go far enough back to satisfy the query. - first_timestamp = query_satisfied(start_datetime=st, records=records, r_length_min=rl) - if first_timestamp: - # The records didn't go far enough back if a timestamp was returned. - print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}') - print(f'first ts on record is: {first_timestamp}') - end_time = dt.datetime.utcfromtimestamp(first_timestamp) - print(f'Requesting from {st} to {end_time}') - # Request records with open_times between [st:end_time] from the database. - records = add_data(data=records, tn=table_name, start_t=st, end_t=end_time) - - # Check if the records received are up-to-date. - last_timestamp = query_uptodate(records=records, r_length_min=rl) - if last_timestamp: - # The query was not up-to-date if a timestamp was returned. - print(f'\nRecords were not updated. Requesting from {ex_details[2]}.') - print(f'the last record on file is: {last_timestamp}') - start_time = dt.datetime.utcfromtimestamp(last_timestamp) - print(f'Requesting from {start_time} to {et}') - # Request records with open_times between [start_time:et] from the database. - records = add_data(data=records, tn=table_name, start_t=start_time, end_t=et) - - return records - - @staticmethod - def _get_records(table_name: str, st: dt.datetime, et: dt.datetime = None) -> pd.DataFrame: - """ - Returns all the candles newer than the provided start_datetime from the specified table. - - :param table_name: str - The database table name. Format: : __. - :param st: dt.datetime.start_datetime - The start_datetime of the first record requested. - :param et: dt.datetime - The end time of the query - :return: pd.dataframe - - """ - # Connect to the database. - with SQLite() as conn: - # Create a timestamp in milliseconds - start_stamp = unix_time_millis(st) - if et is not None: - # Create a timestamp in milliseconds - end_stamp = unix_time_millis(et) - q_str = f"SELECT * FROM '{table_name}' WHERE open_time >= {start_stamp} AND open_time <= {end_stamp};" - else: - q_str = f"SELECT * FROM '{table_name}' WHERE open_time >= {start_stamp};" - # Retrieve all the records from the table. - records = pd.read_sql(q_str, conn) - # Drop the databases primary id. - records = records.drop('id', axis=1) - # Return the data. - return records - - # @staticmethod - # def date_of_last_timestamp(table_name): - # """ - # Returns the latest timestamp stored in the db. - # TODO: Unused. - # - # :return: dt.timestamp - # """ - # # Connect to the database. - # with SQLite() as conn: - # # Get a cursor from the connection. - # cursor = conn.cursor() - # cursor.execute(f"""SELECT open_time FROM '{table_name}' ORDER BY open_time DESC LIMIT 1""") - # ts = cursor.fetchone()[0] / 1000 - # return dt.datetime.utcfromtimestamp(ts) - # - def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str, - start_datetime: object = None, end_datetime: object = None) -> pd.DataFrame: - """ - Fetches and returns all candles from the specified market, timeframe, and exchange_name. - - :param symbol: str - The symbol of the market. - :param interval: str - The timeframe. Format '' - examples: '15m', '4h' - :param exchange_name: str - The name of the exchange_name. - :param start_datetime: dt.datetime - The open_time of the first record requested. - :param end_datetime: dt.datetime - The end_time for the query. - :return: pd.DataFrame: Dataframe containing rows of candle attributes that vary - depending on the exchange_name. - For example: [open_time, open, high, low, close, volume, close_time, - quote_volume, num_trades, taker_buy_base_volume, taker_buy_quote_volume] - """ - - def fill_data_holes(records, interval): - time_span = timeframe_to_minutes(interval) - last_timestamp = None - filled_records = [] - - for _, row in records.iterrows(): - time_stamp = row['open_time'] - - if last_timestamp is None: - last_timestamp = time_stamp - filled_records.append(row) - continue - - delta_ms = time_stamp - last_timestamp - delta_minutes = (delta_ms / 1000) / 60 - - if delta_minutes > time_span: - num_missing_rec = int(delta_minutes / time_span) - step = int(delta_ms / num_missing_rec) - - for ts in range(int(last_timestamp) + step, int(time_stamp), step): - new_row = row.copy() - new_row['open_time'] = ts - filled_records.append(new_row) - - filled_records.append(row) - last_timestamp = time_stamp - - return pd.DataFrame(filled_records) - - # Default start date for fetching from the exchange_name. - if start_datetime is None: - start_datetime = dt.datetime(year=2017, month=1, day=1) - - # Default end date for fetching from the exchange_name. - if end_datetime is None: - end_datetime = dt.datetime.utcnow() - - if start_datetime > end_datetime: - raise ValueError("\ndatabase:fetch_candles_from_exchange():" - " Invalid start and end parameters: ", start_datetime, end_datetime) - - # Get a reference to the exchange - exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name) - - temp = (((unix_time_millis(end_datetime) - unix_time_millis( - start_datetime)) / 1000) / 60) / timeframe_to_minutes(interval) - print(f'Fetching historical data {start_datetime} to {end_datetime}, Should be {temp} records') - - if start_datetime == end_datetime: - end_datetime = None - - # Request candlestick data from the exchange_name. - candles = exchange.get_historical_klines(symbol=symbol, - interval=interval, - start_dt=start_datetime, - end_dt=end_datetime) - num_rec_records = len(candles.index) - print(f'\n{num_rec_records} candles retrieved from the exchange_name.') - # Isolate the open_times from the records received. - open_times = candles.open_time - # Calculate the number of records that would fit between the min and max open time. - estimated_num_records = (((open_times.max() - open_times.min()) / 1000) / 60) / timeframe_to_minutes(interval) - if num_rec_records < estimated_num_records: - # Some records may be missing due to server maintenance periods ect. - # Fill the holes with copies of the last record received before the gap. - candles = fill_data_holes(candles, interval) - return candles diff --git a/tests/test_DataCache.py b/tests/test_DataCache.py index ad9ce95..9ef4376 100644 --- a/tests/test_DataCache.py +++ b/tests/test_DataCache.py @@ -3,15 +3,53 @@ from exchangeinterface import ExchangeInterface import unittest import pandas as pd import datetime as dt +import os +from Database import SQLite, Database +from shared_utilities import unix_time_millis class TestDataCache(unittest.TestCase): def setUp(self): - # Setup the database connection here + # Set the database connection here self.exchanges = ExchangeInterface() self.exchanges.connect_exchange(exchange_name='binance', user_name='test_guy', api_keys=None) # This object maintains all the cached data. Pass it connection to the exchanges. + self.db_file = 'test_db.sqlite' + self.database = Database(db_file=self.db_file) + + # Create necessary tables + with SQLite(db_file=self.db_file) as con: + cursor = con.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS exchange ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE + ) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS markets ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + symbol TEXT, + exchange_id INTEGER, + FOREIGN KEY (exchange_id) REFERENCES exchange(id) + ) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS test_table ( + id INTEGER PRIMARY KEY, + market_id INTEGER, + open_time INTEGER UNIQUE, + open REAL NOT NULL, + high REAL NOT NULL, + low REAL NOT NULL, + close REAL NOT NULL, + volume REAL NOT NULL, + FOREIGN KEY (market_id) REFERENCES markets(id) + ) + """) + self.data = DataCache(self.exchanges) + self.data.db = self.database asset, timeframe, exchange = 'BTC/USD', '2h', 'binance' self.key1 = f'{asset}_{timeframe}_{exchange}' @@ -19,8 +57,11 @@ class TestDataCache(unittest.TestCase): asset, timeframe, exchange = 'ETH/USD', '2h', 'binance' self.key2 = f'{asset}_{timeframe}_{exchange}' + def tearDown(self): + if os.path.exists(self.db_file): + os.remove(self.db_file) + def test_set_cache(self): - # Tests print('Testing set_cache flag not set:') self.data.set_cache(data='data', key=self.key1) attr = self.data.__getattribute__('cached_data') @@ -36,7 +77,6 @@ class TestDataCache(unittest.TestCase): self.assertEqual(attr[self.key1], 'more_data') def test_cache_exists(self): - # Tests print('Testing cache_exists() method:') self.assertFalse(self.data.cache_exists(key=self.key2)) self.data.set_cache(data='data', key=self.key1) @@ -44,7 +84,6 @@ class TestDataCache(unittest.TestCase): def test_update_candle_cache(self): print('Testing update_candle_cache() method:') - # Initial data df_initial = pd.DataFrame({ 'open_time': [1, 2, 3], 'open': [100, 101, 102], @@ -54,7 +93,6 @@ class TestDataCache(unittest.TestCase): 'volume': [1000, 1001, 1002] }) - # Data to be added df_new = pd.DataFrame({ 'open_time': [3, 4, 5], 'open': [102, 103, 104], @@ -96,7 +134,7 @@ class TestDataCache(unittest.TestCase): def test_get_records_since(self): print('Testing get_records_since() method:') df_initial = pd.DataFrame({ - 'open_time': [1, 2, 3], + 'open_time': [unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=i)) for i in range(3)], 'open': [100, 101, 102], 'high': [110, 111, 112], 'low': [90, 91, 92], @@ -105,20 +143,86 @@ class TestDataCache(unittest.TestCase): }) self.data.set_cache(data=df_initial, key=self.key1) - start_datetime = dt.datetime.utcfromtimestamp(2) - result = self.data.get_records_since(key=self.key1, start_datetime=start_datetime, record_length=60, ex_details=[]).sort_values(by='open_time').reset_index(drop=True) + start_datetime = dt.datetime.utcnow() - dt.timedelta(minutes=2) + result = self.data.get_records_since(key=self.key1, start_datetime=start_datetime, record_length=60, + ex_details=['BTC/USD', '2h', 'binance']) expected = pd.DataFrame({ - 'open_time': [2, 3], - 'open': [101, 102], - 'high': [111, 112], - 'low': [91, 92], - 'close': [106, 107], - 'volume': [1001, 1002] + 'open_time': df_initial['open_time'][:2].values, + 'open': [100, 101], + 'high': [110, 111], + 'low': [90, 91], + 'close': [105, 106], + 'volume': [1000, 1001] }) pd.testing.assert_frame_equal(result, expected) + def test_get_records_since_from_db(self): + print('Testing get_records_since_from_db() method:') + df_initial = pd.DataFrame({ + 'market_id': [None], + 'open_time': [unix_time_millis(dt.datetime.utcnow())], + 'open': [1.0], + 'high': [1.0], + 'low': [1.0], + 'close': [1.0], + 'volume': [1.0] + }) + + with SQLite(self.db_file) as con: + df_initial.to_sql('test_table', con, if_exists='append', index=False) + + start_datetime = dt.datetime.utcnow() - dt.timedelta(minutes=1) + end_datetime = dt.datetime.utcnow() + result = self.data.get_records_since_from_db(table_name='test_table', st=start_datetime, et=end_datetime, + rl=1, ex_details=['BTC/USD', '2h', 'binance']).sort_values( + by='open_time').reset_index(drop=True) + + print("Columns in the result DataFrame:", result.columns) + print("Result DataFrame:\n", result) + + # Remove 'id' column from the result DataFrame if it exists + if 'id' in result.columns: + result = result.drop(columns=['id']) + + expected = pd.DataFrame({ + 'market_id': [None], + 'open_time': [unix_time_millis(dt.datetime.utcnow())], + 'open': [1.0], + 'high': [1.0], + 'low': [1.0], + 'close': [1.0], + 'volume': [1.0] + }) + + print("Expected DataFrame:\n", expected) + + pd.testing.assert_frame_equal(result, expected) + + def test_populate_db(self): + print('Testing _populate_db() method:') + start_time = dt.datetime.utcnow() - dt.timedelta(days=1) + end_time = dt.datetime.utcnow() + + result = self.data._populate_db(table_name='test_table', start_time=start_time, + end_time=end_time, ex_details=['BTC/USD', '2h', 'binance', 'test_guy']) + + self.assertIsInstance(result, pd.DataFrame) + self.assertFalse(result.empty) + + def test_fetch_candles_from_exchange(self): + print('Testing _fetch_candles_from_exchange() method:') + start_time = dt.datetime.utcnow() - dt.timedelta(days=1) + end_time = dt.datetime.utcnow() + + result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h', exchange_name='binance', + user_name='test_guy', start_datetime=start_time, + end_datetime=end_time) + + self.assertIsInstance(result, pd.DataFrame) + self.assertFalse(result.empty) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_database.py b/tests/test_database.py index 88001bb..ebadc45 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,96 +1,219 @@ -import datetime +import unittest +import sqlite3 import pandas as pd -from database import make_query, make_insert, Database, HDict, SQLite -from exchangeinterface import ExchangeInterface -from passlib.hash import bcrypt -from sqlalchemy import create_engine, text -import config +import datetime as dt +from Database import Database, SQLite, make_query, make_insert, HDict +from shared_utilities import unix_time_millis -def test(): - # un_hashed_pass = 'password' - hasher = bcrypt.using(rounds=13) - # hashed_pass = hasher.hash(un_hashed_pass) - # print(f'password: {un_hashed_pass}') - # print(f'hashed pass: {hashed_pass}') - # print(f" right pass: {hasher.verify('password', hashed_pass)}") - # print(f" wrong pass: {hasher.verify('passWord', hashed_pass)}") - engine = create_engine("sqlite:///" + config.DB_FILE, echo=True) - with engine.connect() as conn: - default_user = pd.read_sql_query(sql=text("SELECT * FROM users WHERE user_name = 'guest'"), con=conn) - # hashed_password = default_user.password.values[0] - # print(f" verify pass: {hasher.verify('password', hashed_password)}") - username = default_user.user_name.values[0] - print(username) +class TestSQLite(unittest.TestCase): + def test_sqlite_context_manager(self): + print("\nRunning test_sqlite_context_manager...") + with SQLite(db_file='test_db.sqlite') as con: + cursor = con.cursor() + cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)') + cursor.execute("INSERT INTO test_table (name) VALUES ('test')") + cursor.execute('SELECT name FROM test_table WHERE name = ?', ('test',)) + result = cursor.fetchone() + self.assertEqual(result[0], 'test') + print("SQLite context manager test passed.") -def test_make_query(): - values = {'first_field': 'first_value'} - item = 'market_id' - table = 'markets' - q_str = make_query(item=item, table=table, values=values) - print(f'\nWith one indexing field: {q_str}') +class TestDatabase(unittest.TestCase): + def setUp(self): + # Use a temporary SQLite database for testing purposes + self.db_file = 'test_db.sqlite' + self.db = Database(db_file=self.db_file) + self.connection = sqlite3.connect(self.db_file) + self.cursor = self.connection.cursor() - values = {'first_field': 'first_value', 'second_field': 'second_value'} - q_str = make_query(item=item, table=table, values=values) - print(f'\nWith two indexing fields: {q_str}') - assert q_str is not None + def tearDown(self): + self.connection.close() + import os + os.remove(self.db_file) # Remove the temporary database file after tests + + def test_execute_sql(self): + print("\nRunning test_execute_sql...") + # Drop the table if it exists to avoid OperationalError + self.cursor.execute('DROP TABLE IF EXISTS test_table') + self.connection.commit() + + sql = 'CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)' + self.db.execute_sql(sql) + + self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test_table';") + result = self.cursor.fetchone() + self.assertIsNotNone(result) + print("Execute SQL test passed.") + + def test_make_query(self): + print("\nRunning test_make_query...") + query = make_query('id', 'test_table', ['name']) + expected_query = 'SELECT id FROM test_table WHERE name = ?;' + self.assertEqual(query, expected_query) + print("Make query test passed.") + + def test_make_insert(self): + print("\nRunning test_make_insert...") + insert = make_insert('test_table', ('name', 'age')) + expected_insert = "INSERT INTO test_table ('name', 'age') VALUES(?, ?);" + self.assertEqual(insert, expected_insert) + print("Make insert test passed.") + + def test_get_item_where(self): + print("\nRunning test_get_item_where...") + self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)') + self.cursor.execute("INSERT INTO test_table (id, name) VALUES (1, 'test')") + self.connection.commit() + item = self.db.get_item_where('name', 'test_table', ('id', 1)) + self.assertEqual(item, 'test') + print("Get item where test passed.") + + def test_get_rows_where(self): + print("\nRunning test_get_rows_where...") + self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)') + self.cursor.execute("INSERT INTO test_table (id, name) VALUES (1, 'test')") + self.connection.commit() + rows = self.db.get_rows_where('test_table', ('name', 'test')) + self.assertIsInstance(rows, pd.DataFrame) + self.assertEqual(rows.iloc[0]['name'], 'test') + print("Get rows where test passed.") + + def test_insert_dataframe(self): + print("\nRunning test_insert_dataframe...") + df = pd.DataFrame({'id': [1], 'name': ['test']}) + self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)') + self.connection.commit() + self.db.insert_dataframe(df, 'test_table') + self.cursor.execute('SELECT name FROM test_table WHERE id = 1') + result = self.cursor.fetchone() + self.assertEqual(result[0], 'test') + print("Insert dataframe test passed.") + + def test_insert_row(self): + print("\nRunning test_insert_row...") + self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)') + self.connection.commit() + self.db.insert_row('test_table', ('id', 'name'), (1, 'test')) + self.cursor.execute('SELECT name FROM test_table WHERE id = 1') + result = self.cursor.fetchone() + self.assertEqual(result[0], 'test') + print("Insert row test passed.") + + def test_table_exists(self): + print("\nRunning test_table_exists...") + self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)') + self.connection.commit() + exists = self.db.table_exists('test_table') + self.assertTrue(exists) + print("Table exists test passed.") + + def test_get_timestamped_records(self): + print("\nRunning test_get_timestamped_records...") + df = pd.DataFrame({ + 'open_time': [unix_time_millis(dt.datetime.utcnow())], + 'open': [1.0], + 'high': [1.0], + 'low': [1.0], + 'close': [1.0], + 'volume': [1.0] + }) + table_name = 'test_table' + self.cursor.execute(f""" + CREATE TABLE {table_name} ( + id INTEGER PRIMARY KEY, + open_time INTEGER UNIQUE, + open REAL NOT NULL, + high REAL NOT NULL, + low REAL NOT NULL, + close REAL NOT NULL, + volume REAL NOT NULL + ) + """) + self.connection.commit() + self.db.insert_dataframe(df, table_name) + st = dt.datetime.utcnow() - dt.timedelta(minutes=1) + et = dt.datetime.utcnow() + records = self.db.get_timestamped_records(table_name, 'open_time', st, et) + self.assertIsInstance(records, pd.DataFrame) + self.assertFalse(records.empty) + print("Get timestamped records test passed.") + + def test_get_from_static_table(self): + print("\nRunning test_get_from_static_table...") + self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT UNIQUE)') + self.connection.commit() + item = self.db.get_from_static_table('id', 'test_table', HDict({'name': 'test'}), create_id=True) + self.assertIsInstance(item, int) + self.cursor.execute('SELECT id FROM test_table WHERE name = ?', ('test',)) + result = self.cursor.fetchone() + self.assertEqual(item, result[0]) + print("Get from static table test passed.") + + def test_insert_candles_into_db(self): + print("\nRunning test_insert_candles_into_db...") + df = pd.DataFrame({ + 'open_time': [unix_time_millis(dt.datetime.utcnow())], + 'open': [1.0], + 'high': [1.0], + 'low': [1.0], + 'close': [1.0], + 'volume': [1.0] + }) + table_name = 'test_table' + self.cursor.execute(f""" + CREATE TABLE {table_name} ( + id INTEGER PRIMARY KEY, + market_id INTEGER, + open_time INTEGER UNIQUE, + open REAL NOT NULL, + high REAL NOT NULL, + low REAL NOT NULL, + close REAL NOT NULL, + volume REAL NOT NULL + ) + """) + self.connection.commit() + + # Create the exchange and markets tables needed for the foreign key constraints + self.cursor.execute(""" + CREATE TABLE exchange ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE + ) + """) + self.cursor.execute(""" + CREATE TABLE markets ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + symbol TEXT, + exchange_id INTEGER, + FOREIGN KEY (exchange_id) REFERENCES exchange(id) + ) + """) + self.connection.commit() + + self.db.insert_candles_into_db(df, table_name, 'BTC/USDT', 'binance') + self.cursor.execute(f'SELECT * FROM {table_name}') + result = self.cursor.fetchall() + self.assertFalse(len(result) == 0) + print("Insert candles into db test passed.") -def test_make_insert(): - table = 'markets' - values = {'first_field': 'first_value'} - q_str = make_insert(table=table, values=values) - print(f'\nWith one indexing field: {q_str}') +if __name__ == '__main__': + unittest.main() - values = {'first_field': 'first_value', 'second_field': 'second_value'} - q_str = make_insert(table=table, values=values) - print(f'\nWith two indexing fields: {q_str}') - assert q_str is not None - - -def test__table_exists(): - exchanges = ExchangeInterface() - d_obj = Database(exchanges) - exists = d_obj._table_exists('BTC/USD_5m_alpaca') - print(f'\nExists - Should be true: {exists}') - assert exists - exists = d_obj._table_exists('BTC/USD_5m_alpina') - print(f'Doesnt exist - should be false: {exists}') - assert not exists - - -def test_get_from_static_table(): - exchanges = ExchangeInterface() - d_obj = Database(exchanges) - market_id = d_obj._fetch_market_id('BTC/USD', 'alpaca') - e_id = d_obj._fetch_exchange_id('alpaca') - print(f'market id: {market_id}') - assert market_id > 0 - print(f'exchange_name ID: {e_id}') - assert e_id == 4 - - -def test_populate_table(): - """ - Populates a database table with records from the exchange_name. - :param table_name: str - The name of the table in the database. - :param start_time: datetime - The starting time to fetch the records from. - :param end_time: datetime - The end time to get the records until. - :return: pdDataframe: - The data that was downloaded. - """ - exchanges = ExchangeInterface() - d_obj = Database(exchanges) - d_obj._populate_table(table_name='BTC/USD_2h_alpaca', - start_time=datetime.datetime(year=2023, month=3, day=27, hour=6, minute=0)) - - -def test_get_records_since(): - exchanges = ExchangeInterface() - d_obj = Database(exchanges) - records = d_obj.get_records_since(table_name='BTC/USD_15m_alpaca', - st=datetime.datetime(year=2023, month=3, day=27, hour=1, minute=0), - et=datetime.datetime.utcnow(), - rl=15) - print(records) - assert records is not None +# def test(): +# # un_hashed_pass = 'password' +# hasher = bcrypt.using(rounds=13) +# # hashed_pass = hasher.hash(un_hashed_pass) +# # print(f'password: {un_hashed_pass}') +# # print(f'hashed pass: {hashed_pass}') +# # print(f" right pass: {hasher.verify('password', hashed_pass)}") +# # print(f" wrong pass: {hasher.verify('passWord', hashed_pass)}") +# engine = create_engine("sqlite:///" + config.DB_FILE, echo=True) +# with engine.connect() as conn: +# default_user = pd.read_sql_query(sql=text("SELECT * FROM users WHERE user_name = 'guest'"), con=conn) +# # hashed_password = default_user.password.values[0] +# # print(f" verify pass: {hasher.verify('password', hashed_password)}") +# username = default_user.user_name.values[0] +# print(username) diff --git a/tests/test_binance_integration.py b/tests/test_live_exchange_integration.py similarity index 99% rename from tests/test_binance_integration.py rename to tests/test_live_exchange_integration.py index 16fd952..1a0a03b 100644 --- a/tests/test_binance_integration.py +++ b/tests/test_live_exchange_integration.py @@ -12,7 +12,7 @@ class TestExchange(unittest.TestCase): @classmethod def setUpClass(cls): - exchange_name = 'binance' + exchange_name = 'kraken' cls.api_keys = None """Uncomment and Provide api keys to connect to exchange.""" # cls.api_keys = {'key': 'EXCHANGE_API_KEY', 'secret': 'EXCHANGE_API_SECRET'}