Completed unittests for Database and DataCache.

This commit is contained in:
Rob 2024-08-03 16:56:13 -03:00
parent e601f8c23e
commit 4130e0ca9a
9 changed files with 942 additions and 640 deletions

View File

@ -27,7 +27,7 @@ class BrighterTrades:
self.signals = Signals(self.config.signals_list)
# Object that maintains candlestick and price data.
self.candles = Candles(config_obj=self.config, exchanges=self.exchanges, database=self.data)
self.candles = Candles(config_obj=self.config, exchanges=self.exchanges, data_source=self.data)
# Object that interacts with and maintains data from available indicators
self.indicators = Indicators(self.candles, self.config)

View File

@ -1,25 +1,41 @@
from typing import Any, List
import pandas as pd
import datetime as dt
from Database import Database
from shared_utilities import query_satisfied, query_uptodate, unix_time_millis, timeframe_to_minutes
import logging
from database import Database
from shared_utilities import query_satisfied, query_uptodate, unix_time_millis
# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
class DataCache:
"""
Fetches manages data limits and optimises memory storage.
Handles connections and operations for the given exchanges.
Example usage:
--------------
db = DataCache(exchanges=some_exchanges_object)
"""
def __init__(self, exchanges):
"""
Initializes the DataCache class.
# Define the cache size
self.max_tables = 50 # Maximum amount of tables that will be cached at any given time.
self.max_records = 1000 # Maximum number of records that will be kept per table.
self.cached_data = {} # A dictionary that holds all the cached records.
self.db = Database(
exchanges) # The class that handles the DB interactions pass it a connection to the exchange_interface.
:param exchanges: The exchanges object handling communication with connected exchanges.
"""
# Maximum number of tables to cache at any given time.
self.max_tables = 50
# Maximum number of records to be kept per table.
self.max_records = 1000
# A dictionary that holds all the cached records.
self.cached_data = {}
# The class that handles the DB interactions.
self.db = Database()
# The class that handles exchange interactions.
self.exchanges = exchanges
def cache_exists(self, key: str) -> bool:
"""
@ -89,27 +105,32 @@ class DataCache:
"""
Return any records from the cache indexed by table_name that are newer than start_datetime.
:param ex_details: Details to pass to the server
:param key: <str>: - The dictionary table_name of the records.
:param start_datetime: <dt.datetime>: - The datetime of the first record requested.
:param record_length: The timespan of the records.
:param ex_details: List[str] - Details to pass to the server
:param key: str - The dictionary table_name of the records.
:param start_datetime: dt.datetime - The datetime of the first record requested.
:param record_length: int - The timespan of the records.
:return pd.DataFrame: - The Requested records
:return: pd.DataFrame - The Requested records
Example:
--------
records = data_cache.get_records_since('BTC/USD_2h_binance', dt.datetime.utcnow() - dt.timedelta(minutes=60), 60, ['BTC/USD', '2h', 'binance'])
"""
try:
# End time of query defaults to the current time.
end_datetime = dt.datetime.utcnow()
if self.cache_exists(key=key):
print('\nGetting records from cache.')
# If the records exist retrieve them from the cache.
logger.debug('Getting records from cache.')
# If the records exist, retrieve them from the cache.
records = self.get_cache(key)
else:
# If they don't exist in cache, get them from the db.
print(f'\nRecords not in cache. Requesting from DB: starting at: {start_datetime} to: {end_datetime}')
records = self.db.get_records_since(table_name=key, st=start_datetime,
# If they don't exist in cache, get them from the database.
logger.debug(
f'Records not in cache. Requesting from DB: starting at: {start_datetime} to: {end_datetime}')
records = self.get_records_since_from_db(table_name=key, st=start_datetime,
et=end_datetime, rl=record_length, ex_details=ex_details)
print(f'Got {len(records.index)} records from db')
logger.debug(f'Got {len(records.index)} records from DB.')
self.set_cache(data=records, key=key)
# Check if the records in the cache go far enough back to satisfy the query.
@ -119,13 +140,14 @@ class DataCache:
if first_timestamp:
# The records didn't go far enough back if a timestamp was returned.
end_time = dt.datetime.utcfromtimestamp(first_timestamp)
# Request candles with open_times between [start_time:end_time] from the database.
print(f'requesting additional records from {start_datetime} to {end_time}')
additional_records = self.db.get_records_since(table_name=key, st=start_datetime,
et=end_time, rl=record_length, ex_details=ex_details)
print(f'Got {len(additional_records.index)} additional records from db')
logger.debug(f'Requesting additional records from {start_datetime} to {end_time}')
# Request additional records from the database.
additional_records = self.get_records_since_from_db(table_name=key, st=start_datetime,
et=end_time, rl=record_length,
ex_details=ex_details)
logger.debug(f'Got {len(additional_records.index)} additional records from DB.')
if not additional_records.empty:
# If more records were received update the cache.
# If more records were received, update the cache.
self.update_candle_cache(additional_records, key)
# Check if the records received are up-to-date.
@ -134,21 +156,219 @@ class DataCache:
if last_timestamp:
# The query was not up-to-date if a timestamp was returned.
start_time = dt.datetime.utcfromtimestamp(last_timestamp)
print(f'requesting additional records from {start_time} to {end_datetime}')
# Request the database update its table starting from start_datetime.
additional_records = self.db.get_records_since(table_name=key, st=start_time,
logger.debug(f'Requesting additional records from {start_time} to {end_datetime}')
# Request additional records from the database.
additional_records = self.get_records_since_from_db(table_name=key, st=start_time,
et=end_datetime, rl=record_length,
ex_details=ex_details)
print(f'Got {len(additional_records.index)} additional records from db')
logger.debug(f'Got {len(additional_records.index)} additional records from DB.')
if not additional_records.empty:
self.update_candle_cache(additional_records, key)
# Create a UTC timestamp.
_timestamp = unix_time_millis(start_datetime)
# Return all records equal or newer than timestamp.
logger.debug(f"Start timestamp in human-readable form: {dt.datetime.utcfromtimestamp(_timestamp / 1000.0)}")
# Return all records equal to or newer than the timestamp.
result = self.get_cache(key).query('open_time >= @_timestamp')
return result
except Exception as e:
print(f"An error occurred: {str(e)}")
logger.error(f"An error occurred: {str(e)}")
raise
def get_records_since_from_db(self, table_name: str, st: dt.datetime,
et: dt.datetime, rl: float, ex_details: list) -> pd.DataFrame:
"""
Returns records from a specified table that meet a criteria and ensures the records are complete.
If the records do not go back far enough or are not up-to-date, it fetches additional records
from an exchange and populates the table.
:param table_name: Database table name.
:param st: Start datetime.
:param et: End datetime.
:param rl: Timespan in minutes each record represents.
:param ex_details: Exchange details [symbol, interval, exchange_name].
:return: DataFrame of records.
Example:
--------
records = db.get_records_since('test_table', start_time, end_time, 1, ['BTC/USDT', '1m', 'binance'])
"""
def add_data(data, tn, start_t, end_t):
new_records = self._populate_db(table_name=tn, start_time=start_t, end_time=end_t, ex_details=ex_details)
print(f'Got {len(new_records.index)} records from exchange_name')
if not new_records.empty:
data = pd.concat([data, new_records], axis=0, ignore_index=True)
data = data.drop_duplicates(subset="open_time", keep='first')
return data
if self.db.table_exists(table_name=table_name):
print('\nTable existed retrieving records from DB')
print(f'Requesting from {st} to {et}')
records = self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=st, et=et)
print(f'Got {len(records.index)} records from db')
else:
print(f'\nTable didnt exist fetching from {ex_details[2]}')
temp = (((unix_time_millis(et) - unix_time_millis(st)) / 1000) / 60) / rl
print(f'Requesting from {st} to {et}, Should be {temp} records')
records = self._populate_db(table_name=table_name, start_time=st, end_time=et, ex_details=ex_details)
print(f'Got {len(records.index)} records from {ex_details[2]}')
first_timestamp = query_satisfied(start_datetime=st, records=records, r_length_min=rl)
if first_timestamp:
print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}')
print(f'first ts on record is: {first_timestamp}')
end_time = dt.datetime.utcfromtimestamp(first_timestamp)
print(f'Requesting from {st} to {end_time}')
records = add_data(data=records, tn=table_name, start_t=st, end_t=end_time)
last_timestamp = query_uptodate(records=records, r_length_min=rl)
if last_timestamp:
print(f'\nRecords were not updated. Requesting from {ex_details[2]}.')
print(f'the last record on file is: {last_timestamp}')
start_time = dt.datetime.utcfromtimestamp(last_timestamp)
print(f'Requesting from {start_time} to {et}')
records = add_data(data=records, tn=table_name, start_t=start_time, end_t=et)
return records
def _populate_db(self, table_name: str, start_time: dt.datetime, ex_details: list,
end_time: dt.datetime = None) -> pd.DataFrame:
"""
Populates a database table with records from the exchange.
:param table_name: Name of the table in the database.
:param start_time: Start time to fetch the records from.
:param end_time: End time to fetch the records until (optional).
:param ex_details: Exchange details [symbol, interval, exchange_name, user_name].
:return: DataFrame of the data downloaded.
Example:
--------
records = db._populate_table('test_table', start_time, ['BTC/USDT', '1m', 'binance', 'user1'])
"""
if end_time is None:
end_time = dt.datetime.utcnow()
sym, inter, ex, un = ex_details
records = self._fetch_candles_from_exchange(symbol=sym, interval=inter, exchange_name=ex, user_name=un,
start_datetime=start_time, end_datetime=end_time)
if not records.empty:
self.db.insert_candles_into_db(records, table_name=table_name, symbol=sym, exchange_name=ex)
else:
print(f'No records inserted {records}')
return records
def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str,
start_datetime: dt.datetime = None,
end_datetime: dt.datetime = None) -> pd.DataFrame:
"""
Fetches and returns all candles from the specified market, timeframe, and exchange.
:param symbol: Symbol of the market.
:param interval: Timeframe in the format '<int><alpha>' (e.g., '15m', '4h').
:param exchange_name: Name of the exchange.
:param user_name: Name of the user.
:param start_datetime: Start datetime for fetching data (optional).
:param end_datetime: End datetime for fetching data (optional).
:return: DataFrame of candle data.
Example:
--------
candles = db._fetch_candles_from_exchange('BTC/USDT', '1m', 'binance', 'user1', start_time, end_time)
"""
def fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame:
"""
Fills gaps in the data by replicating the last known data point for the missing periods.
:param records: DataFrame containing the original records.
:param interval: Interval of the data (e.g., '1m', '5m').
:return: DataFrame with gaps filled.
Example:
--------
filled_records = fill_data_holes(df, '1m')
"""
time_span = timeframe_to_minutes(interval)
last_timestamp = None
filled_records = []
logger.info(f"Starting to fill data holes for interval: {interval}")
for index, row in records.iterrows():
time_stamp = row['open_time']
# If last_timestamp is None, this is the first record
if last_timestamp is None:
last_timestamp = time_stamp
filled_records.append(row)
logger.debug(f"First timestamp: {time_stamp}")
continue
# Calculate the difference in milliseconds and minutes
delta_ms = time_stamp - last_timestamp
delta_minutes = (delta_ms / 1000) / 60
logger.debug(f"Timestamp: {time_stamp}, Delta minutes: {delta_minutes}")
# If the gap is larger than the time span of the interval, fill the gap
if delta_minutes > time_span:
num_missing_rec = int(delta_minutes / time_span)
step = int(delta_ms / num_missing_rec)
logger.debug(f"Gap detected. Filling {num_missing_rec} records with step: {step}")
for ts in range(int(last_timestamp) + step, int(time_stamp), step):
new_row = row.copy()
new_row['open_time'] = ts
filled_records.append(new_row)
logger.debug(f"Filled timestamp: {ts}")
filled_records.append(row)
last_timestamp = time_stamp
logger.info("Data holes filled successfully.")
return pd.DataFrame(filled_records)
# Default start date if not provided
if start_datetime is None:
start_datetime = dt.datetime(year=2017, month=1, day=1)
# Default end date if not provided
if end_datetime is None:
end_datetime = dt.datetime.utcnow()
# Check if start date is greater than end date
if start_datetime > end_datetime:
raise ValueError("Invalid start and end parameters: start_datetime must be before end_datetime.")
# Get the exchange object
exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name)
# Calculate the expected number of records
temp = (((unix_time_millis(end_datetime) - unix_time_millis(
start_datetime)) / 1000) / 60) / timeframe_to_minutes(interval)
logger.info(f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {temp}')
# If start and end times are the same, set end_datetime to None
if start_datetime == end_datetime:
end_datetime = None
# Fetch historical candlestick data from the exchange
candles = exchange.get_historical_klines(symbol=symbol, interval=interval, start_dt=start_datetime,
end_dt=end_datetime)
num_rec_records = len(candles.index)
logger.info(f'{num_rec_records} candles retrieved from the exchange.')
# Check if the retrieved data covers the expected time range
open_times = candles.open_time
estimated_num_records = (((open_times.max() - open_times.min()) / 1000) / 60) / timeframe_to_minutes(interval)
# Fill in any missing data if the retrieved data is less than expected
if num_rec_records < estimated_num_records:
candles = fill_data_holes(candles, interval)
return candles

352
src/Database.py Normal file
View File

@ -0,0 +1,352 @@
import sqlite3
from functools import lru_cache
from typing import Any
import config
import datetime as dt
import pandas as pd
from shared_utilities import unix_time_millis
class SQLite:
"""
Context manager for SQLite database connections.
Accepts a database file name or defaults to the file in config.DB_FILE.
Example usage:
--------------
with SQLite(db_file='test_db.sqlite') as con:
cursor = con.cursor()
cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
"""
def __init__(self, db_file=None):
self.db_file = db_file if db_file else config.DB_FILE
self.connection = sqlite3.connect(self.db_file)
def __enter__(self):
return self.connection
def __exit__(self, exc_type, exc_val, exc_tb):
self.connection.commit()
self.connection.close()
class HDict(dict):
"""
Hashable dictionary to use as cache keys.
Example usage:
--------------
hdict = HDict({'key1': 'value1', 'key2': 'value2'})
hash(hdict)
"""
def __hash__(self):
return hash(frozenset(self.items()))
def make_query(item: str, table: str, columns: list) -> str:
"""
Creates a SQL select query string with the required number of placeholders.
:param item: The field to select.
:param table: The table to select from.
:param columns: List of columns for the where clause.
:return: The query string.
Example:
--------
query = make_query('id', 'test_table', ['name', 'age'])
# Result: 'SELECT id FROM test_table WHERE name = ? AND age = ?;'
"""
an_itr = iter(columns)
k = next(an_itr)
where_str = f"SELECT {item} FROM {table} WHERE {k} = ?"
where_str += "".join([f" AND {k} = ?" for k in an_itr]) + ';'
return where_str
def make_insert(table: str, values: tuple) -> str:
"""
Creates a SQL insert query string with the required number of placeholders.
:param table: The table to insert into.
:param values: Tuple of values to insert.
:return: The query string.
Example:
--------
insert = make_insert('test_table', ('name', 'age'))
# Result: "INSERT INTO test_table ('name', 'age') VALUES(?, ?);"
"""
itr1 = iter(values)
itr2 = iter(values)
k1 = next(itr1)
_ = next(itr2)
insert_str = f"INSERT INTO {table} ('{k1}'"
insert_str += "".join([f", '{k1}'" for k1 in itr1]) + ") VALUES(?" + "".join(
[", ?" for _ in enumerate(itr2)]) + ");"
return insert_str
class Database:
"""
Database class to communicate and maintain the database.
Handles connections and operations for the given exchanges.
Example usage:
--------------
db = Database(exchanges=some_exchanges_object, db_file='test_db.sqlite')
db.execute_sql('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
"""
def __init__(self, db_file=None):
"""
Initializes the Database class.
:param db_file: Optional database file name.
"""
self.db_file = db_file
def execute_sql(self, sql: str) -> None:
"""
Executes a raw SQL statement.
:param sql: SQL statement to execute.
Example:
--------
db = Database(exchanges=some_exchanges_object, db_file='test_db.sqlite')
db.execute_sql('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
"""
with SQLite(self.db_file) as con:
cur = con.cursor()
cur.execute(sql)
def get_item_where(self, item_name: str, table_name: str, filter_vals: tuple) -> int:
"""
Returns an item from a table where the filter results should isolate a single row.
:param item_name: Name of the item to fetch.
:param table_name: Name of the table.
:param filter_vals: Tuple of column name and value to filter by.
:return: The item.
Example:
--------
item = db.get_item_where('name', 'test_table', ('id', 1))
# Fetches the 'name' from 'test_table' where 'id' is 1
"""
with SQLite(self.db_file) as con:
cur = con.cursor()
qry = make_query(item_name, table_name, [filter_vals[0]])
cur.execute(qry, (filter_vals[1],))
if user_id := cur.fetchone():
return user_id[0]
else:
error = f"Couldn't fetch item {item_name} from {table_name} where {filter_vals[0]} = {filter_vals[1]}"
raise ValueError(error)
def get_rows_where(self, table: str, filter_vals: tuple) -> pd.DataFrame | None:
"""
Returns a DataFrame containing all rows of a table that meet the filter criteria.
:param table: Name of the table.
:param filter_vals: Tuple of column name and value to filter by.
:return: DataFrame of the query result or None if empty.
Example:
--------
rows = db.get_rows_where('test_table', ('name', 'test'))
# Fetches all rows from 'test_table' where 'name' is 'test'
"""
with SQLite(self.db_file) as con:
qry = f"SELECT * FROM {table} WHERE {filter_vals[0]}='{filter_vals[1]}'"
result = pd.read_sql(qry, con=con)
if not result.empty:
return result
else:
return None
def insert_dataframe(self, df: pd.DataFrame, table: str) -> None:
"""
Inserts a DataFrame into a specified table.
:param df: DataFrame to insert.
:param table: Name of the table.
Example:
--------
df = pd.DataFrame({'id': [1], 'name': ['test']})
db.insert_dataframe(df, 'test_table')
"""
with SQLite(self.db_file) as con:
df.to_sql(name=table, con=con, index=False, if_exists='append')
def insert_row(self, table: str, columns: tuple, values: tuple) -> None:
"""
Inserts a row into a specified table.
:param table: Name of the table.
:param columns: Tuple of column names.
:param values: Tuple of values to insert.
Example:
--------
db.insert_row('test_table', ('id', 'name'), (1, 'test'))
"""
with SQLite(self.db_file) as conn:
cursor = conn.cursor()
sql = make_insert(table=table, values=columns)
cursor.execute(sql, values)
def table_exists(self, table_name: str) -> bool:
"""
Checks if a table exists in the database.
:param table_name: Name of the table.
:return: True if the table exists, False otherwise.
Example:
--------
exists = db._table_exists('test_table')
# Checks if 'test_table' exists in the database
"""
with SQLite(self.db_file) as conn:
cursor = conn.cursor()
sql = make_query(item='name', table='sqlite_schema', columns=['type', 'name'])
cursor.execute(sql, ('table', table_name))
result = cursor.fetchone()
return result is not None
@lru_cache(maxsize=1000)
def get_from_static_table(self, item: str, table: str, indexes: HDict, create_id: bool = False) -> Any:
"""
Returns the row id of an item from a table specified. If the item isn't listed in the table,
it will insert the item into a new row and return the autoincremented id. The item received as a hashable
dictionary so the results can be cached.
:param item: Name of the item requested.
:param table: Table being queried.
:param indexes: Hashable dictionary of indexing columns and their values.
:param create_id: If True, create a row if it doesn't exist and return the autoincrement ID.
:return: The content of the field.
Example:
--------
exchange_id = db.get_from_static_table('id', 'exchange', HDict({'name': 'binance'}), create_id=True)
"""
with SQLite(self.db_file) as conn:
cursor = conn.cursor()
sql = make_query(item, table, list(indexes.keys()))
cursor.execute(sql, tuple(indexes.values()))
result = cursor.fetchone()
if result is None and create_id:
sql = make_insert(table, tuple(indexes.keys()))
cursor.execute(sql, tuple(indexes.values()))
sql = make_query(item, table, list(indexes.keys()))
cursor.execute(sql, tuple(indexes.values()))
result = cursor.fetchone()
return result[0] if result else None
def _fetch_exchange_id(self, exchange_name: str) -> int:
"""
Fetches the primary ID of an exchange from the database.
:param exchange_name: Name of the exchange.
:return: Primary ID of the exchange.
Example:
--------
exchange_id = db._fetch_exchange_id('binance')
"""
return self.get_from_static_table(item='id', table='exchange', create_id=True,
indexes=HDict({'name': exchange_name}))
def _fetch_market_id(self, symbol: str, exchange_name: str) -> int:
"""
Returns the market ID for a trading pair listed in the database.
:param symbol: Symbol of the trading pair.
:param exchange_name: Name of the exchange.
:return: Market ID.
Example:
--------
market_id = db._fetch_market_id('BTC/USDT', 'binance')
"""
exchange_id = self._fetch_exchange_id(exchange_name)
market_id = self.get_from_static_table(item='id', table='markets', create_id=True,
indexes=HDict({'symbol': symbol, 'exchange_id': exchange_id}))
return market_id
def insert_candles_into_db(self, candlesticks: pd.DataFrame, table_name: str, symbol: str,
exchange_name: str) -> None:
"""
Inserts all candlesticks from a DataFrame into the database.
:param candlesticks: DataFrame of candlestick data.
:param table_name: Name of the table to insert into.
:param symbol: Symbol of the trading pair.
:param exchange_name: Name of the exchange.
Example:
--------
df = pd.DataFrame({
'open_time': [unix_time_millis(dt.datetime.utcnow())],
'open': [1.0],
'high': [1.0],
'low': [1.0],
'close': [1.0],
'volume': [1.0]
})
db._insert_candles_into_db(df, 'test_table', 'BTC/USDT', 'binance')
"""
market_id = self._fetch_market_id(symbol, exchange_name)
candlesticks.insert(0, 'market_id', market_id)
sql_create = f"""
CREATE TABLE IF NOT EXISTS '{table_name}' (
id INTEGER PRIMARY KEY,
market_id INTEGER,
open_time INTEGER UNIQUE ON CONFLICT IGNORE,
open REAL NOT NULL,
high REAL NOT NULL,
low REAL NOT NULL,
close REAL NOT NULL,
volume REAL NOT NULL,
FOREIGN KEY (market_id) REFERENCES market (id)
)"""
with SQLite(self.db_file) as conn:
cursor = conn.cursor()
cursor.execute(sql_create)
candlesticks.to_sql(table_name, conn, if_exists='append', index=False)
def get_timestamped_records(self, table_name: str, timestamp_field: str, st: dt.datetime,
et: dt.datetime = None) -> pd.DataFrame:
"""
Returns records from a specified table in the database that have timestamps greater than or equal to a given
start time and, optionally, less than or equal to a given end time.
:param table_name: Database table name.
:param timestamp_field: Field name that contains the timestamp.
:param st: Start datetime.
:param et: End datetime (optional).
:return: DataFrame of records.
Example:
--------
records = db.get_timestamped_records('test_table', 'open_time', start_time, end_time)
"""
with SQLite(self.db_file) as conn:
start_stamp = unix_time_millis(st)
if et is not None:
end_stamp = unix_time_millis(et)
q_str = (
f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= {start_stamp} "
f"AND {timestamp_field} <= {end_stamp};"
)
else:
q_str = f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= {start_stamp};"
records = pd.read_sql(q_str, conn)
records = records.drop('id', axis=1)
return records

View File

@ -5,7 +5,7 @@ from typing import Any
from passlib.hash import bcrypt
import pandas as pd
from database import HDict
from Database import HDict
class Users:

View File

@ -1,13 +1,12 @@
import datetime as dt
import logging as log
from DataCache import DataCache
from shared_utilities import timeframe_to_minutes, ts_of_n_minutes_ago
# log.basicConfig(level=log.ERROR)
class Candles:
def __init__(self, exchanges, config_obj, database):
def __init__(self, exchanges, config_obj, data_source):
# A reference to the app configuration
self.config = config_obj
@ -15,8 +14,8 @@ class Candles:
# The maximum amount of candles to load at one time.
self.max_records = self.config.app_data.get('max_data_loaded')
# This object maintains all the cached data. Pass it connection to the exchanges.
self.data = database
# This object maintains all the cached data.
self.data = data_source
# print('Setting the candle cache.')
# # Populate the cache:

View File

@ -1,496 +0,0 @@
import sqlite3
from functools import lru_cache
from typing import Any
import config
import datetime as dt
import pandas as pd
from shared_utilities import query_satisfied, query_uptodate, unix_time_millis, timeframe_to_minutes
class SQLite:
"""
Context manager returns a cursor. The connection is closed when
the cursor is destroyed, even if an exception is thrown.
"""
def __init__(self):
self.connection = sqlite3.connect(config.DB_FILE)
def __enter__(self):
return self.connection
def __exit__(self, exc_type, exc_val, exc_tb):
self.connection.commit()
self.connection.close()
class HDict(dict):
def __hash__(self):
return hash(frozenset(self.items()))
def make_query(item: str, table: str, columns: list) -> str:
"""
Creates a sql select string with the required number of ?'s to match the given columns.
:param item: The field to select.
:param table: The table to select.
:param columns: list - A list of database columns.
:return: str: The query string.
"""
an_itr = iter(columns)
k = next(an_itr)
where_str = f"SELECT {item} FROM {table} WHERE {k} = ?"
where_str = where_str + "".join([f" AND {k} = ?" for k in an_itr]) + ';'
return where_str
def make_insert(table: str, values: tuple) -> str:
"""
Creates a sql insert string with the required number of ?'s to match the given values.
:param table: The table to insert into.
:param values: dict - A dictionary of table_name-value pairs used to index a db query.
:return: str: The query string.
"""
itr1 = iter(values)
itr2 = iter(values)
k1 = next(itr1)
_ = next(itr2)
insert_str = f"INSERT INTO {table} ('{k1}'"
insert_str = insert_str + "".join([f", '{k1}'" for k1 in itr1]) + ") VALUES(?" + "".join(
[", ?" for _ in enumerate(itr2)]) + ");"
return insert_str
class Database:
"""
Communicates and maintains the database.
"""
def __init__(self, exchanges):
# The exchanges object handles communication with all connected exchanges.
self.exchanges = exchanges
@staticmethod
def execute_sql(sql: str) -> None:
"""
Executes a sql statement. This is for stuff I haven't created a function for yet.
:param sql: str - sql statement.
:return: None
"""
with SQLite() as con:
cur = con.cursor()
cur.execute(sql)
@staticmethod
def get_item_where(item_name: str, table_name: str, filter_vals: tuple) -> int:
"""
Returns an item from a table where the filter results should isolate a single row.
:param item_name: str - The name of the item to fetch.
:param table_name: str - The name of the table.
:param filter_vals: tuple(str, str) - The column and value to filter the results with.
:return: str - The item.
"""
with SQLite() as con:
cur = con.cursor()
qry = make_query(item_name, table_name, [filter_vals[0]])
cur.execute(qry, (filter_vals[1],))
if user_id := cur.fetchone():
return user_id[0]
else:
error = f"Couldn't fetch item{item_name} from {table_name} where {filter_vals[0]} = {filter_vals[1]}"
raise ValueError(error)
@staticmethod
def get_rows_where(table: str, filter_vals: tuple) -> pd.DataFrame | None:
"""
Returns a dataframe containing all rows of a table that meet the filter criteria.
:param table: str - The name of the table.
:param filter_vals: tuple(column: str, value: str) - the criteria
:return: dataframe|None - returns the data in a dataframe or None if the query fails.
"""
with SQLite() as con:
qry = f"select * from {table} where {filter_vals[0]}='{filter_vals[1]}'"
result = pd.read_sql(qry, con=con)
if not result.empty:
return result
else:
return None
@staticmethod
def insert_dataframe(df, table):
# Connect to the database.
with SQLite() as con:
# Insert the modified user as a new record in the table.
df.to_sql(name=table, con=con, index=False, if_exists='append')
# Commit the changes to the database.
con.commit()
@staticmethod
def insert_row(table: str, columns: tuple, values: tuple) -> None:
"""
Saves user specific data from a table in the database.
:param table: str - The table to insert into
:param columns: tuple(str1, str2, ...) - The columns of the database.
:param values: tuple(val1, val2, ...) - The values to be inserted.
:return: None
"""
# Connect to the database.
with SQLite() as conn:
# Get a cursor from the sql connection.
cursor = conn.cursor()
sql = make_insert(table=table, values=columns)
cursor.execute(sql, values)
@staticmethod
def _table_exists(table_name: str) -> bool:
"""
Returns True if table_name exists in the database.
:param table_name: The name of the database.
:return: bool - True|False
"""
# Connect to the database.
with SQLite() as conn:
# Get a cursor from the sql connection.
cursor = conn.cursor()
# sql = f"SELECT name FROM sqlite_schema WHERE type = 'table' AND name = '{table_name}';"
sql = make_query(item='name', table='sqlite_schema', columns=['type', 'name'])
# Check if the table exists.
cursor.execute(sql, ('table', table_name))
# Fetch the results from the cursor.
result = cursor.fetchone()
if not result:
# If the table doesn't exist return False.
return False
return True
def _populate_table(self, table_name: str, start_time: dt.datetime, ex_details: list, end_time: dt.datetime = None):
"""
Populates a database table with records from the exchange_name.
:param table_name: str - The name of the table in the database.
:param start_time: datetime - The starting time to fetch the records from.
:param end_time: datetime - The end time to get the records until.
:return: pdDataframe: - The data that was downloaded.
"""
# Set the default end_time to UTC now.
if end_time is None:
end_time = dt.datetime.utcnow()
# Fetch the records from the exchange_name.
# Extract the parameters from the details. Format: <symbol>_<timeframe>_<exchange_name>.
sym, inter, ex, un = ex_details
records = self._fetch_candles_from_exchange(symbol=sym, interval=inter, exchange_name=ex, user_name=un,
start_datetime=start_time, end_datetime=end_time)
# Update the database.
if not records.empty:
# Inert into the database any received records.
self._insert_candles_into_db(records, table_name=table_name, symbol=sym, exchange_name=ex)
else:
print(f'No records inserted {records}')
return records
@staticmethod
@lru_cache(maxsize=1000)
def get_from_static_table(item: str, table: str, indexes: HDict, create_id: bool = False) -> Any:
"""
Retrieves a single item from a table. This method returns a cached result and is ment
for fetching static data like settings, names and ID's.
:param create_id: bool: - If True, create a row if it doesn't exist and return the autoincrement ID.
:param item: str - The name of the item requested.
:param table: str - The table being queried.
:param indexes: str - A hashable dictionary containing the indexing columns and their values.
:return: Any - The content of the field.
"""
# Connect to the database.
with SQLite() as conn:
# Get a cursor from the sql connection.
cursor = conn.cursor()
# Retrieve the record from the db.
sql = make_query(item, table, list(indexes.keys()))
cursor.execute(sql, tuple(indexes.values()))
# The result is returned as tuple. Example: (id,)
result = cursor.fetchone()
if result is None and create_id is True:
# Insert the indexes into the db.
sql = make_insert(table, tuple(indexes.keys()))
cursor.execute(sql, tuple(indexes.values()))
# Retrieve the record from the db.
sql = make_query(item, table, list(indexes.keys()))
cursor.execute(sql, tuple(indexes.values()))
# Get the first element of the tuple received from sql query.
result = cursor.fetchone()
# Return the result from the tuple if it exists.
if result:
return result[0]
else:
return None
def _fetch_exchange_id(self, exchange_name: str) -> int:
"""
Fetch the primary id of exchange_name from the database.
:param exchange_name: str - The name of the exchange_name.
:return: int - The primary id of the exchange_name.
"""
return self.get_from_static_table(item='id', table='exchange', create_id=True,indexes=HDict({'name': exchange_name}))
def _fetch_market_id(self, symbol: str, exchange_name: str) -> int:
"""
Returns the market id that belongs to a trading pair listed in the database.
:param symbol: str - The symbol of the trading pair.
:param exchange_name: str - The exchange_name name.
:return: int - The market ID
"""
# Fetch the id of the exchange_name.
exchange_id = self._fetch_exchange_id(exchange_name)
# Ask the db for the market_id. Tell it to create one if it doesn't already exist.
market_id = self.get_from_static_table(item='id', table='markets', create_id=True,
indexes=HDict({'symbol': symbol, 'exchange_id': exchange_id}))
# Return the market id.
return market_id
def _insert_candles_into_db(self, candlesticks, table_name: str, symbol, exchange_name) -> None:
"""
Insert all the candlesticks from a dataframe into the database.
:param exchange_name: The name of the exchange_name.
:param symbol: The symbol of the trading pair.
:param candlesticks: pd.dataframe - A rows of candlestick attributes.
:param table_name: str - The name of the table to inset.
:return: None
"""
# Retrieve the market id for the symbol.
market_id = self._fetch_market_id(symbol, exchange_name)
# Insert the market id into the dataframe.
candlesticks.insert(0, 'market_id', market_id)
# Create a table schema. todo delete these line if not needed anymore
# # Get a list of all the columns in the dataframe.
# columns = list(candlesticks.columns.values)
# # Isolate any extra columns specific to individual exchanges.
# # The carriage return and tabs are unnecessary, they just tidy output for debugging.
# columns = ',\n\t\t\t\t\t'.join(columns[7:], )
# # Define the columns common with all exchanges and append any extras columns.
sql_create = f"""
CREATE TABLE IF NOT EXISTS '{table_name}' (
id INTEGER PRIMARY KEY,
market_id INTEGER,
open_time INTEGER UNIQUE ON CONFLICT IGNORE,
open REAL NOT NULL,
high REAL NOT NULL,
low REAL NOT NULL,
close REAL NOT NULL,
volume REAL NOT NULL,
FOREIGN KEY (market_id) REFERENCES market (id)
)"""
# Connect to the database.
with SQLite() as conn:
# Get a cursor from the sql connection.
cursor = conn.cursor()
# Create the table if it doesn't exist.
cursor.execute(sql_create)
# Insert the candles into the table.
candlesticks.to_sql(table_name, conn, if_exists='append', index=False)
return
def get_records_since(self, table_name: str, st: dt.datetime,
et: dt.datetime, rl: float, ex_details: list) -> pd.DataFrame:
"""
Returns all the candles newer than the provided start_datetime from the specified table.
:param ex_details: list of details to pass to the server. [symbol, interval, exchange_name]
:param table_name: str - The database table name. Format: : <symbol>_<timeframe>_<exchange_name>.
:param st: dt.datetime.start_datetime - The start_datetime of the first record requested.
:param et: dt.datetime - The end time of the query
:param rl: float - The timespan in minutes each record represents.
:return: pd.dataframe -
"""
def add_data(data, tn, start_t, end_t):
new_records = self._populate_table(table_name=tn, start_time=start_t, end_time=end_t, ex_details=ex_details)
print(f'Got {len(new_records.index)} records from exchange_name')
if not new_records.empty:
# Combine the new records with the previously records.
data = pd.concat([data, new_records], axis=0, ignore_index=True)
# Drop any duplicates from overlap.
data = data.drop_duplicates(subset="open_time", keep='first')
# Return the modified dataframe.
return data
if self._table_exists(table_name=table_name):
# If the table exists retrieve all the records.
print('\nTable existed retrieving records from DB')
print(f'Requesting from {st} to {et}')
records = self._get_records(table_name=table_name, st=st, et=et)
print(f'Got {len(records.index)} records from db')
else:
# If the table doesn't exist, get them from the exchange_name.
print(f'\nTable didnt exist fetching from {ex_details[2]}')
temp = (((unix_time_millis(et) - unix_time_millis(st)) / 1000) / 60) / rl
print(f'Requesting from {st} to {et}, Should be {temp} records')
records = self._populate_table(table_name=table_name, start_time=st, end_time=et, ex_details=ex_details)
print(f'Got {len(records.index)} records from {ex_details[2]}')
# Check if the records in the db go far enough back to satisfy the query.
first_timestamp = query_satisfied(start_datetime=st, records=records, r_length_min=rl)
if first_timestamp:
# The records didn't go far enough back if a timestamp was returned.
print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}')
print(f'first ts on record is: {first_timestamp}')
end_time = dt.datetime.utcfromtimestamp(first_timestamp)
print(f'Requesting from {st} to {end_time}')
# Request records with open_times between [st:end_time] from the database.
records = add_data(data=records, tn=table_name, start_t=st, end_t=end_time)
# Check if the records received are up-to-date.
last_timestamp = query_uptodate(records=records, r_length_min=rl)
if last_timestamp:
# The query was not up-to-date if a timestamp was returned.
print(f'\nRecords were not updated. Requesting from {ex_details[2]}.')
print(f'the last record on file is: {last_timestamp}')
start_time = dt.datetime.utcfromtimestamp(last_timestamp)
print(f'Requesting from {start_time} to {et}')
# Request records with open_times between [start_time:et] from the database.
records = add_data(data=records, tn=table_name, start_t=start_time, end_t=et)
return records
@staticmethod
def _get_records(table_name: str, st: dt.datetime, et: dt.datetime = None) -> pd.DataFrame:
"""
Returns all the candles newer than the provided start_datetime from the specified table.
:param table_name: str - The database table name. Format: : <symbol>_<timeframe>_<exchange_name>.
:param st: dt.datetime.start_datetime - The start_datetime of the first record requested.
:param et: dt.datetime - The end time of the query
:return: pd.dataframe -
"""
# Connect to the database.
with SQLite() as conn:
# Create a timestamp in milliseconds
start_stamp = unix_time_millis(st)
if et is not None:
# Create a timestamp in milliseconds
end_stamp = unix_time_millis(et)
q_str = f"SELECT * FROM '{table_name}' WHERE open_time >= {start_stamp} AND open_time <= {end_stamp};"
else:
q_str = f"SELECT * FROM '{table_name}' WHERE open_time >= {start_stamp};"
# Retrieve all the records from the table.
records = pd.read_sql(q_str, conn)
# Drop the databases primary id.
records = records.drop('id', axis=1)
# Return the data.
return records
# @staticmethod
# def date_of_last_timestamp(table_name):
# """
# Returns the latest timestamp stored in the db.
# TODO: Unused.
#
# :return: dt.timestamp
# """
# # Connect to the database.
# with SQLite() as conn:
# # Get a cursor from the connection.
# cursor = conn.cursor()
# cursor.execute(f"""SELECT open_time FROM '{table_name}' ORDER BY open_time DESC LIMIT 1""")
# ts = cursor.fetchone()[0] / 1000
# return dt.datetime.utcfromtimestamp(ts)
#
def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str,
start_datetime: object = None, end_datetime: object = None) -> pd.DataFrame:
"""
Fetches and returns all candles from the specified market, timeframe, and exchange_name.
:param symbol: str - The symbol of the market.
:param interval: str - The timeframe. Format '<int><alpha>' - examples: '15m', '4h'
:param exchange_name: str - The name of the exchange_name.
:param start_datetime: dt.datetime - The open_time of the first record requested.
:param end_datetime: dt.datetime - The end_time for the query.
:return: pd.DataFrame: Dataframe containing rows of candle attributes that vary
depending on the exchange_name.
For example: [open_time, open, high, low, close, volume, close_time,
quote_volume, num_trades, taker_buy_base_volume, taker_buy_quote_volume]
"""
def fill_data_holes(records, interval):
time_span = timeframe_to_minutes(interval)
last_timestamp = None
filled_records = []
for _, row in records.iterrows():
time_stamp = row['open_time']
if last_timestamp is None:
last_timestamp = time_stamp
filled_records.append(row)
continue
delta_ms = time_stamp - last_timestamp
delta_minutes = (delta_ms / 1000) / 60
if delta_minutes > time_span:
num_missing_rec = int(delta_minutes / time_span)
step = int(delta_ms / num_missing_rec)
for ts in range(int(last_timestamp) + step, int(time_stamp), step):
new_row = row.copy()
new_row['open_time'] = ts
filled_records.append(new_row)
filled_records.append(row)
last_timestamp = time_stamp
return pd.DataFrame(filled_records)
# Default start date for fetching from the exchange_name.
if start_datetime is None:
start_datetime = dt.datetime(year=2017, month=1, day=1)
# Default end date for fetching from the exchange_name.
if end_datetime is None:
end_datetime = dt.datetime.utcnow()
if start_datetime > end_datetime:
raise ValueError("\ndatabase:fetch_candles_from_exchange():"
" Invalid start and end parameters: ", start_datetime, end_datetime)
# Get a reference to the exchange
exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name)
temp = (((unix_time_millis(end_datetime) - unix_time_millis(
start_datetime)) / 1000) / 60) / timeframe_to_minutes(interval)
print(f'Fetching historical data {start_datetime} to {end_datetime}, Should be {temp} records')
if start_datetime == end_datetime:
end_datetime = None
# Request candlestick data from the exchange_name.
candles = exchange.get_historical_klines(symbol=symbol,
interval=interval,
start_dt=start_datetime,
end_dt=end_datetime)
num_rec_records = len(candles.index)
print(f'\n{num_rec_records} candles retrieved from the exchange_name.')
# Isolate the open_times from the records received.
open_times = candles.open_time
# Calculate the number of records that would fit between the min and max open time.
estimated_num_records = (((open_times.max() - open_times.min()) / 1000) / 60) / timeframe_to_minutes(interval)
if num_rec_records < estimated_num_records:
# Some records may be missing due to server maintenance periods ect.
# Fill the holes with copies of the last record received before the gap.
candles = fill_data_holes(candles, interval)
return candles

View File

@ -3,15 +3,53 @@ from exchangeinterface import ExchangeInterface
import unittest
import pandas as pd
import datetime as dt
import os
from Database import SQLite, Database
from shared_utilities import unix_time_millis
class TestDataCache(unittest.TestCase):
def setUp(self):
# Setup the database connection here
# Set the database connection here
self.exchanges = ExchangeInterface()
self.exchanges.connect_exchange(exchange_name='binance', user_name='test_guy', api_keys=None)
# This object maintains all the cached data. Pass it connection to the exchanges.
self.db_file = 'test_db.sqlite'
self.database = Database(db_file=self.db_file)
# Create necessary tables
with SQLite(db_file=self.db_file) as con:
cursor = con.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS exchange (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS markets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT,
exchange_id INTEGER,
FOREIGN KEY (exchange_id) REFERENCES exchange(id)
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS test_table (
id INTEGER PRIMARY KEY,
market_id INTEGER,
open_time INTEGER UNIQUE,
open REAL NOT NULL,
high REAL NOT NULL,
low REAL NOT NULL,
close REAL NOT NULL,
volume REAL NOT NULL,
FOREIGN KEY (market_id) REFERENCES markets(id)
)
""")
self.data = DataCache(self.exchanges)
self.data.db = self.database
asset, timeframe, exchange = 'BTC/USD', '2h', 'binance'
self.key1 = f'{asset}_{timeframe}_{exchange}'
@ -19,8 +57,11 @@ class TestDataCache(unittest.TestCase):
asset, timeframe, exchange = 'ETH/USD', '2h', 'binance'
self.key2 = f'{asset}_{timeframe}_{exchange}'
def tearDown(self):
if os.path.exists(self.db_file):
os.remove(self.db_file)
def test_set_cache(self):
# Tests
print('Testing set_cache flag not set:')
self.data.set_cache(data='data', key=self.key1)
attr = self.data.__getattribute__('cached_data')
@ -36,7 +77,6 @@ class TestDataCache(unittest.TestCase):
self.assertEqual(attr[self.key1], 'more_data')
def test_cache_exists(self):
# Tests
print('Testing cache_exists() method:')
self.assertFalse(self.data.cache_exists(key=self.key2))
self.data.set_cache(data='data', key=self.key1)
@ -44,7 +84,6 @@ class TestDataCache(unittest.TestCase):
def test_update_candle_cache(self):
print('Testing update_candle_cache() method:')
# Initial data
df_initial = pd.DataFrame({
'open_time': [1, 2, 3],
'open': [100, 101, 102],
@ -54,7 +93,6 @@ class TestDataCache(unittest.TestCase):
'volume': [1000, 1001, 1002]
})
# Data to be added
df_new = pd.DataFrame({
'open_time': [3, 4, 5],
'open': [102, 103, 104],
@ -96,7 +134,7 @@ class TestDataCache(unittest.TestCase):
def test_get_records_since(self):
print('Testing get_records_since() method:')
df_initial = pd.DataFrame({
'open_time': [1, 2, 3],
'open_time': [unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=i)) for i in range(3)],
'open': [100, 101, 102],
'high': [110, 111, 112],
'low': [90, 91, 92],
@ -105,20 +143,86 @@ class TestDataCache(unittest.TestCase):
})
self.data.set_cache(data=df_initial, key=self.key1)
start_datetime = dt.datetime.utcfromtimestamp(2)
result = self.data.get_records_since(key=self.key1, start_datetime=start_datetime, record_length=60, ex_details=[]).sort_values(by='open_time').reset_index(drop=True)
start_datetime = dt.datetime.utcnow() - dt.timedelta(minutes=2)
result = self.data.get_records_since(key=self.key1, start_datetime=start_datetime, record_length=60,
ex_details=['BTC/USD', '2h', 'binance'])
expected = pd.DataFrame({
'open_time': [2, 3],
'open': [101, 102],
'high': [111, 112],
'low': [91, 92],
'close': [106, 107],
'volume': [1001, 1002]
'open_time': df_initial['open_time'][:2].values,
'open': [100, 101],
'high': [110, 111],
'low': [90, 91],
'close': [105, 106],
'volume': [1000, 1001]
})
pd.testing.assert_frame_equal(result, expected)
def test_get_records_since_from_db(self):
print('Testing get_records_since_from_db() method:')
df_initial = pd.DataFrame({
'market_id': [None],
'open_time': [unix_time_millis(dt.datetime.utcnow())],
'open': [1.0],
'high': [1.0],
'low': [1.0],
'close': [1.0],
'volume': [1.0]
})
with SQLite(self.db_file) as con:
df_initial.to_sql('test_table', con, if_exists='append', index=False)
start_datetime = dt.datetime.utcnow() - dt.timedelta(minutes=1)
end_datetime = dt.datetime.utcnow()
result = self.data.get_records_since_from_db(table_name='test_table', st=start_datetime, et=end_datetime,
rl=1, ex_details=['BTC/USD', '2h', 'binance']).sort_values(
by='open_time').reset_index(drop=True)
print("Columns in the result DataFrame:", result.columns)
print("Result DataFrame:\n", result)
# Remove 'id' column from the result DataFrame if it exists
if 'id' in result.columns:
result = result.drop(columns=['id'])
expected = pd.DataFrame({
'market_id': [None],
'open_time': [unix_time_millis(dt.datetime.utcnow())],
'open': [1.0],
'high': [1.0],
'low': [1.0],
'close': [1.0],
'volume': [1.0]
})
print("Expected DataFrame:\n", expected)
pd.testing.assert_frame_equal(result, expected)
def test_populate_db(self):
print('Testing _populate_db() method:')
start_time = dt.datetime.utcnow() - dt.timedelta(days=1)
end_time = dt.datetime.utcnow()
result = self.data._populate_db(table_name='test_table', start_time=start_time,
end_time=end_time, ex_details=['BTC/USD', '2h', 'binance', 'test_guy'])
self.assertIsInstance(result, pd.DataFrame)
self.assertFalse(result.empty)
def test_fetch_candles_from_exchange(self):
print('Testing _fetch_candles_from_exchange() method:')
start_time = dt.datetime.utcnow() - dt.timedelta(days=1)
end_time = dt.datetime.utcnow()
result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h', exchange_name='binance',
user_name='test_guy', start_datetime=start_time,
end_datetime=end_time)
self.assertIsInstance(result, pd.DataFrame)
self.assertFalse(result.empty)
if __name__ == '__main__':
unittest.main()

View File

@ -1,96 +1,219 @@
import datetime
import unittest
import sqlite3
import pandas as pd
from database import make_query, make_insert, Database, HDict, SQLite
from exchangeinterface import ExchangeInterface
from passlib.hash import bcrypt
from sqlalchemy import create_engine, text
import config
import datetime as dt
from Database import Database, SQLite, make_query, make_insert, HDict
from shared_utilities import unix_time_millis
def test():
# un_hashed_pass = 'password'
hasher = bcrypt.using(rounds=13)
# hashed_pass = hasher.hash(un_hashed_pass)
# print(f'password: {un_hashed_pass}')
# print(f'hashed pass: {hashed_pass}')
# print(f" right pass: {hasher.verify('password', hashed_pass)}")
# print(f" wrong pass: {hasher.verify('passWord', hashed_pass)}")
engine = create_engine("sqlite:///" + config.DB_FILE, echo=True)
with engine.connect() as conn:
default_user = pd.read_sql_query(sql=text("SELECT * FROM users WHERE user_name = 'guest'"), con=conn)
# hashed_password = default_user.password.values[0]
# print(f" verify pass: {hasher.verify('password', hashed_password)}")
username = default_user.user_name.values[0]
print(username)
class TestSQLite(unittest.TestCase):
def test_sqlite_context_manager(self):
print("\nRunning test_sqlite_context_manager...")
with SQLite(db_file='test_db.sqlite') as con:
cursor = con.cursor()
cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
cursor.execute("INSERT INTO test_table (name) VALUES ('test')")
cursor.execute('SELECT name FROM test_table WHERE name = ?', ('test',))
result = cursor.fetchone()
self.assertEqual(result[0], 'test')
print("SQLite context manager test passed.")
def test_make_query():
values = {'first_field': 'first_value'}
item = 'market_id'
table = 'markets'
q_str = make_query(item=item, table=table, values=values)
print(f'\nWith one indexing field: {q_str}')
class TestDatabase(unittest.TestCase):
def setUp(self):
# Use a temporary SQLite database for testing purposes
self.db_file = 'test_db.sqlite'
self.db = Database(db_file=self.db_file)
self.connection = sqlite3.connect(self.db_file)
self.cursor = self.connection.cursor()
values = {'first_field': 'first_value', 'second_field': 'second_value'}
q_str = make_query(item=item, table=table, values=values)
print(f'\nWith two indexing fields: {q_str}')
assert q_str is not None
def tearDown(self):
self.connection.close()
import os
os.remove(self.db_file) # Remove the temporary database file after tests
def test_execute_sql(self):
print("\nRunning test_execute_sql...")
# Drop the table if it exists to avoid OperationalError
self.cursor.execute('DROP TABLE IF EXISTS test_table')
self.connection.commit()
sql = 'CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)'
self.db.execute_sql(sql)
self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test_table';")
result = self.cursor.fetchone()
self.assertIsNotNone(result)
print("Execute SQL test passed.")
def test_make_query(self):
print("\nRunning test_make_query...")
query = make_query('id', 'test_table', ['name'])
expected_query = 'SELECT id FROM test_table WHERE name = ?;'
self.assertEqual(query, expected_query)
print("Make query test passed.")
def test_make_insert(self):
print("\nRunning test_make_insert...")
insert = make_insert('test_table', ('name', 'age'))
expected_insert = "INSERT INTO test_table ('name', 'age') VALUES(?, ?);"
self.assertEqual(insert, expected_insert)
print("Make insert test passed.")
def test_get_item_where(self):
print("\nRunning test_get_item_where...")
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
self.cursor.execute("INSERT INTO test_table (id, name) VALUES (1, 'test')")
self.connection.commit()
item = self.db.get_item_where('name', 'test_table', ('id', 1))
self.assertEqual(item, 'test')
print("Get item where test passed.")
def test_get_rows_where(self):
print("\nRunning test_get_rows_where...")
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
self.cursor.execute("INSERT INTO test_table (id, name) VALUES (1, 'test')")
self.connection.commit()
rows = self.db.get_rows_where('test_table', ('name', 'test'))
self.assertIsInstance(rows, pd.DataFrame)
self.assertEqual(rows.iloc[0]['name'], 'test')
print("Get rows where test passed.")
def test_insert_dataframe(self):
print("\nRunning test_insert_dataframe...")
df = pd.DataFrame({'id': [1], 'name': ['test']})
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
self.connection.commit()
self.db.insert_dataframe(df, 'test_table')
self.cursor.execute('SELECT name FROM test_table WHERE id = 1')
result = self.cursor.fetchone()
self.assertEqual(result[0], 'test')
print("Insert dataframe test passed.")
def test_insert_row(self):
print("\nRunning test_insert_row...")
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
self.connection.commit()
self.db.insert_row('test_table', ('id', 'name'), (1, 'test'))
self.cursor.execute('SELECT name FROM test_table WHERE id = 1')
result = self.cursor.fetchone()
self.assertEqual(result[0], 'test')
print("Insert row test passed.")
def test_table_exists(self):
print("\nRunning test_table_exists...")
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
self.connection.commit()
exists = self.db.table_exists('test_table')
self.assertTrue(exists)
print("Table exists test passed.")
def test_get_timestamped_records(self):
print("\nRunning test_get_timestamped_records...")
df = pd.DataFrame({
'open_time': [unix_time_millis(dt.datetime.utcnow())],
'open': [1.0],
'high': [1.0],
'low': [1.0],
'close': [1.0],
'volume': [1.0]
})
table_name = 'test_table'
self.cursor.execute(f"""
CREATE TABLE {table_name} (
id INTEGER PRIMARY KEY,
open_time INTEGER UNIQUE,
open REAL NOT NULL,
high REAL NOT NULL,
low REAL NOT NULL,
close REAL NOT NULL,
volume REAL NOT NULL
)
""")
self.connection.commit()
self.db.insert_dataframe(df, table_name)
st = dt.datetime.utcnow() - dt.timedelta(minutes=1)
et = dt.datetime.utcnow()
records = self.db.get_timestamped_records(table_name, 'open_time', st, et)
self.assertIsInstance(records, pd.DataFrame)
self.assertFalse(records.empty)
print("Get timestamped records test passed.")
def test_get_from_static_table(self):
print("\nRunning test_get_from_static_table...")
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT UNIQUE)')
self.connection.commit()
item = self.db.get_from_static_table('id', 'test_table', HDict({'name': 'test'}), create_id=True)
self.assertIsInstance(item, int)
self.cursor.execute('SELECT id FROM test_table WHERE name = ?', ('test',))
result = self.cursor.fetchone()
self.assertEqual(item, result[0])
print("Get from static table test passed.")
def test_insert_candles_into_db(self):
print("\nRunning test_insert_candles_into_db...")
df = pd.DataFrame({
'open_time': [unix_time_millis(dt.datetime.utcnow())],
'open': [1.0],
'high': [1.0],
'low': [1.0],
'close': [1.0],
'volume': [1.0]
})
table_name = 'test_table'
self.cursor.execute(f"""
CREATE TABLE {table_name} (
id INTEGER PRIMARY KEY,
market_id INTEGER,
open_time INTEGER UNIQUE,
open REAL NOT NULL,
high REAL NOT NULL,
low REAL NOT NULL,
close REAL NOT NULL,
volume REAL NOT NULL
)
""")
self.connection.commit()
# Create the exchange and markets tables needed for the foreign key constraints
self.cursor.execute("""
CREATE TABLE exchange (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE
)
""")
self.cursor.execute("""
CREATE TABLE markets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT,
exchange_id INTEGER,
FOREIGN KEY (exchange_id) REFERENCES exchange(id)
)
""")
self.connection.commit()
self.db.insert_candles_into_db(df, table_name, 'BTC/USDT', 'binance')
self.cursor.execute(f'SELECT * FROM {table_name}')
result = self.cursor.fetchall()
self.assertFalse(len(result) == 0)
print("Insert candles into db test passed.")
def test_make_insert():
table = 'markets'
values = {'first_field': 'first_value'}
q_str = make_insert(table=table, values=values)
print(f'\nWith one indexing field: {q_str}')
if __name__ == '__main__':
unittest.main()
values = {'first_field': 'first_value', 'second_field': 'second_value'}
q_str = make_insert(table=table, values=values)
print(f'\nWith two indexing fields: {q_str}')
assert q_str is not None
def test__table_exists():
exchanges = ExchangeInterface()
d_obj = Database(exchanges)
exists = d_obj._table_exists('BTC/USD_5m_alpaca')
print(f'\nExists - Should be true: {exists}')
assert exists
exists = d_obj._table_exists('BTC/USD_5m_alpina')
print(f'Doesnt exist - should be false: {exists}')
assert not exists
def test_get_from_static_table():
exchanges = ExchangeInterface()
d_obj = Database(exchanges)
market_id = d_obj._fetch_market_id('BTC/USD', 'alpaca')
e_id = d_obj._fetch_exchange_id('alpaca')
print(f'market id: {market_id}')
assert market_id > 0
print(f'exchange_name ID: {e_id}')
assert e_id == 4
def test_populate_table():
"""
Populates a database table with records from the exchange_name.
:param table_name: str - The name of the table in the database.
:param start_time: datetime - The starting time to fetch the records from.
:param end_time: datetime - The end time to get the records until.
:return: pdDataframe: - The data that was downloaded.
"""
exchanges = ExchangeInterface()
d_obj = Database(exchanges)
d_obj._populate_table(table_name='BTC/USD_2h_alpaca',
start_time=datetime.datetime(year=2023, month=3, day=27, hour=6, minute=0))
def test_get_records_since():
exchanges = ExchangeInterface()
d_obj = Database(exchanges)
records = d_obj.get_records_since(table_name='BTC/USD_15m_alpaca',
st=datetime.datetime(year=2023, month=3, day=27, hour=1, minute=0),
et=datetime.datetime.utcnow(),
rl=15)
print(records)
assert records is not None
# def test():
# # un_hashed_pass = 'password'
# hasher = bcrypt.using(rounds=13)
# # hashed_pass = hasher.hash(un_hashed_pass)
# # print(f'password: {un_hashed_pass}')
# # print(f'hashed pass: {hashed_pass}')
# # print(f" right pass: {hasher.verify('password', hashed_pass)}")
# # print(f" wrong pass: {hasher.verify('passWord', hashed_pass)}")
# engine = create_engine("sqlite:///" + config.DB_FILE, echo=True)
# with engine.connect() as conn:
# default_user = pd.read_sql_query(sql=text("SELECT * FROM users WHERE user_name = 'guest'"), con=conn)
# # hashed_password = default_user.password.values[0]
# # print(f" verify pass: {hasher.verify('password', hashed_password)}")
# username = default_user.user_name.values[0]
# print(username)

View File

@ -12,7 +12,7 @@ class TestExchange(unittest.TestCase):
@classmethod
def setUpClass(cls):
exchange_name = 'binance'
exchange_name = 'kraken'
cls.api_keys = None
"""Uncomment and Provide api keys to connect to exchange."""
# cls.api_keys = {'key': 'EXCHANGE_API_KEY', 'secret': 'EXCHANGE_API_SECRET'}