Completed unittests for Database and DataCache.

This commit is contained in:
Rob 2024-08-03 16:56:13 -03:00
parent e601f8c23e
commit 4130e0ca9a
9 changed files with 942 additions and 640 deletions

View File

@ -27,7 +27,7 @@ class BrighterTrades:
self.signals = Signals(self.config.signals_list) self.signals = Signals(self.config.signals_list)
# Object that maintains candlestick and price data. # Object that maintains candlestick and price data.
self.candles = Candles(config_obj=self.config, exchanges=self.exchanges, database=self.data) self.candles = Candles(config_obj=self.config, exchanges=self.exchanges, data_source=self.data)
# Object that interacts with and maintains data from available indicators # Object that interacts with and maintains data from available indicators
self.indicators = Indicators(self.candles, self.config) self.indicators = Indicators(self.candles, self.config)

View File

@ -1,25 +1,41 @@
from typing import Any, List from typing import Any, List
import pandas as pd import pandas as pd
import datetime as dt import datetime as dt
from Database import Database
from shared_utilities import query_satisfied, query_uptodate, unix_time_millis, timeframe_to_minutes
import logging
from database import Database # Set up logging
from shared_utilities import query_satisfied, query_uptodate, unix_time_millis logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
class DataCache: class DataCache:
""" """
Fetches manages data limits and optimises memory storage. Fetches manages data limits and optimises memory storage.
Handles connections and operations for the given exchanges.
Example usage:
--------------
db = DataCache(exchanges=some_exchanges_object)
""" """
def __init__(self, exchanges): def __init__(self, exchanges):
"""
Initializes the DataCache class.
# Define the cache size :param exchanges: The exchanges object handling communication with connected exchanges.
self.max_tables = 50 # Maximum amount of tables that will be cached at any given time. """
self.max_records = 1000 # Maximum number of records that will be kept per table. # Maximum number of tables to cache at any given time.
self.cached_data = {} # A dictionary that holds all the cached records. self.max_tables = 50
self.db = Database( # Maximum number of records to be kept per table.
exchanges) # The class that handles the DB interactions pass it a connection to the exchange_interface. self.max_records = 1000
# A dictionary that holds all the cached records.
self.cached_data = {}
# The class that handles the DB interactions.
self.db = Database()
# The class that handles exchange interactions.
self.exchanges = exchanges
def cache_exists(self, key: str) -> bool: def cache_exists(self, key: str) -> bool:
""" """
@ -89,27 +105,32 @@ class DataCache:
""" """
Return any records from the cache indexed by table_name that are newer than start_datetime. Return any records from the cache indexed by table_name that are newer than start_datetime.
:param ex_details: Details to pass to the server :param ex_details: List[str] - Details to pass to the server
:param key: <str>: - The dictionary table_name of the records. :param key: str - The dictionary table_name of the records.
:param start_datetime: <dt.datetime>: - The datetime of the first record requested. :param start_datetime: dt.datetime - The datetime of the first record requested.
:param record_length: The timespan of the records. :param record_length: int - The timespan of the records.
:return pd.DataFrame: - The Requested records :return: pd.DataFrame - The Requested records
Example:
--------
records = data_cache.get_records_since('BTC/USD_2h_binance', dt.datetime.utcnow() - dt.timedelta(minutes=60), 60, ['BTC/USD', '2h', 'binance'])
""" """
try: try:
# End time of query defaults to the current time. # End time of query defaults to the current time.
end_datetime = dt.datetime.utcnow() end_datetime = dt.datetime.utcnow()
if self.cache_exists(key=key): if self.cache_exists(key=key):
print('\nGetting records from cache.') logger.debug('Getting records from cache.')
# If the records exist retrieve them from the cache. # If the records exist, retrieve them from the cache.
records = self.get_cache(key) records = self.get_cache(key)
else: else:
# If they don't exist in cache, get them from the db. # If they don't exist in cache, get them from the database.
print(f'\nRecords not in cache. Requesting from DB: starting at: {start_datetime} to: {end_datetime}') logger.debug(
records = self.db.get_records_since(table_name=key, st=start_datetime, f'Records not in cache. Requesting from DB: starting at: {start_datetime} to: {end_datetime}')
records = self.get_records_since_from_db(table_name=key, st=start_datetime,
et=end_datetime, rl=record_length, ex_details=ex_details) et=end_datetime, rl=record_length, ex_details=ex_details)
print(f'Got {len(records.index)} records from db') logger.debug(f'Got {len(records.index)} records from DB.')
self.set_cache(data=records, key=key) self.set_cache(data=records, key=key)
# Check if the records in the cache go far enough back to satisfy the query. # Check if the records in the cache go far enough back to satisfy the query.
@ -119,13 +140,14 @@ class DataCache:
if first_timestamp: if first_timestamp:
# The records didn't go far enough back if a timestamp was returned. # The records didn't go far enough back if a timestamp was returned.
end_time = dt.datetime.utcfromtimestamp(first_timestamp) end_time = dt.datetime.utcfromtimestamp(first_timestamp)
# Request candles with open_times between [start_time:end_time] from the database. logger.debug(f'Requesting additional records from {start_datetime} to {end_time}')
print(f'requesting additional records from {start_datetime} to {end_time}') # Request additional records from the database.
additional_records = self.db.get_records_since(table_name=key, st=start_datetime, additional_records = self.get_records_since_from_db(table_name=key, st=start_datetime,
et=end_time, rl=record_length, ex_details=ex_details) et=end_time, rl=record_length,
print(f'Got {len(additional_records.index)} additional records from db') ex_details=ex_details)
logger.debug(f'Got {len(additional_records.index)} additional records from DB.')
if not additional_records.empty: if not additional_records.empty:
# If more records were received update the cache. # If more records were received, update the cache.
self.update_candle_cache(additional_records, key) self.update_candle_cache(additional_records, key)
# Check if the records received are up-to-date. # Check if the records received are up-to-date.
@ -134,21 +156,219 @@ class DataCache:
if last_timestamp: if last_timestamp:
# The query was not up-to-date if a timestamp was returned. # The query was not up-to-date if a timestamp was returned.
start_time = dt.datetime.utcfromtimestamp(last_timestamp) start_time = dt.datetime.utcfromtimestamp(last_timestamp)
print(f'requesting additional records from {start_time} to {end_datetime}') logger.debug(f'Requesting additional records from {start_time} to {end_datetime}')
# Request the database update its table starting from start_datetime. # Request additional records from the database.
additional_records = self.db.get_records_since(table_name=key, st=start_time, additional_records = self.get_records_since_from_db(table_name=key, st=start_time,
et=end_datetime, rl=record_length, et=end_datetime, rl=record_length,
ex_details=ex_details) ex_details=ex_details)
print(f'Got {len(additional_records.index)} additional records from db') logger.debug(f'Got {len(additional_records.index)} additional records from DB.')
if not additional_records.empty: if not additional_records.empty:
self.update_candle_cache(additional_records, key) self.update_candle_cache(additional_records, key)
# Create a UTC timestamp. # Create a UTC timestamp.
_timestamp = unix_time_millis(start_datetime) _timestamp = unix_time_millis(start_datetime)
# Return all records equal or newer than timestamp. logger.debug(f"Start timestamp in human-readable form: {dt.datetime.utcfromtimestamp(_timestamp / 1000.0)}")
# Return all records equal to or newer than the timestamp.
result = self.get_cache(key).query('open_time >= @_timestamp') result = self.get_cache(key).query('open_time >= @_timestamp')
return result return result
except Exception as e: except Exception as e:
print(f"An error occurred: {str(e)}") logger.error(f"An error occurred: {str(e)}")
raise raise
def get_records_since_from_db(self, table_name: str, st: dt.datetime,
et: dt.datetime, rl: float, ex_details: list) -> pd.DataFrame:
"""
Returns records from a specified table that meet a criteria and ensures the records are complete.
If the records do not go back far enough or are not up-to-date, it fetches additional records
from an exchange and populates the table.
:param table_name: Database table name.
:param st: Start datetime.
:param et: End datetime.
:param rl: Timespan in minutes each record represents.
:param ex_details: Exchange details [symbol, interval, exchange_name].
:return: DataFrame of records.
Example:
--------
records = db.get_records_since('test_table', start_time, end_time, 1, ['BTC/USDT', '1m', 'binance'])
"""
def add_data(data, tn, start_t, end_t):
new_records = self._populate_db(table_name=tn, start_time=start_t, end_time=end_t, ex_details=ex_details)
print(f'Got {len(new_records.index)} records from exchange_name')
if not new_records.empty:
data = pd.concat([data, new_records], axis=0, ignore_index=True)
data = data.drop_duplicates(subset="open_time", keep='first')
return data
if self.db.table_exists(table_name=table_name):
print('\nTable existed retrieving records from DB')
print(f'Requesting from {st} to {et}')
records = self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=st, et=et)
print(f'Got {len(records.index)} records from db')
else:
print(f'\nTable didnt exist fetching from {ex_details[2]}')
temp = (((unix_time_millis(et) - unix_time_millis(st)) / 1000) / 60) / rl
print(f'Requesting from {st} to {et}, Should be {temp} records')
records = self._populate_db(table_name=table_name, start_time=st, end_time=et, ex_details=ex_details)
print(f'Got {len(records.index)} records from {ex_details[2]}')
first_timestamp = query_satisfied(start_datetime=st, records=records, r_length_min=rl)
if first_timestamp:
print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}')
print(f'first ts on record is: {first_timestamp}')
end_time = dt.datetime.utcfromtimestamp(first_timestamp)
print(f'Requesting from {st} to {end_time}')
records = add_data(data=records, tn=table_name, start_t=st, end_t=end_time)
last_timestamp = query_uptodate(records=records, r_length_min=rl)
if last_timestamp:
print(f'\nRecords were not updated. Requesting from {ex_details[2]}.')
print(f'the last record on file is: {last_timestamp}')
start_time = dt.datetime.utcfromtimestamp(last_timestamp)
print(f'Requesting from {start_time} to {et}')
records = add_data(data=records, tn=table_name, start_t=start_time, end_t=et)
return records
def _populate_db(self, table_name: str, start_time: dt.datetime, ex_details: list,
end_time: dt.datetime = None) -> pd.DataFrame:
"""
Populates a database table with records from the exchange.
:param table_name: Name of the table in the database.
:param start_time: Start time to fetch the records from.
:param end_time: End time to fetch the records until (optional).
:param ex_details: Exchange details [symbol, interval, exchange_name, user_name].
:return: DataFrame of the data downloaded.
Example:
--------
records = db._populate_table('test_table', start_time, ['BTC/USDT', '1m', 'binance', 'user1'])
"""
if end_time is None:
end_time = dt.datetime.utcnow()
sym, inter, ex, un = ex_details
records = self._fetch_candles_from_exchange(symbol=sym, interval=inter, exchange_name=ex, user_name=un,
start_datetime=start_time, end_datetime=end_time)
if not records.empty:
self.db.insert_candles_into_db(records, table_name=table_name, symbol=sym, exchange_name=ex)
else:
print(f'No records inserted {records}')
return records
def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str,
start_datetime: dt.datetime = None,
end_datetime: dt.datetime = None) -> pd.DataFrame:
"""
Fetches and returns all candles from the specified market, timeframe, and exchange.
:param symbol: Symbol of the market.
:param interval: Timeframe in the format '<int><alpha>' (e.g., '15m', '4h').
:param exchange_name: Name of the exchange.
:param user_name: Name of the user.
:param start_datetime: Start datetime for fetching data (optional).
:param end_datetime: End datetime for fetching data (optional).
:return: DataFrame of candle data.
Example:
--------
candles = db._fetch_candles_from_exchange('BTC/USDT', '1m', 'binance', 'user1', start_time, end_time)
"""
def fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame:
"""
Fills gaps in the data by replicating the last known data point for the missing periods.
:param records: DataFrame containing the original records.
:param interval: Interval of the data (e.g., '1m', '5m').
:return: DataFrame with gaps filled.
Example:
--------
filled_records = fill_data_holes(df, '1m')
"""
time_span = timeframe_to_minutes(interval)
last_timestamp = None
filled_records = []
logger.info(f"Starting to fill data holes for interval: {interval}")
for index, row in records.iterrows():
time_stamp = row['open_time']
# If last_timestamp is None, this is the first record
if last_timestamp is None:
last_timestamp = time_stamp
filled_records.append(row)
logger.debug(f"First timestamp: {time_stamp}")
continue
# Calculate the difference in milliseconds and minutes
delta_ms = time_stamp - last_timestamp
delta_minutes = (delta_ms / 1000) / 60
logger.debug(f"Timestamp: {time_stamp}, Delta minutes: {delta_minutes}")
# If the gap is larger than the time span of the interval, fill the gap
if delta_minutes > time_span:
num_missing_rec = int(delta_minutes / time_span)
step = int(delta_ms / num_missing_rec)
logger.debug(f"Gap detected. Filling {num_missing_rec} records with step: {step}")
for ts in range(int(last_timestamp) + step, int(time_stamp), step):
new_row = row.copy()
new_row['open_time'] = ts
filled_records.append(new_row)
logger.debug(f"Filled timestamp: {ts}")
filled_records.append(row)
last_timestamp = time_stamp
logger.info("Data holes filled successfully.")
return pd.DataFrame(filled_records)
# Default start date if not provided
if start_datetime is None:
start_datetime = dt.datetime(year=2017, month=1, day=1)
# Default end date if not provided
if end_datetime is None:
end_datetime = dt.datetime.utcnow()
# Check if start date is greater than end date
if start_datetime > end_datetime:
raise ValueError("Invalid start and end parameters: start_datetime must be before end_datetime.")
# Get the exchange object
exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name)
# Calculate the expected number of records
temp = (((unix_time_millis(end_datetime) - unix_time_millis(
start_datetime)) / 1000) / 60) / timeframe_to_minutes(interval)
logger.info(f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {temp}')
# If start and end times are the same, set end_datetime to None
if start_datetime == end_datetime:
end_datetime = None
# Fetch historical candlestick data from the exchange
candles = exchange.get_historical_klines(symbol=symbol, interval=interval, start_dt=start_datetime,
end_dt=end_datetime)
num_rec_records = len(candles.index)
logger.info(f'{num_rec_records} candles retrieved from the exchange.')
# Check if the retrieved data covers the expected time range
open_times = candles.open_time
estimated_num_records = (((open_times.max() - open_times.min()) / 1000) / 60) / timeframe_to_minutes(interval)
# Fill in any missing data if the retrieved data is less than expected
if num_rec_records < estimated_num_records:
candles = fill_data_holes(candles, interval)
return candles

352
src/Database.py Normal file
View File

@ -0,0 +1,352 @@
import sqlite3
from functools import lru_cache
from typing import Any
import config
import datetime as dt
import pandas as pd
from shared_utilities import unix_time_millis
class SQLite:
"""
Context manager for SQLite database connections.
Accepts a database file name or defaults to the file in config.DB_FILE.
Example usage:
--------------
with SQLite(db_file='test_db.sqlite') as con:
cursor = con.cursor()
cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
"""
def __init__(self, db_file=None):
self.db_file = db_file if db_file else config.DB_FILE
self.connection = sqlite3.connect(self.db_file)
def __enter__(self):
return self.connection
def __exit__(self, exc_type, exc_val, exc_tb):
self.connection.commit()
self.connection.close()
class HDict(dict):
"""
Hashable dictionary to use as cache keys.
Example usage:
--------------
hdict = HDict({'key1': 'value1', 'key2': 'value2'})
hash(hdict)
"""
def __hash__(self):
return hash(frozenset(self.items()))
def make_query(item: str, table: str, columns: list) -> str:
"""
Creates a SQL select query string with the required number of placeholders.
:param item: The field to select.
:param table: The table to select from.
:param columns: List of columns for the where clause.
:return: The query string.
Example:
--------
query = make_query('id', 'test_table', ['name', 'age'])
# Result: 'SELECT id FROM test_table WHERE name = ? AND age = ?;'
"""
an_itr = iter(columns)
k = next(an_itr)
where_str = f"SELECT {item} FROM {table} WHERE {k} = ?"
where_str += "".join([f" AND {k} = ?" for k in an_itr]) + ';'
return where_str
def make_insert(table: str, values: tuple) -> str:
"""
Creates a SQL insert query string with the required number of placeholders.
:param table: The table to insert into.
:param values: Tuple of values to insert.
:return: The query string.
Example:
--------
insert = make_insert('test_table', ('name', 'age'))
# Result: "INSERT INTO test_table ('name', 'age') VALUES(?, ?);"
"""
itr1 = iter(values)
itr2 = iter(values)
k1 = next(itr1)
_ = next(itr2)
insert_str = f"INSERT INTO {table} ('{k1}'"
insert_str += "".join([f", '{k1}'" for k1 in itr1]) + ") VALUES(?" + "".join(
[", ?" for _ in enumerate(itr2)]) + ");"
return insert_str
class Database:
"""
Database class to communicate and maintain the database.
Handles connections and operations for the given exchanges.
Example usage:
--------------
db = Database(exchanges=some_exchanges_object, db_file='test_db.sqlite')
db.execute_sql('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
"""
def __init__(self, db_file=None):
"""
Initializes the Database class.
:param db_file: Optional database file name.
"""
self.db_file = db_file
def execute_sql(self, sql: str) -> None:
"""
Executes a raw SQL statement.
:param sql: SQL statement to execute.
Example:
--------
db = Database(exchanges=some_exchanges_object, db_file='test_db.sqlite')
db.execute_sql('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
"""
with SQLite(self.db_file) as con:
cur = con.cursor()
cur.execute(sql)
def get_item_where(self, item_name: str, table_name: str, filter_vals: tuple) -> int:
"""
Returns an item from a table where the filter results should isolate a single row.
:param item_name: Name of the item to fetch.
:param table_name: Name of the table.
:param filter_vals: Tuple of column name and value to filter by.
:return: The item.
Example:
--------
item = db.get_item_where('name', 'test_table', ('id', 1))
# Fetches the 'name' from 'test_table' where 'id' is 1
"""
with SQLite(self.db_file) as con:
cur = con.cursor()
qry = make_query(item_name, table_name, [filter_vals[0]])
cur.execute(qry, (filter_vals[1],))
if user_id := cur.fetchone():
return user_id[0]
else:
error = f"Couldn't fetch item {item_name} from {table_name} where {filter_vals[0]} = {filter_vals[1]}"
raise ValueError(error)
def get_rows_where(self, table: str, filter_vals: tuple) -> pd.DataFrame | None:
"""
Returns a DataFrame containing all rows of a table that meet the filter criteria.
:param table: Name of the table.
:param filter_vals: Tuple of column name and value to filter by.
:return: DataFrame of the query result or None if empty.
Example:
--------
rows = db.get_rows_where('test_table', ('name', 'test'))
# Fetches all rows from 'test_table' where 'name' is 'test'
"""
with SQLite(self.db_file) as con:
qry = f"SELECT * FROM {table} WHERE {filter_vals[0]}='{filter_vals[1]}'"
result = pd.read_sql(qry, con=con)
if not result.empty:
return result
else:
return None
def insert_dataframe(self, df: pd.DataFrame, table: str) -> None:
"""
Inserts a DataFrame into a specified table.
:param df: DataFrame to insert.
:param table: Name of the table.
Example:
--------
df = pd.DataFrame({'id': [1], 'name': ['test']})
db.insert_dataframe(df, 'test_table')
"""
with SQLite(self.db_file) as con:
df.to_sql(name=table, con=con, index=False, if_exists='append')
def insert_row(self, table: str, columns: tuple, values: tuple) -> None:
"""
Inserts a row into a specified table.
:param table: Name of the table.
:param columns: Tuple of column names.
:param values: Tuple of values to insert.
Example:
--------
db.insert_row('test_table', ('id', 'name'), (1, 'test'))
"""
with SQLite(self.db_file) as conn:
cursor = conn.cursor()
sql = make_insert(table=table, values=columns)
cursor.execute(sql, values)
def table_exists(self, table_name: str) -> bool:
"""
Checks if a table exists in the database.
:param table_name: Name of the table.
:return: True if the table exists, False otherwise.
Example:
--------
exists = db._table_exists('test_table')
# Checks if 'test_table' exists in the database
"""
with SQLite(self.db_file) as conn:
cursor = conn.cursor()
sql = make_query(item='name', table='sqlite_schema', columns=['type', 'name'])
cursor.execute(sql, ('table', table_name))
result = cursor.fetchone()
return result is not None
@lru_cache(maxsize=1000)
def get_from_static_table(self, item: str, table: str, indexes: HDict, create_id: bool = False) -> Any:
"""
Returns the row id of an item from a table specified. If the item isn't listed in the table,
it will insert the item into a new row and return the autoincremented id. The item received as a hashable
dictionary so the results can be cached.
:param item: Name of the item requested.
:param table: Table being queried.
:param indexes: Hashable dictionary of indexing columns and their values.
:param create_id: If True, create a row if it doesn't exist and return the autoincrement ID.
:return: The content of the field.
Example:
--------
exchange_id = db.get_from_static_table('id', 'exchange', HDict({'name': 'binance'}), create_id=True)
"""
with SQLite(self.db_file) as conn:
cursor = conn.cursor()
sql = make_query(item, table, list(indexes.keys()))
cursor.execute(sql, tuple(indexes.values()))
result = cursor.fetchone()
if result is None and create_id:
sql = make_insert(table, tuple(indexes.keys()))
cursor.execute(sql, tuple(indexes.values()))
sql = make_query(item, table, list(indexes.keys()))
cursor.execute(sql, tuple(indexes.values()))
result = cursor.fetchone()
return result[0] if result else None
def _fetch_exchange_id(self, exchange_name: str) -> int:
"""
Fetches the primary ID of an exchange from the database.
:param exchange_name: Name of the exchange.
:return: Primary ID of the exchange.
Example:
--------
exchange_id = db._fetch_exchange_id('binance')
"""
return self.get_from_static_table(item='id', table='exchange', create_id=True,
indexes=HDict({'name': exchange_name}))
def _fetch_market_id(self, symbol: str, exchange_name: str) -> int:
"""
Returns the market ID for a trading pair listed in the database.
:param symbol: Symbol of the trading pair.
:param exchange_name: Name of the exchange.
:return: Market ID.
Example:
--------
market_id = db._fetch_market_id('BTC/USDT', 'binance')
"""
exchange_id = self._fetch_exchange_id(exchange_name)
market_id = self.get_from_static_table(item='id', table='markets', create_id=True,
indexes=HDict({'symbol': symbol, 'exchange_id': exchange_id}))
return market_id
def insert_candles_into_db(self, candlesticks: pd.DataFrame, table_name: str, symbol: str,
exchange_name: str) -> None:
"""
Inserts all candlesticks from a DataFrame into the database.
:param candlesticks: DataFrame of candlestick data.
:param table_name: Name of the table to insert into.
:param symbol: Symbol of the trading pair.
:param exchange_name: Name of the exchange.
Example:
--------
df = pd.DataFrame({
'open_time': [unix_time_millis(dt.datetime.utcnow())],
'open': [1.0],
'high': [1.0],
'low': [1.0],
'close': [1.0],
'volume': [1.0]
})
db._insert_candles_into_db(df, 'test_table', 'BTC/USDT', 'binance')
"""
market_id = self._fetch_market_id(symbol, exchange_name)
candlesticks.insert(0, 'market_id', market_id)
sql_create = f"""
CREATE TABLE IF NOT EXISTS '{table_name}' (
id INTEGER PRIMARY KEY,
market_id INTEGER,
open_time INTEGER UNIQUE ON CONFLICT IGNORE,
open REAL NOT NULL,
high REAL NOT NULL,
low REAL NOT NULL,
close REAL NOT NULL,
volume REAL NOT NULL,
FOREIGN KEY (market_id) REFERENCES market (id)
)"""
with SQLite(self.db_file) as conn:
cursor = conn.cursor()
cursor.execute(sql_create)
candlesticks.to_sql(table_name, conn, if_exists='append', index=False)
def get_timestamped_records(self, table_name: str, timestamp_field: str, st: dt.datetime,
et: dt.datetime = None) -> pd.DataFrame:
"""
Returns records from a specified table in the database that have timestamps greater than or equal to a given
start time and, optionally, less than or equal to a given end time.
:param table_name: Database table name.
:param timestamp_field: Field name that contains the timestamp.
:param st: Start datetime.
:param et: End datetime (optional).
:return: DataFrame of records.
Example:
--------
records = db.get_timestamped_records('test_table', 'open_time', start_time, end_time)
"""
with SQLite(self.db_file) as conn:
start_stamp = unix_time_millis(st)
if et is not None:
end_stamp = unix_time_millis(et)
q_str = (
f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= {start_stamp} "
f"AND {timestamp_field} <= {end_stamp};"
)
else:
q_str = f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= {start_stamp};"
records = pd.read_sql(q_str, conn)
records = records.drop('id', axis=1)
return records

View File

@ -5,7 +5,7 @@ from typing import Any
from passlib.hash import bcrypt from passlib.hash import bcrypt
import pandas as pd import pandas as pd
from database import HDict from Database import HDict
class Users: class Users:

View File

@ -1,13 +1,12 @@
import datetime as dt import datetime as dt
import logging as log import logging as log
from DataCache import DataCache
from shared_utilities import timeframe_to_minutes, ts_of_n_minutes_ago from shared_utilities import timeframe_to_minutes, ts_of_n_minutes_ago
# log.basicConfig(level=log.ERROR) # log.basicConfig(level=log.ERROR)
class Candles: class Candles:
def __init__(self, exchanges, config_obj, database): def __init__(self, exchanges, config_obj, data_source):
# A reference to the app configuration # A reference to the app configuration
self.config = config_obj self.config = config_obj
@ -15,8 +14,8 @@ class Candles:
# The maximum amount of candles to load at one time. # The maximum amount of candles to load at one time.
self.max_records = self.config.app_data.get('max_data_loaded') self.max_records = self.config.app_data.get('max_data_loaded')
# This object maintains all the cached data. Pass it connection to the exchanges. # This object maintains all the cached data.
self.data = database self.data = data_source
# print('Setting the candle cache.') # print('Setting the candle cache.')
# # Populate the cache: # # Populate the cache:

View File

@ -1,496 +0,0 @@
import sqlite3
from functools import lru_cache
from typing import Any
import config
import datetime as dt
import pandas as pd
from shared_utilities import query_satisfied, query_uptodate, unix_time_millis, timeframe_to_minutes
class SQLite:
"""
Context manager returns a cursor. The connection is closed when
the cursor is destroyed, even if an exception is thrown.
"""
def __init__(self):
self.connection = sqlite3.connect(config.DB_FILE)
def __enter__(self):
return self.connection
def __exit__(self, exc_type, exc_val, exc_tb):
self.connection.commit()
self.connection.close()
class HDict(dict):
def __hash__(self):
return hash(frozenset(self.items()))
def make_query(item: str, table: str, columns: list) -> str:
"""
Creates a sql select string with the required number of ?'s to match the given columns.
:param item: The field to select.
:param table: The table to select.
:param columns: list - A list of database columns.
:return: str: The query string.
"""
an_itr = iter(columns)
k = next(an_itr)
where_str = f"SELECT {item} FROM {table} WHERE {k} = ?"
where_str = where_str + "".join([f" AND {k} = ?" for k in an_itr]) + ';'
return where_str
def make_insert(table: str, values: tuple) -> str:
"""
Creates a sql insert string with the required number of ?'s to match the given values.
:param table: The table to insert into.
:param values: dict - A dictionary of table_name-value pairs used to index a db query.
:return: str: The query string.
"""
itr1 = iter(values)
itr2 = iter(values)
k1 = next(itr1)
_ = next(itr2)
insert_str = f"INSERT INTO {table} ('{k1}'"
insert_str = insert_str + "".join([f", '{k1}'" for k1 in itr1]) + ") VALUES(?" + "".join(
[", ?" for _ in enumerate(itr2)]) + ");"
return insert_str
class Database:
"""
Communicates and maintains the database.
"""
def __init__(self, exchanges):
# The exchanges object handles communication with all connected exchanges.
self.exchanges = exchanges
@staticmethod
def execute_sql(sql: str) -> None:
"""
Executes a sql statement. This is for stuff I haven't created a function for yet.
:param sql: str - sql statement.
:return: None
"""
with SQLite() as con:
cur = con.cursor()
cur.execute(sql)
@staticmethod
def get_item_where(item_name: str, table_name: str, filter_vals: tuple) -> int:
"""
Returns an item from a table where the filter results should isolate a single row.
:param item_name: str - The name of the item to fetch.
:param table_name: str - The name of the table.
:param filter_vals: tuple(str, str) - The column and value to filter the results with.
:return: str - The item.
"""
with SQLite() as con:
cur = con.cursor()
qry = make_query(item_name, table_name, [filter_vals[0]])
cur.execute(qry, (filter_vals[1],))
if user_id := cur.fetchone():
return user_id[0]
else:
error = f"Couldn't fetch item{item_name} from {table_name} where {filter_vals[0]} = {filter_vals[1]}"
raise ValueError(error)
@staticmethod
def get_rows_where(table: str, filter_vals: tuple) -> pd.DataFrame | None:
"""
Returns a dataframe containing all rows of a table that meet the filter criteria.
:param table: str - The name of the table.
:param filter_vals: tuple(column: str, value: str) - the criteria
:return: dataframe|None - returns the data in a dataframe or None if the query fails.
"""
with SQLite() as con:
qry = f"select * from {table} where {filter_vals[0]}='{filter_vals[1]}'"
result = pd.read_sql(qry, con=con)
if not result.empty:
return result
else:
return None
@staticmethod
def insert_dataframe(df, table):
# Connect to the database.
with SQLite() as con:
# Insert the modified user as a new record in the table.
df.to_sql(name=table, con=con, index=False, if_exists='append')
# Commit the changes to the database.
con.commit()
@staticmethod
def insert_row(table: str, columns: tuple, values: tuple) -> None:
"""
Saves user specific data from a table in the database.
:param table: str - The table to insert into
:param columns: tuple(str1, str2, ...) - The columns of the database.
:param values: tuple(val1, val2, ...) - The values to be inserted.
:return: None
"""
# Connect to the database.
with SQLite() as conn:
# Get a cursor from the sql connection.
cursor = conn.cursor()
sql = make_insert(table=table, values=columns)
cursor.execute(sql, values)
@staticmethod
def _table_exists(table_name: str) -> bool:
"""
Returns True if table_name exists in the database.
:param table_name: The name of the database.
:return: bool - True|False
"""
# Connect to the database.
with SQLite() as conn:
# Get a cursor from the sql connection.
cursor = conn.cursor()
# sql = f"SELECT name FROM sqlite_schema WHERE type = 'table' AND name = '{table_name}';"
sql = make_query(item='name', table='sqlite_schema', columns=['type', 'name'])
# Check if the table exists.
cursor.execute(sql, ('table', table_name))
# Fetch the results from the cursor.
result = cursor.fetchone()
if not result:
# If the table doesn't exist return False.
return False
return True
def _populate_table(self, table_name: str, start_time: dt.datetime, ex_details: list, end_time: dt.datetime = None):
"""
Populates a database table with records from the exchange_name.
:param table_name: str - The name of the table in the database.
:param start_time: datetime - The starting time to fetch the records from.
:param end_time: datetime - The end time to get the records until.
:return: pdDataframe: - The data that was downloaded.
"""
# Set the default end_time to UTC now.
if end_time is None:
end_time = dt.datetime.utcnow()
# Fetch the records from the exchange_name.
# Extract the parameters from the details. Format: <symbol>_<timeframe>_<exchange_name>.
sym, inter, ex, un = ex_details
records = self._fetch_candles_from_exchange(symbol=sym, interval=inter, exchange_name=ex, user_name=un,
start_datetime=start_time, end_datetime=end_time)
# Update the database.
if not records.empty:
# Inert into the database any received records.
self._insert_candles_into_db(records, table_name=table_name, symbol=sym, exchange_name=ex)
else:
print(f'No records inserted {records}')
return records
@staticmethod
@lru_cache(maxsize=1000)
def get_from_static_table(item: str, table: str, indexes: HDict, create_id: bool = False) -> Any:
"""
Retrieves a single item from a table. This method returns a cached result and is ment
for fetching static data like settings, names and ID's.
:param create_id: bool: - If True, create a row if it doesn't exist and return the autoincrement ID.
:param item: str - The name of the item requested.
:param table: str - The table being queried.
:param indexes: str - A hashable dictionary containing the indexing columns and their values.
:return: Any - The content of the field.
"""
# Connect to the database.
with SQLite() as conn:
# Get a cursor from the sql connection.
cursor = conn.cursor()
# Retrieve the record from the db.
sql = make_query(item, table, list(indexes.keys()))
cursor.execute(sql, tuple(indexes.values()))
# The result is returned as tuple. Example: (id,)
result = cursor.fetchone()
if result is None and create_id is True:
# Insert the indexes into the db.
sql = make_insert(table, tuple(indexes.keys()))
cursor.execute(sql, tuple(indexes.values()))
# Retrieve the record from the db.
sql = make_query(item, table, list(indexes.keys()))
cursor.execute(sql, tuple(indexes.values()))
# Get the first element of the tuple received from sql query.
result = cursor.fetchone()
# Return the result from the tuple if it exists.
if result:
return result[0]
else:
return None
def _fetch_exchange_id(self, exchange_name: str) -> int:
"""
Fetch the primary id of exchange_name from the database.
:param exchange_name: str - The name of the exchange_name.
:return: int - The primary id of the exchange_name.
"""
return self.get_from_static_table(item='id', table='exchange', create_id=True,indexes=HDict({'name': exchange_name}))
def _fetch_market_id(self, symbol: str, exchange_name: str) -> int:
"""
Returns the market id that belongs to a trading pair listed in the database.
:param symbol: str - The symbol of the trading pair.
:param exchange_name: str - The exchange_name name.
:return: int - The market ID
"""
# Fetch the id of the exchange_name.
exchange_id = self._fetch_exchange_id(exchange_name)
# Ask the db for the market_id. Tell it to create one if it doesn't already exist.
market_id = self.get_from_static_table(item='id', table='markets', create_id=True,
indexes=HDict({'symbol': symbol, 'exchange_id': exchange_id}))
# Return the market id.
return market_id
def _insert_candles_into_db(self, candlesticks, table_name: str, symbol, exchange_name) -> None:
"""
Insert all the candlesticks from a dataframe into the database.
:param exchange_name: The name of the exchange_name.
:param symbol: The symbol of the trading pair.
:param candlesticks: pd.dataframe - A rows of candlestick attributes.
:param table_name: str - The name of the table to inset.
:return: None
"""
# Retrieve the market id for the symbol.
market_id = self._fetch_market_id(symbol, exchange_name)
# Insert the market id into the dataframe.
candlesticks.insert(0, 'market_id', market_id)
# Create a table schema. todo delete these line if not needed anymore
# # Get a list of all the columns in the dataframe.
# columns = list(candlesticks.columns.values)
# # Isolate any extra columns specific to individual exchanges.
# # The carriage return and tabs are unnecessary, they just tidy output for debugging.
# columns = ',\n\t\t\t\t\t'.join(columns[7:], )
# # Define the columns common with all exchanges and append any extras columns.
sql_create = f"""
CREATE TABLE IF NOT EXISTS '{table_name}' (
id INTEGER PRIMARY KEY,
market_id INTEGER,
open_time INTEGER UNIQUE ON CONFLICT IGNORE,
open REAL NOT NULL,
high REAL NOT NULL,
low REAL NOT NULL,
close REAL NOT NULL,
volume REAL NOT NULL,
FOREIGN KEY (market_id) REFERENCES market (id)
)"""
# Connect to the database.
with SQLite() as conn:
# Get a cursor from the sql connection.
cursor = conn.cursor()
# Create the table if it doesn't exist.
cursor.execute(sql_create)
# Insert the candles into the table.
candlesticks.to_sql(table_name, conn, if_exists='append', index=False)
return
def get_records_since(self, table_name: str, st: dt.datetime,
et: dt.datetime, rl: float, ex_details: list) -> pd.DataFrame:
"""
Returns all the candles newer than the provided start_datetime from the specified table.
:param ex_details: list of details to pass to the server. [symbol, interval, exchange_name]
:param table_name: str - The database table name. Format: : <symbol>_<timeframe>_<exchange_name>.
:param st: dt.datetime.start_datetime - The start_datetime of the first record requested.
:param et: dt.datetime - The end time of the query
:param rl: float - The timespan in minutes each record represents.
:return: pd.dataframe -
"""
def add_data(data, tn, start_t, end_t):
new_records = self._populate_table(table_name=tn, start_time=start_t, end_time=end_t, ex_details=ex_details)
print(f'Got {len(new_records.index)} records from exchange_name')
if not new_records.empty:
# Combine the new records with the previously records.
data = pd.concat([data, new_records], axis=0, ignore_index=True)
# Drop any duplicates from overlap.
data = data.drop_duplicates(subset="open_time", keep='first')
# Return the modified dataframe.
return data
if self._table_exists(table_name=table_name):
# If the table exists retrieve all the records.
print('\nTable existed retrieving records from DB')
print(f'Requesting from {st} to {et}')
records = self._get_records(table_name=table_name, st=st, et=et)
print(f'Got {len(records.index)} records from db')
else:
# If the table doesn't exist, get them from the exchange_name.
print(f'\nTable didnt exist fetching from {ex_details[2]}')
temp = (((unix_time_millis(et) - unix_time_millis(st)) / 1000) / 60) / rl
print(f'Requesting from {st} to {et}, Should be {temp} records')
records = self._populate_table(table_name=table_name, start_time=st, end_time=et, ex_details=ex_details)
print(f'Got {len(records.index)} records from {ex_details[2]}')
# Check if the records in the db go far enough back to satisfy the query.
first_timestamp = query_satisfied(start_datetime=st, records=records, r_length_min=rl)
if first_timestamp:
# The records didn't go far enough back if a timestamp was returned.
print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}')
print(f'first ts on record is: {first_timestamp}')
end_time = dt.datetime.utcfromtimestamp(first_timestamp)
print(f'Requesting from {st} to {end_time}')
# Request records with open_times between [st:end_time] from the database.
records = add_data(data=records, tn=table_name, start_t=st, end_t=end_time)
# Check if the records received are up-to-date.
last_timestamp = query_uptodate(records=records, r_length_min=rl)
if last_timestamp:
# The query was not up-to-date if a timestamp was returned.
print(f'\nRecords were not updated. Requesting from {ex_details[2]}.')
print(f'the last record on file is: {last_timestamp}')
start_time = dt.datetime.utcfromtimestamp(last_timestamp)
print(f'Requesting from {start_time} to {et}')
# Request records with open_times between [start_time:et] from the database.
records = add_data(data=records, tn=table_name, start_t=start_time, end_t=et)
return records
@staticmethod
def _get_records(table_name: str, st: dt.datetime, et: dt.datetime = None) -> pd.DataFrame:
"""
Returns all the candles newer than the provided start_datetime from the specified table.
:param table_name: str - The database table name. Format: : <symbol>_<timeframe>_<exchange_name>.
:param st: dt.datetime.start_datetime - The start_datetime of the first record requested.
:param et: dt.datetime - The end time of the query
:return: pd.dataframe -
"""
# Connect to the database.
with SQLite() as conn:
# Create a timestamp in milliseconds
start_stamp = unix_time_millis(st)
if et is not None:
# Create a timestamp in milliseconds
end_stamp = unix_time_millis(et)
q_str = f"SELECT * FROM '{table_name}' WHERE open_time >= {start_stamp} AND open_time <= {end_stamp};"
else:
q_str = f"SELECT * FROM '{table_name}' WHERE open_time >= {start_stamp};"
# Retrieve all the records from the table.
records = pd.read_sql(q_str, conn)
# Drop the databases primary id.
records = records.drop('id', axis=1)
# Return the data.
return records
# @staticmethod
# def date_of_last_timestamp(table_name):
# """
# Returns the latest timestamp stored in the db.
# TODO: Unused.
#
# :return: dt.timestamp
# """
# # Connect to the database.
# with SQLite() as conn:
# # Get a cursor from the connection.
# cursor = conn.cursor()
# cursor.execute(f"""SELECT open_time FROM '{table_name}' ORDER BY open_time DESC LIMIT 1""")
# ts = cursor.fetchone()[0] / 1000
# return dt.datetime.utcfromtimestamp(ts)
#
def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str,
start_datetime: object = None, end_datetime: object = None) -> pd.DataFrame:
"""
Fetches and returns all candles from the specified market, timeframe, and exchange_name.
:param symbol: str - The symbol of the market.
:param interval: str - The timeframe. Format '<int><alpha>' - examples: '15m', '4h'
:param exchange_name: str - The name of the exchange_name.
:param start_datetime: dt.datetime - The open_time of the first record requested.
:param end_datetime: dt.datetime - The end_time for the query.
:return: pd.DataFrame: Dataframe containing rows of candle attributes that vary
depending on the exchange_name.
For example: [open_time, open, high, low, close, volume, close_time,
quote_volume, num_trades, taker_buy_base_volume, taker_buy_quote_volume]
"""
def fill_data_holes(records, interval):
time_span = timeframe_to_minutes(interval)
last_timestamp = None
filled_records = []
for _, row in records.iterrows():
time_stamp = row['open_time']
if last_timestamp is None:
last_timestamp = time_stamp
filled_records.append(row)
continue
delta_ms = time_stamp - last_timestamp
delta_minutes = (delta_ms / 1000) / 60
if delta_minutes > time_span:
num_missing_rec = int(delta_minutes / time_span)
step = int(delta_ms / num_missing_rec)
for ts in range(int(last_timestamp) + step, int(time_stamp), step):
new_row = row.copy()
new_row['open_time'] = ts
filled_records.append(new_row)
filled_records.append(row)
last_timestamp = time_stamp
return pd.DataFrame(filled_records)
# Default start date for fetching from the exchange_name.
if start_datetime is None:
start_datetime = dt.datetime(year=2017, month=1, day=1)
# Default end date for fetching from the exchange_name.
if end_datetime is None:
end_datetime = dt.datetime.utcnow()
if start_datetime > end_datetime:
raise ValueError("\ndatabase:fetch_candles_from_exchange():"
" Invalid start and end parameters: ", start_datetime, end_datetime)
# Get a reference to the exchange
exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name)
temp = (((unix_time_millis(end_datetime) - unix_time_millis(
start_datetime)) / 1000) / 60) / timeframe_to_minutes(interval)
print(f'Fetching historical data {start_datetime} to {end_datetime}, Should be {temp} records')
if start_datetime == end_datetime:
end_datetime = None
# Request candlestick data from the exchange_name.
candles = exchange.get_historical_klines(symbol=symbol,
interval=interval,
start_dt=start_datetime,
end_dt=end_datetime)
num_rec_records = len(candles.index)
print(f'\n{num_rec_records} candles retrieved from the exchange_name.')
# Isolate the open_times from the records received.
open_times = candles.open_time
# Calculate the number of records that would fit between the min and max open time.
estimated_num_records = (((open_times.max() - open_times.min()) / 1000) / 60) / timeframe_to_minutes(interval)
if num_rec_records < estimated_num_records:
# Some records may be missing due to server maintenance periods ect.
# Fill the holes with copies of the last record received before the gap.
candles = fill_data_holes(candles, interval)
return candles

View File

@ -3,15 +3,53 @@ from exchangeinterface import ExchangeInterface
import unittest import unittest
import pandas as pd import pandas as pd
import datetime as dt import datetime as dt
import os
from Database import SQLite, Database
from shared_utilities import unix_time_millis
class TestDataCache(unittest.TestCase): class TestDataCache(unittest.TestCase):
def setUp(self): def setUp(self):
# Setup the database connection here # Set the database connection here
self.exchanges = ExchangeInterface() self.exchanges = ExchangeInterface()
self.exchanges.connect_exchange(exchange_name='binance', user_name='test_guy', api_keys=None) self.exchanges.connect_exchange(exchange_name='binance', user_name='test_guy', api_keys=None)
# This object maintains all the cached data. Pass it connection to the exchanges. # This object maintains all the cached data. Pass it connection to the exchanges.
self.db_file = 'test_db.sqlite'
self.database = Database(db_file=self.db_file)
# Create necessary tables
with SQLite(db_file=self.db_file) as con:
cursor = con.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS exchange (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS markets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT,
exchange_id INTEGER,
FOREIGN KEY (exchange_id) REFERENCES exchange(id)
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS test_table (
id INTEGER PRIMARY KEY,
market_id INTEGER,
open_time INTEGER UNIQUE,
open REAL NOT NULL,
high REAL NOT NULL,
low REAL NOT NULL,
close REAL NOT NULL,
volume REAL NOT NULL,
FOREIGN KEY (market_id) REFERENCES markets(id)
)
""")
self.data = DataCache(self.exchanges) self.data = DataCache(self.exchanges)
self.data.db = self.database
asset, timeframe, exchange = 'BTC/USD', '2h', 'binance' asset, timeframe, exchange = 'BTC/USD', '2h', 'binance'
self.key1 = f'{asset}_{timeframe}_{exchange}' self.key1 = f'{asset}_{timeframe}_{exchange}'
@ -19,8 +57,11 @@ class TestDataCache(unittest.TestCase):
asset, timeframe, exchange = 'ETH/USD', '2h', 'binance' asset, timeframe, exchange = 'ETH/USD', '2h', 'binance'
self.key2 = f'{asset}_{timeframe}_{exchange}' self.key2 = f'{asset}_{timeframe}_{exchange}'
def tearDown(self):
if os.path.exists(self.db_file):
os.remove(self.db_file)
def test_set_cache(self): def test_set_cache(self):
# Tests
print('Testing set_cache flag not set:') print('Testing set_cache flag not set:')
self.data.set_cache(data='data', key=self.key1) self.data.set_cache(data='data', key=self.key1)
attr = self.data.__getattribute__('cached_data') attr = self.data.__getattribute__('cached_data')
@ -36,7 +77,6 @@ class TestDataCache(unittest.TestCase):
self.assertEqual(attr[self.key1], 'more_data') self.assertEqual(attr[self.key1], 'more_data')
def test_cache_exists(self): def test_cache_exists(self):
# Tests
print('Testing cache_exists() method:') print('Testing cache_exists() method:')
self.assertFalse(self.data.cache_exists(key=self.key2)) self.assertFalse(self.data.cache_exists(key=self.key2))
self.data.set_cache(data='data', key=self.key1) self.data.set_cache(data='data', key=self.key1)
@ -44,7 +84,6 @@ class TestDataCache(unittest.TestCase):
def test_update_candle_cache(self): def test_update_candle_cache(self):
print('Testing update_candle_cache() method:') print('Testing update_candle_cache() method:')
# Initial data
df_initial = pd.DataFrame({ df_initial = pd.DataFrame({
'open_time': [1, 2, 3], 'open_time': [1, 2, 3],
'open': [100, 101, 102], 'open': [100, 101, 102],
@ -54,7 +93,6 @@ class TestDataCache(unittest.TestCase):
'volume': [1000, 1001, 1002] 'volume': [1000, 1001, 1002]
}) })
# Data to be added
df_new = pd.DataFrame({ df_new = pd.DataFrame({
'open_time': [3, 4, 5], 'open_time': [3, 4, 5],
'open': [102, 103, 104], 'open': [102, 103, 104],
@ -96,7 +134,7 @@ class TestDataCache(unittest.TestCase):
def test_get_records_since(self): def test_get_records_since(self):
print('Testing get_records_since() method:') print('Testing get_records_since() method:')
df_initial = pd.DataFrame({ df_initial = pd.DataFrame({
'open_time': [1, 2, 3], 'open_time': [unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=i)) for i in range(3)],
'open': [100, 101, 102], 'open': [100, 101, 102],
'high': [110, 111, 112], 'high': [110, 111, 112],
'low': [90, 91, 92], 'low': [90, 91, 92],
@ -105,20 +143,86 @@ class TestDataCache(unittest.TestCase):
}) })
self.data.set_cache(data=df_initial, key=self.key1) self.data.set_cache(data=df_initial, key=self.key1)
start_datetime = dt.datetime.utcfromtimestamp(2) start_datetime = dt.datetime.utcnow() - dt.timedelta(minutes=2)
result = self.data.get_records_since(key=self.key1, start_datetime=start_datetime, record_length=60, ex_details=[]).sort_values(by='open_time').reset_index(drop=True) result = self.data.get_records_since(key=self.key1, start_datetime=start_datetime, record_length=60,
ex_details=['BTC/USD', '2h', 'binance'])
expected = pd.DataFrame({ expected = pd.DataFrame({
'open_time': [2, 3], 'open_time': df_initial['open_time'][:2].values,
'open': [101, 102], 'open': [100, 101],
'high': [111, 112], 'high': [110, 111],
'low': [91, 92], 'low': [90, 91],
'close': [106, 107], 'close': [105, 106],
'volume': [1001, 1002] 'volume': [1000, 1001]
}) })
pd.testing.assert_frame_equal(result, expected) pd.testing.assert_frame_equal(result, expected)
def test_get_records_since_from_db(self):
print('Testing get_records_since_from_db() method:')
df_initial = pd.DataFrame({
'market_id': [None],
'open_time': [unix_time_millis(dt.datetime.utcnow())],
'open': [1.0],
'high': [1.0],
'low': [1.0],
'close': [1.0],
'volume': [1.0]
})
with SQLite(self.db_file) as con:
df_initial.to_sql('test_table', con, if_exists='append', index=False)
start_datetime = dt.datetime.utcnow() - dt.timedelta(minutes=1)
end_datetime = dt.datetime.utcnow()
result = self.data.get_records_since_from_db(table_name='test_table', st=start_datetime, et=end_datetime,
rl=1, ex_details=['BTC/USD', '2h', 'binance']).sort_values(
by='open_time').reset_index(drop=True)
print("Columns in the result DataFrame:", result.columns)
print("Result DataFrame:\n", result)
# Remove 'id' column from the result DataFrame if it exists
if 'id' in result.columns:
result = result.drop(columns=['id'])
expected = pd.DataFrame({
'market_id': [None],
'open_time': [unix_time_millis(dt.datetime.utcnow())],
'open': [1.0],
'high': [1.0],
'low': [1.0],
'close': [1.0],
'volume': [1.0]
})
print("Expected DataFrame:\n", expected)
pd.testing.assert_frame_equal(result, expected)
def test_populate_db(self):
print('Testing _populate_db() method:')
start_time = dt.datetime.utcnow() - dt.timedelta(days=1)
end_time = dt.datetime.utcnow()
result = self.data._populate_db(table_name='test_table', start_time=start_time,
end_time=end_time, ex_details=['BTC/USD', '2h', 'binance', 'test_guy'])
self.assertIsInstance(result, pd.DataFrame)
self.assertFalse(result.empty)
def test_fetch_candles_from_exchange(self):
print('Testing _fetch_candles_from_exchange() method:')
start_time = dt.datetime.utcnow() - dt.timedelta(days=1)
end_time = dt.datetime.utcnow()
result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h', exchange_name='binance',
user_name='test_guy', start_datetime=start_time,
end_datetime=end_time)
self.assertIsInstance(result, pd.DataFrame)
self.assertFalse(result.empty)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -1,96 +1,219 @@
import datetime import unittest
import sqlite3
import pandas as pd import pandas as pd
from database import make_query, make_insert, Database, HDict, SQLite import datetime as dt
from exchangeinterface import ExchangeInterface from Database import Database, SQLite, make_query, make_insert, HDict
from passlib.hash import bcrypt from shared_utilities import unix_time_millis
from sqlalchemy import create_engine, text
import config
def test(): class TestSQLite(unittest.TestCase):
# un_hashed_pass = 'password' def test_sqlite_context_manager(self):
hasher = bcrypt.using(rounds=13) print("\nRunning test_sqlite_context_manager...")
# hashed_pass = hasher.hash(un_hashed_pass) with SQLite(db_file='test_db.sqlite') as con:
# print(f'password: {un_hashed_pass}') cursor = con.cursor()
# print(f'hashed pass: {hashed_pass}') cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
# print(f" right pass: {hasher.verify('password', hashed_pass)}") cursor.execute("INSERT INTO test_table (name) VALUES ('test')")
# print(f" wrong pass: {hasher.verify('passWord', hashed_pass)}") cursor.execute('SELECT name FROM test_table WHERE name = ?', ('test',))
engine = create_engine("sqlite:///" + config.DB_FILE, echo=True) result = cursor.fetchone()
with engine.connect() as conn: self.assertEqual(result[0], 'test')
default_user = pd.read_sql_query(sql=text("SELECT * FROM users WHERE user_name = 'guest'"), con=conn) print("SQLite context manager test passed.")
# hashed_password = default_user.password.values[0]
# print(f" verify pass: {hasher.verify('password', hashed_password)}")
username = default_user.user_name.values[0]
print(username)
def test_make_query(): class TestDatabase(unittest.TestCase):
values = {'first_field': 'first_value'} def setUp(self):
item = 'market_id' # Use a temporary SQLite database for testing purposes
table = 'markets' self.db_file = 'test_db.sqlite'
q_str = make_query(item=item, table=table, values=values) self.db = Database(db_file=self.db_file)
print(f'\nWith one indexing field: {q_str}') self.connection = sqlite3.connect(self.db_file)
self.cursor = self.connection.cursor()
values = {'first_field': 'first_value', 'second_field': 'second_value'} def tearDown(self):
q_str = make_query(item=item, table=table, values=values) self.connection.close()
print(f'\nWith two indexing fields: {q_str}') import os
assert q_str is not None os.remove(self.db_file) # Remove the temporary database file after tests
def test_execute_sql(self):
print("\nRunning test_execute_sql...")
# Drop the table if it exists to avoid OperationalError
self.cursor.execute('DROP TABLE IF EXISTS test_table')
self.connection.commit()
sql = 'CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)'
self.db.execute_sql(sql)
self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test_table';")
result = self.cursor.fetchone()
self.assertIsNotNone(result)
print("Execute SQL test passed.")
def test_make_query(self):
print("\nRunning test_make_query...")
query = make_query('id', 'test_table', ['name'])
expected_query = 'SELECT id FROM test_table WHERE name = ?;'
self.assertEqual(query, expected_query)
print("Make query test passed.")
def test_make_insert(self):
print("\nRunning test_make_insert...")
insert = make_insert('test_table', ('name', 'age'))
expected_insert = "INSERT INTO test_table ('name', 'age') VALUES(?, ?);"
self.assertEqual(insert, expected_insert)
print("Make insert test passed.")
def test_get_item_where(self):
print("\nRunning test_get_item_where...")
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
self.cursor.execute("INSERT INTO test_table (id, name) VALUES (1, 'test')")
self.connection.commit()
item = self.db.get_item_where('name', 'test_table', ('id', 1))
self.assertEqual(item, 'test')
print("Get item where test passed.")
def test_get_rows_where(self):
print("\nRunning test_get_rows_where...")
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
self.cursor.execute("INSERT INTO test_table (id, name) VALUES (1, 'test')")
self.connection.commit()
rows = self.db.get_rows_where('test_table', ('name', 'test'))
self.assertIsInstance(rows, pd.DataFrame)
self.assertEqual(rows.iloc[0]['name'], 'test')
print("Get rows where test passed.")
def test_insert_dataframe(self):
print("\nRunning test_insert_dataframe...")
df = pd.DataFrame({'id': [1], 'name': ['test']})
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
self.connection.commit()
self.db.insert_dataframe(df, 'test_table')
self.cursor.execute('SELECT name FROM test_table WHERE id = 1')
result = self.cursor.fetchone()
self.assertEqual(result[0], 'test')
print("Insert dataframe test passed.")
def test_insert_row(self):
print("\nRunning test_insert_row...")
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
self.connection.commit()
self.db.insert_row('test_table', ('id', 'name'), (1, 'test'))
self.cursor.execute('SELECT name FROM test_table WHERE id = 1')
result = self.cursor.fetchone()
self.assertEqual(result[0], 'test')
print("Insert row test passed.")
def test_table_exists(self):
print("\nRunning test_table_exists...")
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
self.connection.commit()
exists = self.db.table_exists('test_table')
self.assertTrue(exists)
print("Table exists test passed.")
def test_get_timestamped_records(self):
print("\nRunning test_get_timestamped_records...")
df = pd.DataFrame({
'open_time': [unix_time_millis(dt.datetime.utcnow())],
'open': [1.0],
'high': [1.0],
'low': [1.0],
'close': [1.0],
'volume': [1.0]
})
table_name = 'test_table'
self.cursor.execute(f"""
CREATE TABLE {table_name} (
id INTEGER PRIMARY KEY,
open_time INTEGER UNIQUE,
open REAL NOT NULL,
high REAL NOT NULL,
low REAL NOT NULL,
close REAL NOT NULL,
volume REAL NOT NULL
)
""")
self.connection.commit()
self.db.insert_dataframe(df, table_name)
st = dt.datetime.utcnow() - dt.timedelta(minutes=1)
et = dt.datetime.utcnow()
records = self.db.get_timestamped_records(table_name, 'open_time', st, et)
self.assertIsInstance(records, pd.DataFrame)
self.assertFalse(records.empty)
print("Get timestamped records test passed.")
def test_get_from_static_table(self):
print("\nRunning test_get_from_static_table...")
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT UNIQUE)')
self.connection.commit()
item = self.db.get_from_static_table('id', 'test_table', HDict({'name': 'test'}), create_id=True)
self.assertIsInstance(item, int)
self.cursor.execute('SELECT id FROM test_table WHERE name = ?', ('test',))
result = self.cursor.fetchone()
self.assertEqual(item, result[0])
print("Get from static table test passed.")
def test_insert_candles_into_db(self):
print("\nRunning test_insert_candles_into_db...")
df = pd.DataFrame({
'open_time': [unix_time_millis(dt.datetime.utcnow())],
'open': [1.0],
'high': [1.0],
'low': [1.0],
'close': [1.0],
'volume': [1.0]
})
table_name = 'test_table'
self.cursor.execute(f"""
CREATE TABLE {table_name} (
id INTEGER PRIMARY KEY,
market_id INTEGER,
open_time INTEGER UNIQUE,
open REAL NOT NULL,
high REAL NOT NULL,
low REAL NOT NULL,
close REAL NOT NULL,
volume REAL NOT NULL
)
""")
self.connection.commit()
# Create the exchange and markets tables needed for the foreign key constraints
self.cursor.execute("""
CREATE TABLE exchange (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE
)
""")
self.cursor.execute("""
CREATE TABLE markets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT,
exchange_id INTEGER,
FOREIGN KEY (exchange_id) REFERENCES exchange(id)
)
""")
self.connection.commit()
self.db.insert_candles_into_db(df, table_name, 'BTC/USDT', 'binance')
self.cursor.execute(f'SELECT * FROM {table_name}')
result = self.cursor.fetchall()
self.assertFalse(len(result) == 0)
print("Insert candles into db test passed.")
def test_make_insert(): if __name__ == '__main__':
table = 'markets' unittest.main()
values = {'first_field': 'first_value'}
q_str = make_insert(table=table, values=values)
print(f'\nWith one indexing field: {q_str}')
values = {'first_field': 'first_value', 'second_field': 'second_value'} # def test():
q_str = make_insert(table=table, values=values) # # un_hashed_pass = 'password'
print(f'\nWith two indexing fields: {q_str}') # hasher = bcrypt.using(rounds=13)
assert q_str is not None # # hashed_pass = hasher.hash(un_hashed_pass)
# # print(f'password: {un_hashed_pass}')
# # print(f'hashed pass: {hashed_pass}')
def test__table_exists(): # # print(f" right pass: {hasher.verify('password', hashed_pass)}")
exchanges = ExchangeInterface() # # print(f" wrong pass: {hasher.verify('passWord', hashed_pass)}")
d_obj = Database(exchanges) # engine = create_engine("sqlite:///" + config.DB_FILE, echo=True)
exists = d_obj._table_exists('BTC/USD_5m_alpaca') # with engine.connect() as conn:
print(f'\nExists - Should be true: {exists}') # default_user = pd.read_sql_query(sql=text("SELECT * FROM users WHERE user_name = 'guest'"), con=conn)
assert exists # # hashed_password = default_user.password.values[0]
exists = d_obj._table_exists('BTC/USD_5m_alpina') # # print(f" verify pass: {hasher.verify('password', hashed_password)}")
print(f'Doesnt exist - should be false: {exists}') # username = default_user.user_name.values[0]
assert not exists # print(username)
def test_get_from_static_table():
exchanges = ExchangeInterface()
d_obj = Database(exchanges)
market_id = d_obj._fetch_market_id('BTC/USD', 'alpaca')
e_id = d_obj._fetch_exchange_id('alpaca')
print(f'market id: {market_id}')
assert market_id > 0
print(f'exchange_name ID: {e_id}')
assert e_id == 4
def test_populate_table():
"""
Populates a database table with records from the exchange_name.
:param table_name: str - The name of the table in the database.
:param start_time: datetime - The starting time to fetch the records from.
:param end_time: datetime - The end time to get the records until.
:return: pdDataframe: - The data that was downloaded.
"""
exchanges = ExchangeInterface()
d_obj = Database(exchanges)
d_obj._populate_table(table_name='BTC/USD_2h_alpaca',
start_time=datetime.datetime(year=2023, month=3, day=27, hour=6, minute=0))
def test_get_records_since():
exchanges = ExchangeInterface()
d_obj = Database(exchanges)
records = d_obj.get_records_since(table_name='BTC/USD_15m_alpaca',
st=datetime.datetime(year=2023, month=3, day=27, hour=1, minute=0),
et=datetime.datetime.utcnow(),
rl=15)
print(records)
assert records is not None

View File

@ -12,7 +12,7 @@ class TestExchange(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
exchange_name = 'binance' exchange_name = 'kraken'
cls.api_keys = None cls.api_keys = None
"""Uncomment and Provide api keys to connect to exchange.""" """Uncomment and Provide api keys to connect to exchange."""
# cls.api_keys = {'key': 'EXCHANGE_API_KEY', 'secret': 'EXCHANGE_API_SECRET'} # cls.api_keys = {'key': 'EXCHANGE_API_KEY', 'secret': 'EXCHANGE_API_SECRET'}