Completed unittests for Database and DataCache.
This commit is contained in:
parent
e601f8c23e
commit
4130e0ca9a
|
|
@ -27,7 +27,7 @@ class BrighterTrades:
|
|||
self.signals = Signals(self.config.signals_list)
|
||||
|
||||
# Object that maintains candlestick and price data.
|
||||
self.candles = Candles(config_obj=self.config, exchanges=self.exchanges, database=self.data)
|
||||
self.candles = Candles(config_obj=self.config, exchanges=self.exchanges, data_source=self.data)
|
||||
|
||||
# Object that interacts with and maintains data from available indicators
|
||||
self.indicators = Indicators(self.candles, self.config)
|
||||
|
|
|
|||
292
src/DataCache.py
292
src/DataCache.py
|
|
@ -1,25 +1,41 @@
|
|||
from typing import Any, List
|
||||
|
||||
import pandas as pd
|
||||
import datetime as dt
|
||||
from Database import Database
|
||||
from shared_utilities import query_satisfied, query_uptodate, unix_time_millis, timeframe_to_minutes
|
||||
import logging
|
||||
|
||||
from database import Database
|
||||
from shared_utilities import query_satisfied, query_uptodate, unix_time_millis
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataCache:
|
||||
"""
|
||||
Fetches manages data limits and optimises memory storage.
|
||||
Fetches manages data limits and optimises memory storage.
|
||||
Handles connections and operations for the given exchanges.
|
||||
|
||||
Example usage:
|
||||
--------------
|
||||
db = DataCache(exchanges=some_exchanges_object)
|
||||
"""
|
||||
|
||||
def __init__(self, exchanges):
|
||||
"""
|
||||
Initializes the DataCache class.
|
||||
|
||||
# Define the cache size
|
||||
self.max_tables = 50 # Maximum amount of tables that will be cached at any given time.
|
||||
self.max_records = 1000 # Maximum number of records that will be kept per table.
|
||||
self.cached_data = {} # A dictionary that holds all the cached records.
|
||||
self.db = Database(
|
||||
exchanges) # The class that handles the DB interactions pass it a connection to the exchange_interface.
|
||||
:param exchanges: The exchanges object handling communication with connected exchanges.
|
||||
"""
|
||||
# Maximum number of tables to cache at any given time.
|
||||
self.max_tables = 50
|
||||
# Maximum number of records to be kept per table.
|
||||
self.max_records = 1000
|
||||
# A dictionary that holds all the cached records.
|
||||
self.cached_data = {}
|
||||
# The class that handles the DB interactions.
|
||||
self.db = Database()
|
||||
# The class that handles exchange interactions.
|
||||
self.exchanges = exchanges
|
||||
|
||||
def cache_exists(self, key: str) -> bool:
|
||||
"""
|
||||
|
|
@ -89,27 +105,32 @@ class DataCache:
|
|||
"""
|
||||
Return any records from the cache indexed by table_name that are newer than start_datetime.
|
||||
|
||||
:param ex_details: Details to pass to the server
|
||||
:param key: <str>: - The dictionary table_name of the records.
|
||||
:param start_datetime: <dt.datetime>: - The datetime of the first record requested.
|
||||
:param record_length: The timespan of the records.
|
||||
:param ex_details: List[str] - Details to pass to the server
|
||||
:param key: str - The dictionary table_name of the records.
|
||||
:param start_datetime: dt.datetime - The datetime of the first record requested.
|
||||
:param record_length: int - The timespan of the records.
|
||||
|
||||
:return pd.DataFrame: - The Requested records
|
||||
:return: pd.DataFrame - The Requested records
|
||||
|
||||
Example:
|
||||
--------
|
||||
records = data_cache.get_records_since('BTC/USD_2h_binance', dt.datetime.utcnow() - dt.timedelta(minutes=60), 60, ['BTC/USD', '2h', 'binance'])
|
||||
"""
|
||||
try:
|
||||
# End time of query defaults to the current time.
|
||||
end_datetime = dt.datetime.utcnow()
|
||||
|
||||
if self.cache_exists(key=key):
|
||||
print('\nGetting records from cache.')
|
||||
# If the records exist retrieve them from the cache.
|
||||
logger.debug('Getting records from cache.')
|
||||
# If the records exist, retrieve them from the cache.
|
||||
records = self.get_cache(key)
|
||||
else:
|
||||
# If they don't exist in cache, get them from the db.
|
||||
print(f'\nRecords not in cache. Requesting from DB: starting at: {start_datetime} to: {end_datetime}')
|
||||
records = self.db.get_records_since(table_name=key, st=start_datetime,
|
||||
et=end_datetime, rl=record_length, ex_details=ex_details)
|
||||
print(f'Got {len(records.index)} records from db')
|
||||
# If they don't exist in cache, get them from the database.
|
||||
logger.debug(
|
||||
f'Records not in cache. Requesting from DB: starting at: {start_datetime} to: {end_datetime}')
|
||||
records = self.get_records_since_from_db(table_name=key, st=start_datetime,
|
||||
et=end_datetime, rl=record_length, ex_details=ex_details)
|
||||
logger.debug(f'Got {len(records.index)} records from DB.')
|
||||
self.set_cache(data=records, key=key)
|
||||
|
||||
# Check if the records in the cache go far enough back to satisfy the query.
|
||||
|
|
@ -119,13 +140,14 @@ class DataCache:
|
|||
if first_timestamp:
|
||||
# The records didn't go far enough back if a timestamp was returned.
|
||||
end_time = dt.datetime.utcfromtimestamp(first_timestamp)
|
||||
# Request candles with open_times between [start_time:end_time] from the database.
|
||||
print(f'requesting additional records from {start_datetime} to {end_time}')
|
||||
additional_records = self.db.get_records_since(table_name=key, st=start_datetime,
|
||||
et=end_time, rl=record_length, ex_details=ex_details)
|
||||
print(f'Got {len(additional_records.index)} additional records from db')
|
||||
logger.debug(f'Requesting additional records from {start_datetime} to {end_time}')
|
||||
# Request additional records from the database.
|
||||
additional_records = self.get_records_since_from_db(table_name=key, st=start_datetime,
|
||||
et=end_time, rl=record_length,
|
||||
ex_details=ex_details)
|
||||
logger.debug(f'Got {len(additional_records.index)} additional records from DB.')
|
||||
if not additional_records.empty:
|
||||
# If more records were received update the cache.
|
||||
# If more records were received, update the cache.
|
||||
self.update_candle_cache(additional_records, key)
|
||||
|
||||
# Check if the records received are up-to-date.
|
||||
|
|
@ -134,21 +156,219 @@ class DataCache:
|
|||
if last_timestamp:
|
||||
# The query was not up-to-date if a timestamp was returned.
|
||||
start_time = dt.datetime.utcfromtimestamp(last_timestamp)
|
||||
print(f'requesting additional records from {start_time} to {end_datetime}')
|
||||
# Request the database update its table starting from start_datetime.
|
||||
additional_records = self.db.get_records_since(table_name=key, st=start_time,
|
||||
et=end_datetime, rl=record_length,
|
||||
ex_details=ex_details)
|
||||
print(f'Got {len(additional_records.index)} additional records from db')
|
||||
logger.debug(f'Requesting additional records from {start_time} to {end_datetime}')
|
||||
# Request additional records from the database.
|
||||
additional_records = self.get_records_since_from_db(table_name=key, st=start_time,
|
||||
et=end_datetime, rl=record_length,
|
||||
ex_details=ex_details)
|
||||
logger.debug(f'Got {len(additional_records.index)} additional records from DB.')
|
||||
if not additional_records.empty:
|
||||
self.update_candle_cache(additional_records, key)
|
||||
|
||||
# Create a UTC timestamp.
|
||||
_timestamp = unix_time_millis(start_datetime)
|
||||
# Return all records equal or newer than timestamp.
|
||||
logger.debug(f"Start timestamp in human-readable form: {dt.datetime.utcfromtimestamp(_timestamp / 1000.0)}")
|
||||
|
||||
# Return all records equal to or newer than the timestamp.
|
||||
result = self.get_cache(key).query('open_time >= @_timestamp')
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {str(e)}")
|
||||
logger.error(f"An error occurred: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_records_since_from_db(self, table_name: str, st: dt.datetime,
|
||||
et: dt.datetime, rl: float, ex_details: list) -> pd.DataFrame:
|
||||
"""
|
||||
Returns records from a specified table that meet a criteria and ensures the records are complete.
|
||||
If the records do not go back far enough or are not up-to-date, it fetches additional records
|
||||
from an exchange and populates the table.
|
||||
|
||||
:param table_name: Database table name.
|
||||
:param st: Start datetime.
|
||||
:param et: End datetime.
|
||||
:param rl: Timespan in minutes each record represents.
|
||||
:param ex_details: Exchange details [symbol, interval, exchange_name].
|
||||
:return: DataFrame of records.
|
||||
|
||||
Example:
|
||||
--------
|
||||
records = db.get_records_since('test_table', start_time, end_time, 1, ['BTC/USDT', '1m', 'binance'])
|
||||
"""
|
||||
|
||||
def add_data(data, tn, start_t, end_t):
|
||||
new_records = self._populate_db(table_name=tn, start_time=start_t, end_time=end_t, ex_details=ex_details)
|
||||
print(f'Got {len(new_records.index)} records from exchange_name')
|
||||
if not new_records.empty:
|
||||
data = pd.concat([data, new_records], axis=0, ignore_index=True)
|
||||
data = data.drop_duplicates(subset="open_time", keep='first')
|
||||
return data
|
||||
|
||||
if self.db.table_exists(table_name=table_name):
|
||||
print('\nTable existed retrieving records from DB')
|
||||
print(f'Requesting from {st} to {et}')
|
||||
records = self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=st, et=et)
|
||||
print(f'Got {len(records.index)} records from db')
|
||||
else:
|
||||
print(f'\nTable didnt exist fetching from {ex_details[2]}')
|
||||
temp = (((unix_time_millis(et) - unix_time_millis(st)) / 1000) / 60) / rl
|
||||
print(f'Requesting from {st} to {et}, Should be {temp} records')
|
||||
records = self._populate_db(table_name=table_name, start_time=st, end_time=et, ex_details=ex_details)
|
||||
print(f'Got {len(records.index)} records from {ex_details[2]}')
|
||||
|
||||
first_timestamp = query_satisfied(start_datetime=st, records=records, r_length_min=rl)
|
||||
if first_timestamp:
|
||||
print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}')
|
||||
print(f'first ts on record is: {first_timestamp}')
|
||||
end_time = dt.datetime.utcfromtimestamp(first_timestamp)
|
||||
print(f'Requesting from {st} to {end_time}')
|
||||
records = add_data(data=records, tn=table_name, start_t=st, end_t=end_time)
|
||||
|
||||
last_timestamp = query_uptodate(records=records, r_length_min=rl)
|
||||
if last_timestamp:
|
||||
print(f'\nRecords were not updated. Requesting from {ex_details[2]}.')
|
||||
print(f'the last record on file is: {last_timestamp}')
|
||||
start_time = dt.datetime.utcfromtimestamp(last_timestamp)
|
||||
print(f'Requesting from {start_time} to {et}')
|
||||
records = add_data(data=records, tn=table_name, start_t=start_time, end_t=et)
|
||||
|
||||
return records
|
||||
|
||||
def _populate_db(self, table_name: str, start_time: dt.datetime, ex_details: list,
|
||||
end_time: dt.datetime = None) -> pd.DataFrame:
|
||||
"""
|
||||
Populates a database table with records from the exchange.
|
||||
|
||||
:param table_name: Name of the table in the database.
|
||||
:param start_time: Start time to fetch the records from.
|
||||
:param end_time: End time to fetch the records until (optional).
|
||||
:param ex_details: Exchange details [symbol, interval, exchange_name, user_name].
|
||||
:return: DataFrame of the data downloaded.
|
||||
|
||||
Example:
|
||||
--------
|
||||
records = db._populate_table('test_table', start_time, ['BTC/USDT', '1m', 'binance', 'user1'])
|
||||
"""
|
||||
if end_time is None:
|
||||
end_time = dt.datetime.utcnow()
|
||||
sym, inter, ex, un = ex_details
|
||||
records = self._fetch_candles_from_exchange(symbol=sym, interval=inter, exchange_name=ex, user_name=un,
|
||||
start_datetime=start_time, end_datetime=end_time)
|
||||
if not records.empty:
|
||||
self.db.insert_candles_into_db(records, table_name=table_name, symbol=sym, exchange_name=ex)
|
||||
else:
|
||||
print(f'No records inserted {records}')
|
||||
return records
|
||||
|
||||
def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str,
|
||||
start_datetime: dt.datetime = None,
|
||||
end_datetime: dt.datetime = None) -> pd.DataFrame:
|
||||
"""
|
||||
Fetches and returns all candles from the specified market, timeframe, and exchange.
|
||||
|
||||
:param symbol: Symbol of the market.
|
||||
:param interval: Timeframe in the format '<int><alpha>' (e.g., '15m', '4h').
|
||||
:param exchange_name: Name of the exchange.
|
||||
:param user_name: Name of the user.
|
||||
:param start_datetime: Start datetime for fetching data (optional).
|
||||
:param end_datetime: End datetime for fetching data (optional).
|
||||
:return: DataFrame of candle data.
|
||||
|
||||
Example:
|
||||
--------
|
||||
candles = db._fetch_candles_from_exchange('BTC/USDT', '1m', 'binance', 'user1', start_time, end_time)
|
||||
"""
|
||||
|
||||
def fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame:
|
||||
"""
|
||||
Fills gaps in the data by replicating the last known data point for the missing periods.
|
||||
|
||||
:param records: DataFrame containing the original records.
|
||||
:param interval: Interval of the data (e.g., '1m', '5m').
|
||||
:return: DataFrame with gaps filled.
|
||||
|
||||
Example:
|
||||
--------
|
||||
filled_records = fill_data_holes(df, '1m')
|
||||
"""
|
||||
time_span = timeframe_to_minutes(interval)
|
||||
last_timestamp = None
|
||||
filled_records = []
|
||||
|
||||
logger.info(f"Starting to fill data holes for interval: {interval}")
|
||||
|
||||
for index, row in records.iterrows():
|
||||
time_stamp = row['open_time']
|
||||
|
||||
# If last_timestamp is None, this is the first record
|
||||
if last_timestamp is None:
|
||||
last_timestamp = time_stamp
|
||||
filled_records.append(row)
|
||||
logger.debug(f"First timestamp: {time_stamp}")
|
||||
continue
|
||||
|
||||
# Calculate the difference in milliseconds and minutes
|
||||
delta_ms = time_stamp - last_timestamp
|
||||
delta_minutes = (delta_ms / 1000) / 60
|
||||
|
||||
logger.debug(f"Timestamp: {time_stamp}, Delta minutes: {delta_minutes}")
|
||||
|
||||
# If the gap is larger than the time span of the interval, fill the gap
|
||||
if delta_minutes > time_span:
|
||||
num_missing_rec = int(delta_minutes / time_span)
|
||||
step = int(delta_ms / num_missing_rec)
|
||||
logger.debug(f"Gap detected. Filling {num_missing_rec} records with step: {step}")
|
||||
|
||||
for ts in range(int(last_timestamp) + step, int(time_stamp), step):
|
||||
new_row = row.copy()
|
||||
new_row['open_time'] = ts
|
||||
filled_records.append(new_row)
|
||||
logger.debug(f"Filled timestamp: {ts}")
|
||||
|
||||
filled_records.append(row)
|
||||
last_timestamp = time_stamp
|
||||
|
||||
logger.info("Data holes filled successfully.")
|
||||
return pd.DataFrame(filled_records)
|
||||
|
||||
# Default start date if not provided
|
||||
if start_datetime is None:
|
||||
start_datetime = dt.datetime(year=2017, month=1, day=1)
|
||||
|
||||
# Default end date if not provided
|
||||
if end_datetime is None:
|
||||
end_datetime = dt.datetime.utcnow()
|
||||
|
||||
# Check if start date is greater than end date
|
||||
if start_datetime > end_datetime:
|
||||
raise ValueError("Invalid start and end parameters: start_datetime must be before end_datetime.")
|
||||
|
||||
# Get the exchange object
|
||||
exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name)
|
||||
|
||||
# Calculate the expected number of records
|
||||
temp = (((unix_time_millis(end_datetime) - unix_time_millis(
|
||||
start_datetime)) / 1000) / 60) / timeframe_to_minutes(interval)
|
||||
|
||||
logger.info(f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {temp}')
|
||||
|
||||
# If start and end times are the same, set end_datetime to None
|
||||
if start_datetime == end_datetime:
|
||||
end_datetime = None
|
||||
|
||||
# Fetch historical candlestick data from the exchange
|
||||
candles = exchange.get_historical_klines(symbol=symbol, interval=interval, start_dt=start_datetime,
|
||||
end_dt=end_datetime)
|
||||
num_rec_records = len(candles.index)
|
||||
|
||||
logger.info(f'{num_rec_records} candles retrieved from the exchange.')
|
||||
|
||||
# Check if the retrieved data covers the expected time range
|
||||
open_times = candles.open_time
|
||||
estimated_num_records = (((open_times.max() - open_times.min()) / 1000) / 60) / timeframe_to_minutes(interval)
|
||||
|
||||
# Fill in any missing data if the retrieved data is less than expected
|
||||
if num_rec_records < estimated_num_records:
|
||||
candles = fill_data_holes(candles, interval)
|
||||
|
||||
return candles
|
||||
|
|
|
|||
|
|
@ -0,0 +1,352 @@
|
|||
import sqlite3
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
import config
|
||||
import datetime as dt
|
||||
import pandas as pd
|
||||
from shared_utilities import unix_time_millis
|
||||
|
||||
|
||||
class SQLite:
|
||||
"""
|
||||
Context manager for SQLite database connections.
|
||||
Accepts a database file name or defaults to the file in config.DB_FILE.
|
||||
|
||||
Example usage:
|
||||
--------------
|
||||
with SQLite(db_file='test_db.sqlite') as con:
|
||||
cursor = con.cursor()
|
||||
cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||
"""
|
||||
|
||||
def __init__(self, db_file=None):
|
||||
self.db_file = db_file if db_file else config.DB_FILE
|
||||
self.connection = sqlite3.connect(self.db_file)
|
||||
|
||||
def __enter__(self):
|
||||
return self.connection
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.connection.commit()
|
||||
self.connection.close()
|
||||
|
||||
|
||||
class HDict(dict):
|
||||
"""
|
||||
Hashable dictionary to use as cache keys.
|
||||
|
||||
Example usage:
|
||||
--------------
|
||||
hdict = HDict({'key1': 'value1', 'key2': 'value2'})
|
||||
hash(hdict)
|
||||
"""
|
||||
|
||||
def __hash__(self):
|
||||
return hash(frozenset(self.items()))
|
||||
|
||||
|
||||
def make_query(item: str, table: str, columns: list) -> str:
|
||||
"""
|
||||
Creates a SQL select query string with the required number of placeholders.
|
||||
|
||||
:param item: The field to select.
|
||||
:param table: The table to select from.
|
||||
:param columns: List of columns for the where clause.
|
||||
:return: The query string.
|
||||
|
||||
Example:
|
||||
--------
|
||||
query = make_query('id', 'test_table', ['name', 'age'])
|
||||
# Result: 'SELECT id FROM test_table WHERE name = ? AND age = ?;'
|
||||
"""
|
||||
an_itr = iter(columns)
|
||||
k = next(an_itr)
|
||||
where_str = f"SELECT {item} FROM {table} WHERE {k} = ?"
|
||||
where_str += "".join([f" AND {k} = ?" for k in an_itr]) + ';'
|
||||
return where_str
|
||||
|
||||
|
||||
def make_insert(table: str, values: tuple) -> str:
|
||||
"""
|
||||
Creates a SQL insert query string with the required number of placeholders.
|
||||
|
||||
:param table: The table to insert into.
|
||||
:param values: Tuple of values to insert.
|
||||
:return: The query string.
|
||||
|
||||
Example:
|
||||
--------
|
||||
insert = make_insert('test_table', ('name', 'age'))
|
||||
# Result: "INSERT INTO test_table ('name', 'age') VALUES(?, ?);"
|
||||
"""
|
||||
itr1 = iter(values)
|
||||
itr2 = iter(values)
|
||||
k1 = next(itr1)
|
||||
_ = next(itr2)
|
||||
insert_str = f"INSERT INTO {table} ('{k1}'"
|
||||
insert_str += "".join([f", '{k1}'" for k1 in itr1]) + ") VALUES(?" + "".join(
|
||||
[", ?" for _ in enumerate(itr2)]) + ");"
|
||||
return insert_str
|
||||
|
||||
|
||||
class Database:
|
||||
"""
|
||||
Database class to communicate and maintain the database.
|
||||
Handles connections and operations for the given exchanges.
|
||||
|
||||
Example usage:
|
||||
--------------
|
||||
db = Database(exchanges=some_exchanges_object, db_file='test_db.sqlite')
|
||||
db.execute_sql('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||
"""
|
||||
|
||||
def __init__(self, db_file=None):
|
||||
"""
|
||||
Initializes the Database class.
|
||||
|
||||
:param db_file: Optional database file name.
|
||||
"""
|
||||
self.db_file = db_file
|
||||
|
||||
def execute_sql(self, sql: str) -> None:
|
||||
"""
|
||||
Executes a raw SQL statement.
|
||||
|
||||
:param sql: SQL statement to execute.
|
||||
|
||||
Example:
|
||||
--------
|
||||
db = Database(exchanges=some_exchanges_object, db_file='test_db.sqlite')
|
||||
db.execute_sql('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||
"""
|
||||
with SQLite(self.db_file) as con:
|
||||
cur = con.cursor()
|
||||
cur.execute(sql)
|
||||
|
||||
def get_item_where(self, item_name: str, table_name: str, filter_vals: tuple) -> int:
|
||||
"""
|
||||
Returns an item from a table where the filter results should isolate a single row.
|
||||
|
||||
:param item_name: Name of the item to fetch.
|
||||
:param table_name: Name of the table.
|
||||
:param filter_vals: Tuple of column name and value to filter by.
|
||||
:return: The item.
|
||||
|
||||
Example:
|
||||
--------
|
||||
item = db.get_item_where('name', 'test_table', ('id', 1))
|
||||
# Fetches the 'name' from 'test_table' where 'id' is 1
|
||||
"""
|
||||
with SQLite(self.db_file) as con:
|
||||
cur = con.cursor()
|
||||
qry = make_query(item_name, table_name, [filter_vals[0]])
|
||||
cur.execute(qry, (filter_vals[1],))
|
||||
if user_id := cur.fetchone():
|
||||
return user_id[0]
|
||||
else:
|
||||
error = f"Couldn't fetch item {item_name} from {table_name} where {filter_vals[0]} = {filter_vals[1]}"
|
||||
raise ValueError(error)
|
||||
|
||||
def get_rows_where(self, table: str, filter_vals: tuple) -> pd.DataFrame | None:
|
||||
"""
|
||||
Returns a DataFrame containing all rows of a table that meet the filter criteria.
|
||||
|
||||
:param table: Name of the table.
|
||||
:param filter_vals: Tuple of column name and value to filter by.
|
||||
:return: DataFrame of the query result or None if empty.
|
||||
|
||||
Example:
|
||||
--------
|
||||
rows = db.get_rows_where('test_table', ('name', 'test'))
|
||||
# Fetches all rows from 'test_table' where 'name' is 'test'
|
||||
"""
|
||||
with SQLite(self.db_file) as con:
|
||||
qry = f"SELECT * FROM {table} WHERE {filter_vals[0]}='{filter_vals[1]}'"
|
||||
result = pd.read_sql(qry, con=con)
|
||||
if not result.empty:
|
||||
return result
|
||||
else:
|
||||
return None
|
||||
|
||||
def insert_dataframe(self, df: pd.DataFrame, table: str) -> None:
|
||||
"""
|
||||
Inserts a DataFrame into a specified table.
|
||||
|
||||
:param df: DataFrame to insert.
|
||||
:param table: Name of the table.
|
||||
|
||||
Example:
|
||||
--------
|
||||
df = pd.DataFrame({'id': [1], 'name': ['test']})
|
||||
db.insert_dataframe(df, 'test_table')
|
||||
"""
|
||||
with SQLite(self.db_file) as con:
|
||||
df.to_sql(name=table, con=con, index=False, if_exists='append')
|
||||
|
||||
def insert_row(self, table: str, columns: tuple, values: tuple) -> None:
|
||||
"""
|
||||
Inserts a row into a specified table.
|
||||
|
||||
:param table: Name of the table.
|
||||
:param columns: Tuple of column names.
|
||||
:param values: Tuple of values to insert.
|
||||
|
||||
Example:
|
||||
--------
|
||||
db.insert_row('test_table', ('id', 'name'), (1, 'test'))
|
||||
"""
|
||||
with SQLite(self.db_file) as conn:
|
||||
cursor = conn.cursor()
|
||||
sql = make_insert(table=table, values=columns)
|
||||
cursor.execute(sql, values)
|
||||
|
||||
def table_exists(self, table_name: str) -> bool:
|
||||
"""
|
||||
Checks if a table exists in the database.
|
||||
|
||||
:param table_name: Name of the table.
|
||||
:return: True if the table exists, False otherwise.
|
||||
|
||||
Example:
|
||||
--------
|
||||
exists = db._table_exists('test_table')
|
||||
# Checks if 'test_table' exists in the database
|
||||
"""
|
||||
with SQLite(self.db_file) as conn:
|
||||
cursor = conn.cursor()
|
||||
sql = make_query(item='name', table='sqlite_schema', columns=['type', 'name'])
|
||||
cursor.execute(sql, ('table', table_name))
|
||||
result = cursor.fetchone()
|
||||
return result is not None
|
||||
|
||||
@lru_cache(maxsize=1000)
|
||||
def get_from_static_table(self, item: str, table: str, indexes: HDict, create_id: bool = False) -> Any:
|
||||
"""
|
||||
Returns the row id of an item from a table specified. If the item isn't listed in the table,
|
||||
it will insert the item into a new row and return the autoincremented id. The item received as a hashable
|
||||
dictionary so the results can be cached.
|
||||
|
||||
:param item: Name of the item requested.
|
||||
:param table: Table being queried.
|
||||
:param indexes: Hashable dictionary of indexing columns and their values.
|
||||
:param create_id: If True, create a row if it doesn't exist and return the autoincrement ID.
|
||||
:return: The content of the field.
|
||||
|
||||
Example:
|
||||
--------
|
||||
exchange_id = db.get_from_static_table('id', 'exchange', HDict({'name': 'binance'}), create_id=True)
|
||||
"""
|
||||
with SQLite(self.db_file) as conn:
|
||||
cursor = conn.cursor()
|
||||
sql = make_query(item, table, list(indexes.keys()))
|
||||
cursor.execute(sql, tuple(indexes.values()))
|
||||
result = cursor.fetchone()
|
||||
if result is None and create_id:
|
||||
sql = make_insert(table, tuple(indexes.keys()))
|
||||
cursor.execute(sql, tuple(indexes.values()))
|
||||
sql = make_query(item, table, list(indexes.keys()))
|
||||
cursor.execute(sql, tuple(indexes.values()))
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
def _fetch_exchange_id(self, exchange_name: str) -> int:
|
||||
"""
|
||||
Fetches the primary ID of an exchange from the database.
|
||||
|
||||
:param exchange_name: Name of the exchange.
|
||||
:return: Primary ID of the exchange.
|
||||
|
||||
Example:
|
||||
--------
|
||||
exchange_id = db._fetch_exchange_id('binance')
|
||||
"""
|
||||
return self.get_from_static_table(item='id', table='exchange', create_id=True,
|
||||
indexes=HDict({'name': exchange_name}))
|
||||
|
||||
def _fetch_market_id(self, symbol: str, exchange_name: str) -> int:
|
||||
"""
|
||||
Returns the market ID for a trading pair listed in the database.
|
||||
|
||||
:param symbol: Symbol of the trading pair.
|
||||
:param exchange_name: Name of the exchange.
|
||||
:return: Market ID.
|
||||
|
||||
Example:
|
||||
--------
|
||||
market_id = db._fetch_market_id('BTC/USDT', 'binance')
|
||||
"""
|
||||
exchange_id = self._fetch_exchange_id(exchange_name)
|
||||
market_id = self.get_from_static_table(item='id', table='markets', create_id=True,
|
||||
indexes=HDict({'symbol': symbol, 'exchange_id': exchange_id}))
|
||||
return market_id
|
||||
|
||||
def insert_candles_into_db(self, candlesticks: pd.DataFrame, table_name: str, symbol: str,
|
||||
exchange_name: str) -> None:
|
||||
"""
|
||||
Inserts all candlesticks from a DataFrame into the database.
|
||||
|
||||
:param candlesticks: DataFrame of candlestick data.
|
||||
:param table_name: Name of the table to insert into.
|
||||
:param symbol: Symbol of the trading pair.
|
||||
:param exchange_name: Name of the exchange.
|
||||
|
||||
Example:
|
||||
--------
|
||||
df = pd.DataFrame({
|
||||
'open_time': [unix_time_millis(dt.datetime.utcnow())],
|
||||
'open': [1.0],
|
||||
'high': [1.0],
|
||||
'low': [1.0],
|
||||
'close': [1.0],
|
||||
'volume': [1.0]
|
||||
})
|
||||
db._insert_candles_into_db(df, 'test_table', 'BTC/USDT', 'binance')
|
||||
"""
|
||||
market_id = self._fetch_market_id(symbol, exchange_name)
|
||||
candlesticks.insert(0, 'market_id', market_id)
|
||||
sql_create = f"""
|
||||
CREATE TABLE IF NOT EXISTS '{table_name}' (
|
||||
id INTEGER PRIMARY KEY,
|
||||
market_id INTEGER,
|
||||
open_time INTEGER UNIQUE ON CONFLICT IGNORE,
|
||||
open REAL NOT NULL,
|
||||
high REAL NOT NULL,
|
||||
low REAL NOT NULL,
|
||||
close REAL NOT NULL,
|
||||
volume REAL NOT NULL,
|
||||
FOREIGN KEY (market_id) REFERENCES market (id)
|
||||
)"""
|
||||
with SQLite(self.db_file) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql_create)
|
||||
candlesticks.to_sql(table_name, conn, if_exists='append', index=False)
|
||||
|
||||
def get_timestamped_records(self, table_name: str, timestamp_field: str, st: dt.datetime,
|
||||
et: dt.datetime = None) -> pd.DataFrame:
|
||||
"""
|
||||
Returns records from a specified table in the database that have timestamps greater than or equal to a given
|
||||
start time and, optionally, less than or equal to a given end time.
|
||||
|
||||
:param table_name: Database table name.
|
||||
:param timestamp_field: Field name that contains the timestamp.
|
||||
:param st: Start datetime.
|
||||
:param et: End datetime (optional).
|
||||
:return: DataFrame of records.
|
||||
|
||||
Example:
|
||||
--------
|
||||
records = db.get_timestamped_records('test_table', 'open_time', start_time, end_time)
|
||||
"""
|
||||
with SQLite(self.db_file) as conn:
|
||||
start_stamp = unix_time_millis(st)
|
||||
if et is not None:
|
||||
end_stamp = unix_time_millis(et)
|
||||
q_str = (
|
||||
f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= {start_stamp} "
|
||||
f"AND {timestamp_field} <= {end_stamp};"
|
||||
)
|
||||
else:
|
||||
q_str = f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= {start_stamp};"
|
||||
records = pd.read_sql(q_str, conn)
|
||||
records = records.drop('id', axis=1)
|
||||
return records
|
||||
|
|
@ -5,7 +5,7 @@ from typing import Any
|
|||
|
||||
from passlib.hash import bcrypt
|
||||
import pandas as pd
|
||||
from database import HDict
|
||||
from Database import HDict
|
||||
|
||||
|
||||
class Users:
|
||||
|
|
|
|||
|
|
@ -1,13 +1,12 @@
|
|||
import datetime as dt
|
||||
import logging as log
|
||||
from DataCache import DataCache
|
||||
from shared_utilities import timeframe_to_minutes, ts_of_n_minutes_ago
|
||||
|
||||
|
||||
# log.basicConfig(level=log.ERROR)
|
||||
|
||||
class Candles:
|
||||
def __init__(self, exchanges, config_obj, database):
|
||||
def __init__(self, exchanges, config_obj, data_source):
|
||||
|
||||
# A reference to the app configuration
|
||||
self.config = config_obj
|
||||
|
|
@ -15,8 +14,8 @@ class Candles:
|
|||
# The maximum amount of candles to load at one time.
|
||||
self.max_records = self.config.app_data.get('max_data_loaded')
|
||||
|
||||
# This object maintains all the cached data. Pass it connection to the exchanges.
|
||||
self.data = database
|
||||
# This object maintains all the cached data.
|
||||
self.data = data_source
|
||||
|
||||
# print('Setting the candle cache.')
|
||||
# # Populate the cache:
|
||||
|
|
|
|||
496
src/database.py
496
src/database.py
|
|
@ -1,496 +0,0 @@
|
|||
import sqlite3
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
import config
|
||||
import datetime as dt
|
||||
import pandas as pd
|
||||
|
||||
from shared_utilities import query_satisfied, query_uptodate, unix_time_millis, timeframe_to_minutes
|
||||
|
||||
|
||||
class SQLite:
|
||||
"""
|
||||
Context manager returns a cursor. The connection is closed when
|
||||
the cursor is destroyed, even if an exception is thrown.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.connection = sqlite3.connect(config.DB_FILE)
|
||||
|
||||
def __enter__(self):
|
||||
return self.connection
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.connection.commit()
|
||||
self.connection.close()
|
||||
|
||||
|
||||
class HDict(dict):
|
||||
def __hash__(self):
|
||||
return hash(frozenset(self.items()))
|
||||
|
||||
|
||||
def make_query(item: str, table: str, columns: list) -> str:
|
||||
"""
|
||||
Creates a sql select string with the required number of ?'s to match the given columns.
|
||||
|
||||
:param item: The field to select.
|
||||
:param table: The table to select.
|
||||
:param columns: list - A list of database columns.
|
||||
:return: str: The query string.
|
||||
"""
|
||||
an_itr = iter(columns)
|
||||
k = next(an_itr)
|
||||
where_str = f"SELECT {item} FROM {table} WHERE {k} = ?"
|
||||
where_str = where_str + "".join([f" AND {k} = ?" for k in an_itr]) + ';'
|
||||
return where_str
|
||||
|
||||
|
||||
def make_insert(table: str, values: tuple) -> str:
|
||||
"""
|
||||
Creates a sql insert string with the required number of ?'s to match the given values.
|
||||
|
||||
:param table: The table to insert into.
|
||||
:param values: dict - A dictionary of table_name-value pairs used to index a db query.
|
||||
:return: str: The query string.
|
||||
"""
|
||||
itr1 = iter(values)
|
||||
itr2 = iter(values)
|
||||
k1 = next(itr1)
|
||||
_ = next(itr2)
|
||||
insert_str = f"INSERT INTO {table} ('{k1}'"
|
||||
insert_str = insert_str + "".join([f", '{k1}'" for k1 in itr1]) + ") VALUES(?" + "".join(
|
||||
[", ?" for _ in enumerate(itr2)]) + ");"
|
||||
return insert_str
|
||||
|
||||
|
||||
class Database:
|
||||
"""
|
||||
Communicates and maintains the database.
|
||||
"""
|
||||
|
||||
def __init__(self, exchanges):
|
||||
# The exchanges object handles communication with all connected exchanges.
|
||||
self.exchanges = exchanges
|
||||
|
||||
@staticmethod
|
||||
def execute_sql(sql: str) -> None:
|
||||
"""
|
||||
Executes a sql statement. This is for stuff I haven't created a function for yet.
|
||||
|
||||
:param sql: str - sql statement.
|
||||
:return: None
|
||||
"""
|
||||
with SQLite() as con:
|
||||
cur = con.cursor()
|
||||
cur.execute(sql)
|
||||
|
||||
@staticmethod
|
||||
def get_item_where(item_name: str, table_name: str, filter_vals: tuple) -> int:
|
||||
"""
|
||||
Returns an item from a table where the filter results should isolate a single row.
|
||||
|
||||
:param item_name: str - The name of the item to fetch.
|
||||
:param table_name: str - The name of the table.
|
||||
:param filter_vals: tuple(str, str) - The column and value to filter the results with.
|
||||
:return: str - The item.
|
||||
"""
|
||||
with SQLite() as con:
|
||||
cur = con.cursor()
|
||||
qry = make_query(item_name, table_name, [filter_vals[0]])
|
||||
cur.execute(qry, (filter_vals[1],))
|
||||
if user_id := cur.fetchone():
|
||||
return user_id[0]
|
||||
else:
|
||||
error = f"Couldn't fetch item{item_name} from {table_name} where {filter_vals[0]} = {filter_vals[1]}"
|
||||
raise ValueError(error)
|
||||
|
||||
@staticmethod
|
||||
def get_rows_where(table: str, filter_vals: tuple) -> pd.DataFrame | None:
|
||||
"""
|
||||
Returns a dataframe containing all rows of a table that meet the filter criteria.
|
||||
|
||||
:param table: str - The name of the table.
|
||||
:param filter_vals: tuple(column: str, value: str) - the criteria
|
||||
:return: dataframe|None - returns the data in a dataframe or None if the query fails.
|
||||
"""
|
||||
with SQLite() as con:
|
||||
qry = f"select * from {table} where {filter_vals[0]}='{filter_vals[1]}'"
|
||||
result = pd.read_sql(qry, con=con)
|
||||
if not result.empty:
|
||||
return result
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def insert_dataframe(df, table):
|
||||
# Connect to the database.
|
||||
with SQLite() as con:
|
||||
# Insert the modified user as a new record in the table.
|
||||
df.to_sql(name=table, con=con, index=False, if_exists='append')
|
||||
# Commit the changes to the database.
|
||||
con.commit()
|
||||
|
||||
@staticmethod
|
||||
def insert_row(table: str, columns: tuple, values: tuple) -> None:
|
||||
"""
|
||||
Saves user specific data from a table in the database.
|
||||
|
||||
:param table: str - The table to insert into
|
||||
:param columns: tuple(str1, str2, ...) - The columns of the database.
|
||||
:param values: tuple(val1, val2, ...) - The values to be inserted.
|
||||
:return: None
|
||||
"""
|
||||
# Connect to the database.
|
||||
with SQLite() as conn:
|
||||
# Get a cursor from the sql connection.
|
||||
cursor = conn.cursor()
|
||||
sql = make_insert(table=table, values=columns)
|
||||
cursor.execute(sql, values)
|
||||
|
||||
@staticmethod
|
||||
def _table_exists(table_name: str) -> bool:
|
||||
"""
|
||||
Returns True if table_name exists in the database.
|
||||
|
||||
:param table_name: The name of the database.
|
||||
:return: bool - True|False
|
||||
"""
|
||||
# Connect to the database.
|
||||
with SQLite() as conn:
|
||||
# Get a cursor from the sql connection.
|
||||
cursor = conn.cursor()
|
||||
# sql = f"SELECT name FROM sqlite_schema WHERE type = 'table' AND name = '{table_name}';"
|
||||
sql = make_query(item='name', table='sqlite_schema', columns=['type', 'name'])
|
||||
# Check if the table exists.
|
||||
cursor.execute(sql, ('table', table_name))
|
||||
# Fetch the results from the cursor.
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
# If the table doesn't exist return False.
|
||||
return False
|
||||
return True
|
||||
|
||||
def _populate_table(self, table_name: str, start_time: dt.datetime, ex_details: list, end_time: dt.datetime = None):
|
||||
"""
|
||||
Populates a database table with records from the exchange_name.
|
||||
:param table_name: str - The name of the table in the database.
|
||||
:param start_time: datetime - The starting time to fetch the records from.
|
||||
:param end_time: datetime - The end time to get the records until.
|
||||
:return: pdDataframe: - The data that was downloaded.
|
||||
"""
|
||||
# Set the default end_time to UTC now.
|
||||
if end_time is None:
|
||||
end_time = dt.datetime.utcnow()
|
||||
# Fetch the records from the exchange_name.
|
||||
# Extract the parameters from the details. Format: <symbol>_<timeframe>_<exchange_name>.
|
||||
sym, inter, ex, un = ex_details
|
||||
records = self._fetch_candles_from_exchange(symbol=sym, interval=inter, exchange_name=ex, user_name=un,
|
||||
start_datetime=start_time, end_datetime=end_time)
|
||||
# Update the database.
|
||||
if not records.empty:
|
||||
# Inert into the database any received records.
|
||||
self._insert_candles_into_db(records, table_name=table_name, symbol=sym, exchange_name=ex)
|
||||
else:
|
||||
print(f'No records inserted {records}')
|
||||
return records
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=1000)
|
||||
def get_from_static_table(item: str, table: str, indexes: HDict, create_id: bool = False) -> Any:
|
||||
"""
|
||||
Retrieves a single item from a table. This method returns a cached result and is ment
|
||||
for fetching static data like settings, names and ID's.
|
||||
|
||||
:param create_id: bool: - If True, create a row if it doesn't exist and return the autoincrement ID.
|
||||
:param item: str - The name of the item requested.
|
||||
:param table: str - The table being queried.
|
||||
:param indexes: str - A hashable dictionary containing the indexing columns and their values.
|
||||
:return: Any - The content of the field.
|
||||
"""
|
||||
|
||||
# Connect to the database.
|
||||
with SQLite() as conn:
|
||||
# Get a cursor from the sql connection.
|
||||
cursor = conn.cursor()
|
||||
# Retrieve the record from the db.
|
||||
sql = make_query(item, table, list(indexes.keys()))
|
||||
cursor.execute(sql, tuple(indexes.values()))
|
||||
# The result is returned as tuple. Example: (id,)
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result is None and create_id is True:
|
||||
# Insert the indexes into the db.
|
||||
sql = make_insert(table, tuple(indexes.keys()))
|
||||
cursor.execute(sql, tuple(indexes.values()))
|
||||
# Retrieve the record from the db.
|
||||
sql = make_query(item, table, list(indexes.keys()))
|
||||
cursor.execute(sql, tuple(indexes.values()))
|
||||
# Get the first element of the tuple received from sql query.
|
||||
result = cursor.fetchone()
|
||||
|
||||
# Return the result from the tuple if it exists.
|
||||
if result:
|
||||
return result[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _fetch_exchange_id(self, exchange_name: str) -> int:
|
||||
"""
|
||||
Fetch the primary id of exchange_name from the database.
|
||||
|
||||
:param exchange_name: str - The name of the exchange_name.
|
||||
:return: int - The primary id of the exchange_name.
|
||||
"""
|
||||
return self.get_from_static_table(item='id', table='exchange', create_id=True,indexes=HDict({'name': exchange_name}))
|
||||
|
||||
def _fetch_market_id(self, symbol: str, exchange_name: str) -> int:
|
||||
"""
|
||||
Returns the market id that belongs to a trading pair listed in the database.
|
||||
|
||||
:param symbol: str - The symbol of the trading pair.
|
||||
:param exchange_name: str - The exchange_name name.
|
||||
:return: int - The market ID
|
||||
"""
|
||||
# Fetch the id of the exchange_name.
|
||||
exchange_id = self._fetch_exchange_id(exchange_name)
|
||||
|
||||
# Ask the db for the market_id. Tell it to create one if it doesn't already exist.
|
||||
market_id = self.get_from_static_table(item='id', table='markets', create_id=True,
|
||||
indexes=HDict({'symbol': symbol, 'exchange_id': exchange_id}))
|
||||
# Return the market id.
|
||||
return market_id
|
||||
|
||||
def _insert_candles_into_db(self, candlesticks, table_name: str, symbol, exchange_name) -> None:
|
||||
"""
|
||||
Insert all the candlesticks from a dataframe into the database.
|
||||
|
||||
:param exchange_name: The name of the exchange_name.
|
||||
:param symbol: The symbol of the trading pair.
|
||||
:param candlesticks: pd.dataframe - A rows of candlestick attributes.
|
||||
:param table_name: str - The name of the table to inset.
|
||||
:return: None
|
||||
"""
|
||||
|
||||
# Retrieve the market id for the symbol.
|
||||
market_id = self._fetch_market_id(symbol, exchange_name)
|
||||
# Insert the market id into the dataframe.
|
||||
candlesticks.insert(0, 'market_id', market_id)
|
||||
# Create a table schema. todo delete these line if not needed anymore
|
||||
# # Get a list of all the columns in the dataframe.
|
||||
# columns = list(candlesticks.columns.values)
|
||||
# # Isolate any extra columns specific to individual exchanges.
|
||||
# # The carriage return and tabs are unnecessary, they just tidy output for debugging.
|
||||
# columns = ',\n\t\t\t\t\t'.join(columns[7:], )
|
||||
# # Define the columns common with all exchanges and append any extras columns.
|
||||
sql_create = f"""
|
||||
CREATE TABLE IF NOT EXISTS '{table_name}' (
|
||||
id INTEGER PRIMARY KEY,
|
||||
market_id INTEGER,
|
||||
open_time INTEGER UNIQUE ON CONFLICT IGNORE,
|
||||
open REAL NOT NULL,
|
||||
high REAL NOT NULL,
|
||||
low REAL NOT NULL,
|
||||
close REAL NOT NULL,
|
||||
volume REAL NOT NULL,
|
||||
FOREIGN KEY (market_id) REFERENCES market (id)
|
||||
)"""
|
||||
# Connect to the database.
|
||||
with SQLite() as conn:
|
||||
# Get a cursor from the sql connection.
|
||||
cursor = conn.cursor()
|
||||
# Create the table if it doesn't exist.
|
||||
cursor.execute(sql_create)
|
||||
# Insert the candles into the table.
|
||||
candlesticks.to_sql(table_name, conn, if_exists='append', index=False)
|
||||
return
|
||||
|
||||
def get_records_since(self, table_name: str, st: dt.datetime,
|
||||
et: dt.datetime, rl: float, ex_details: list) -> pd.DataFrame:
|
||||
"""
|
||||
Returns all the candles newer than the provided start_datetime from the specified table.
|
||||
|
||||
:param ex_details: list of details to pass to the server. [symbol, interval, exchange_name]
|
||||
:param table_name: str - The database table name. Format: : <symbol>_<timeframe>_<exchange_name>.
|
||||
:param st: dt.datetime.start_datetime - The start_datetime of the first record requested.
|
||||
:param et: dt.datetime - The end time of the query
|
||||
:param rl: float - The timespan in minutes each record represents.
|
||||
:return: pd.dataframe -
|
||||
"""
|
||||
|
||||
def add_data(data, tn, start_t, end_t):
|
||||
new_records = self._populate_table(table_name=tn, start_time=start_t, end_time=end_t, ex_details=ex_details)
|
||||
print(f'Got {len(new_records.index)} records from exchange_name')
|
||||
if not new_records.empty:
|
||||
# Combine the new records with the previously records.
|
||||
data = pd.concat([data, new_records], axis=0, ignore_index=True)
|
||||
# Drop any duplicates from overlap.
|
||||
data = data.drop_duplicates(subset="open_time", keep='first')
|
||||
# Return the modified dataframe.
|
||||
return data
|
||||
|
||||
if self._table_exists(table_name=table_name):
|
||||
# If the table exists retrieve all the records.
|
||||
print('\nTable existed retrieving records from DB')
|
||||
print(f'Requesting from {st} to {et}')
|
||||
records = self._get_records(table_name=table_name, st=st, et=et)
|
||||
print(f'Got {len(records.index)} records from db')
|
||||
else:
|
||||
# If the table doesn't exist, get them from the exchange_name.
|
||||
print(f'\nTable didnt exist fetching from {ex_details[2]}')
|
||||
temp = (((unix_time_millis(et) - unix_time_millis(st)) / 1000) / 60) / rl
|
||||
print(f'Requesting from {st} to {et}, Should be {temp} records')
|
||||
records = self._populate_table(table_name=table_name, start_time=st, end_time=et, ex_details=ex_details)
|
||||
print(f'Got {len(records.index)} records from {ex_details[2]}')
|
||||
|
||||
# Check if the records in the db go far enough back to satisfy the query.
|
||||
first_timestamp = query_satisfied(start_datetime=st, records=records, r_length_min=rl)
|
||||
if first_timestamp:
|
||||
# The records didn't go far enough back if a timestamp was returned.
|
||||
print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}')
|
||||
print(f'first ts on record is: {first_timestamp}')
|
||||
end_time = dt.datetime.utcfromtimestamp(first_timestamp)
|
||||
print(f'Requesting from {st} to {end_time}')
|
||||
# Request records with open_times between [st:end_time] from the database.
|
||||
records = add_data(data=records, tn=table_name, start_t=st, end_t=end_time)
|
||||
|
||||
# Check if the records received are up-to-date.
|
||||
last_timestamp = query_uptodate(records=records, r_length_min=rl)
|
||||
if last_timestamp:
|
||||
# The query was not up-to-date if a timestamp was returned.
|
||||
print(f'\nRecords were not updated. Requesting from {ex_details[2]}.')
|
||||
print(f'the last record on file is: {last_timestamp}')
|
||||
start_time = dt.datetime.utcfromtimestamp(last_timestamp)
|
||||
print(f'Requesting from {start_time} to {et}')
|
||||
# Request records with open_times between [start_time:et] from the database.
|
||||
records = add_data(data=records, tn=table_name, start_t=start_time, end_t=et)
|
||||
|
||||
return records
|
||||
|
||||
@staticmethod
|
||||
def _get_records(table_name: str, st: dt.datetime, et: dt.datetime = None) -> pd.DataFrame:
|
||||
"""
|
||||
Returns all the candles newer than the provided start_datetime from the specified table.
|
||||
|
||||
:param table_name: str - The database table name. Format: : <symbol>_<timeframe>_<exchange_name>.
|
||||
:param st: dt.datetime.start_datetime - The start_datetime of the first record requested.
|
||||
:param et: dt.datetime - The end time of the query
|
||||
:return: pd.dataframe -
|
||||
"""
|
||||
# Connect to the database.
|
||||
with SQLite() as conn:
|
||||
# Create a timestamp in milliseconds
|
||||
start_stamp = unix_time_millis(st)
|
||||
if et is not None:
|
||||
# Create a timestamp in milliseconds
|
||||
end_stamp = unix_time_millis(et)
|
||||
q_str = f"SELECT * FROM '{table_name}' WHERE open_time >= {start_stamp} AND open_time <= {end_stamp};"
|
||||
else:
|
||||
q_str = f"SELECT * FROM '{table_name}' WHERE open_time >= {start_stamp};"
|
||||
# Retrieve all the records from the table.
|
||||
records = pd.read_sql(q_str, conn)
|
||||
# Drop the databases primary id.
|
||||
records = records.drop('id', axis=1)
|
||||
# Return the data.
|
||||
return records
|
||||
|
||||
# @staticmethod
|
||||
# def date_of_last_timestamp(table_name):
|
||||
# """
|
||||
# Returns the latest timestamp stored in the db.
|
||||
# TODO: Unused.
|
||||
#
|
||||
# :return: dt.timestamp
|
||||
# """
|
||||
# # Connect to the database.
|
||||
# with SQLite() as conn:
|
||||
# # Get a cursor from the connection.
|
||||
# cursor = conn.cursor()
|
||||
# cursor.execute(f"""SELECT open_time FROM '{table_name}' ORDER BY open_time DESC LIMIT 1""")
|
||||
# ts = cursor.fetchone()[0] / 1000
|
||||
# return dt.datetime.utcfromtimestamp(ts)
|
||||
#
|
||||
def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str,
|
||||
start_datetime: object = None, end_datetime: object = None) -> pd.DataFrame:
|
||||
"""
|
||||
Fetches and returns all candles from the specified market, timeframe, and exchange_name.
|
||||
|
||||
:param symbol: str - The symbol of the market.
|
||||
:param interval: str - The timeframe. Format '<int><alpha>' - examples: '15m', '4h'
|
||||
:param exchange_name: str - The name of the exchange_name.
|
||||
:param start_datetime: dt.datetime - The open_time of the first record requested.
|
||||
:param end_datetime: dt.datetime - The end_time for the query.
|
||||
:return: pd.DataFrame: Dataframe containing rows of candle attributes that vary
|
||||
depending on the exchange_name.
|
||||
For example: [open_time, open, high, low, close, volume, close_time,
|
||||
quote_volume, num_trades, taker_buy_base_volume, taker_buy_quote_volume]
|
||||
"""
|
||||
|
||||
def fill_data_holes(records, interval):
|
||||
time_span = timeframe_to_minutes(interval)
|
||||
last_timestamp = None
|
||||
filled_records = []
|
||||
|
||||
for _, row in records.iterrows():
|
||||
time_stamp = row['open_time']
|
||||
|
||||
if last_timestamp is None:
|
||||
last_timestamp = time_stamp
|
||||
filled_records.append(row)
|
||||
continue
|
||||
|
||||
delta_ms = time_stamp - last_timestamp
|
||||
delta_minutes = (delta_ms / 1000) / 60
|
||||
|
||||
if delta_minutes > time_span:
|
||||
num_missing_rec = int(delta_minutes / time_span)
|
||||
step = int(delta_ms / num_missing_rec)
|
||||
|
||||
for ts in range(int(last_timestamp) + step, int(time_stamp), step):
|
||||
new_row = row.copy()
|
||||
new_row['open_time'] = ts
|
||||
filled_records.append(new_row)
|
||||
|
||||
filled_records.append(row)
|
||||
last_timestamp = time_stamp
|
||||
|
||||
return pd.DataFrame(filled_records)
|
||||
|
||||
# Default start date for fetching from the exchange_name.
|
||||
if start_datetime is None:
|
||||
start_datetime = dt.datetime(year=2017, month=1, day=1)
|
||||
|
||||
# Default end date for fetching from the exchange_name.
|
||||
if end_datetime is None:
|
||||
end_datetime = dt.datetime.utcnow()
|
||||
|
||||
if start_datetime > end_datetime:
|
||||
raise ValueError("\ndatabase:fetch_candles_from_exchange():"
|
||||
" Invalid start and end parameters: ", start_datetime, end_datetime)
|
||||
|
||||
# Get a reference to the exchange
|
||||
exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name)
|
||||
|
||||
temp = (((unix_time_millis(end_datetime) - unix_time_millis(
|
||||
start_datetime)) / 1000) / 60) / timeframe_to_minutes(interval)
|
||||
print(f'Fetching historical data {start_datetime} to {end_datetime}, Should be {temp} records')
|
||||
|
||||
if start_datetime == end_datetime:
|
||||
end_datetime = None
|
||||
|
||||
# Request candlestick data from the exchange_name.
|
||||
candles = exchange.get_historical_klines(symbol=symbol,
|
||||
interval=interval,
|
||||
start_dt=start_datetime,
|
||||
end_dt=end_datetime)
|
||||
num_rec_records = len(candles.index)
|
||||
print(f'\n{num_rec_records} candles retrieved from the exchange_name.')
|
||||
# Isolate the open_times from the records received.
|
||||
open_times = candles.open_time
|
||||
# Calculate the number of records that would fit between the min and max open time.
|
||||
estimated_num_records = (((open_times.max() - open_times.min()) / 1000) / 60) / timeframe_to_minutes(interval)
|
||||
if num_rec_records < estimated_num_records:
|
||||
# Some records may be missing due to server maintenance periods ect.
|
||||
# Fill the holes with copies of the last record received before the gap.
|
||||
candles = fill_data_holes(candles, interval)
|
||||
return candles
|
||||
|
|
@ -3,15 +3,53 @@ from exchangeinterface import ExchangeInterface
|
|||
import unittest
|
||||
import pandas as pd
|
||||
import datetime as dt
|
||||
import os
|
||||
from Database import SQLite, Database
|
||||
from shared_utilities import unix_time_millis
|
||||
|
||||
|
||||
class TestDataCache(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Setup the database connection here
|
||||
# Set the database connection here
|
||||
self.exchanges = ExchangeInterface()
|
||||
self.exchanges.connect_exchange(exchange_name='binance', user_name='test_guy', api_keys=None)
|
||||
# This object maintains all the cached data. Pass it connection to the exchanges.
|
||||
self.db_file = 'test_db.sqlite'
|
||||
self.database = Database(db_file=self.db_file)
|
||||
|
||||
# Create necessary tables
|
||||
with SQLite(db_file=self.db_file) as con:
|
||||
cursor = con.cursor()
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS exchange (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT UNIQUE
|
||||
)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS markets (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
symbol TEXT,
|
||||
exchange_id INTEGER,
|
||||
FOREIGN KEY (exchange_id) REFERENCES exchange(id)
|
||||
)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS test_table (
|
||||
id INTEGER PRIMARY KEY,
|
||||
market_id INTEGER,
|
||||
open_time INTEGER UNIQUE,
|
||||
open REAL NOT NULL,
|
||||
high REAL NOT NULL,
|
||||
low REAL NOT NULL,
|
||||
close REAL NOT NULL,
|
||||
volume REAL NOT NULL,
|
||||
FOREIGN KEY (market_id) REFERENCES markets(id)
|
||||
)
|
||||
""")
|
||||
|
||||
self.data = DataCache(self.exchanges)
|
||||
self.data.db = self.database
|
||||
|
||||
asset, timeframe, exchange = 'BTC/USD', '2h', 'binance'
|
||||
self.key1 = f'{asset}_{timeframe}_{exchange}'
|
||||
|
|
@ -19,8 +57,11 @@ class TestDataCache(unittest.TestCase):
|
|||
asset, timeframe, exchange = 'ETH/USD', '2h', 'binance'
|
||||
self.key2 = f'{asset}_{timeframe}_{exchange}'
|
||||
|
||||
def tearDown(self):
|
||||
if os.path.exists(self.db_file):
|
||||
os.remove(self.db_file)
|
||||
|
||||
def test_set_cache(self):
|
||||
# Tests
|
||||
print('Testing set_cache flag not set:')
|
||||
self.data.set_cache(data='data', key=self.key1)
|
||||
attr = self.data.__getattribute__('cached_data')
|
||||
|
|
@ -36,7 +77,6 @@ class TestDataCache(unittest.TestCase):
|
|||
self.assertEqual(attr[self.key1], 'more_data')
|
||||
|
||||
def test_cache_exists(self):
|
||||
# Tests
|
||||
print('Testing cache_exists() method:')
|
||||
self.assertFalse(self.data.cache_exists(key=self.key2))
|
||||
self.data.set_cache(data='data', key=self.key1)
|
||||
|
|
@ -44,7 +84,6 @@ class TestDataCache(unittest.TestCase):
|
|||
|
||||
def test_update_candle_cache(self):
|
||||
print('Testing update_candle_cache() method:')
|
||||
# Initial data
|
||||
df_initial = pd.DataFrame({
|
||||
'open_time': [1, 2, 3],
|
||||
'open': [100, 101, 102],
|
||||
|
|
@ -54,7 +93,6 @@ class TestDataCache(unittest.TestCase):
|
|||
'volume': [1000, 1001, 1002]
|
||||
})
|
||||
|
||||
# Data to be added
|
||||
df_new = pd.DataFrame({
|
||||
'open_time': [3, 4, 5],
|
||||
'open': [102, 103, 104],
|
||||
|
|
@ -96,7 +134,7 @@ class TestDataCache(unittest.TestCase):
|
|||
def test_get_records_since(self):
|
||||
print('Testing get_records_since() method:')
|
||||
df_initial = pd.DataFrame({
|
||||
'open_time': [1, 2, 3],
|
||||
'open_time': [unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=i)) for i in range(3)],
|
||||
'open': [100, 101, 102],
|
||||
'high': [110, 111, 112],
|
||||
'low': [90, 91, 92],
|
||||
|
|
@ -105,20 +143,86 @@ class TestDataCache(unittest.TestCase):
|
|||
})
|
||||
|
||||
self.data.set_cache(data=df_initial, key=self.key1)
|
||||
start_datetime = dt.datetime.utcfromtimestamp(2)
|
||||
result = self.data.get_records_since(key=self.key1, start_datetime=start_datetime, record_length=60, ex_details=[]).sort_values(by='open_time').reset_index(drop=True)
|
||||
start_datetime = dt.datetime.utcnow() - dt.timedelta(minutes=2)
|
||||
result = self.data.get_records_since(key=self.key1, start_datetime=start_datetime, record_length=60,
|
||||
ex_details=['BTC/USD', '2h', 'binance'])
|
||||
|
||||
expected = pd.DataFrame({
|
||||
'open_time': [2, 3],
|
||||
'open': [101, 102],
|
||||
'high': [111, 112],
|
||||
'low': [91, 92],
|
||||
'close': [106, 107],
|
||||
'volume': [1001, 1002]
|
||||
'open_time': df_initial['open_time'][:2].values,
|
||||
'open': [100, 101],
|
||||
'high': [110, 111],
|
||||
'low': [90, 91],
|
||||
'close': [105, 106],
|
||||
'volume': [1000, 1001]
|
||||
})
|
||||
|
||||
pd.testing.assert_frame_equal(result, expected)
|
||||
|
||||
def test_get_records_since_from_db(self):
|
||||
print('Testing get_records_since_from_db() method:')
|
||||
df_initial = pd.DataFrame({
|
||||
'market_id': [None],
|
||||
'open_time': [unix_time_millis(dt.datetime.utcnow())],
|
||||
'open': [1.0],
|
||||
'high': [1.0],
|
||||
'low': [1.0],
|
||||
'close': [1.0],
|
||||
'volume': [1.0]
|
||||
})
|
||||
|
||||
with SQLite(self.db_file) as con:
|
||||
df_initial.to_sql('test_table', con, if_exists='append', index=False)
|
||||
|
||||
start_datetime = dt.datetime.utcnow() - dt.timedelta(minutes=1)
|
||||
end_datetime = dt.datetime.utcnow()
|
||||
result = self.data.get_records_since_from_db(table_name='test_table', st=start_datetime, et=end_datetime,
|
||||
rl=1, ex_details=['BTC/USD', '2h', 'binance']).sort_values(
|
||||
by='open_time').reset_index(drop=True)
|
||||
|
||||
print("Columns in the result DataFrame:", result.columns)
|
||||
print("Result DataFrame:\n", result)
|
||||
|
||||
# Remove 'id' column from the result DataFrame if it exists
|
||||
if 'id' in result.columns:
|
||||
result = result.drop(columns=['id'])
|
||||
|
||||
expected = pd.DataFrame({
|
||||
'market_id': [None],
|
||||
'open_time': [unix_time_millis(dt.datetime.utcnow())],
|
||||
'open': [1.0],
|
||||
'high': [1.0],
|
||||
'low': [1.0],
|
||||
'close': [1.0],
|
||||
'volume': [1.0]
|
||||
})
|
||||
|
||||
print("Expected DataFrame:\n", expected)
|
||||
|
||||
pd.testing.assert_frame_equal(result, expected)
|
||||
|
||||
def test_populate_db(self):
|
||||
print('Testing _populate_db() method:')
|
||||
start_time = dt.datetime.utcnow() - dt.timedelta(days=1)
|
||||
end_time = dt.datetime.utcnow()
|
||||
|
||||
result = self.data._populate_db(table_name='test_table', start_time=start_time,
|
||||
end_time=end_time, ex_details=['BTC/USD', '2h', 'binance', 'test_guy'])
|
||||
|
||||
self.assertIsInstance(result, pd.DataFrame)
|
||||
self.assertFalse(result.empty)
|
||||
|
||||
def test_fetch_candles_from_exchange(self):
|
||||
print('Testing _fetch_candles_from_exchange() method:')
|
||||
start_time = dt.datetime.utcnow() - dt.timedelta(days=1)
|
||||
end_time = dt.datetime.utcnow()
|
||||
|
||||
result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h', exchange_name='binance',
|
||||
user_name='test_guy', start_datetime=start_time,
|
||||
end_datetime=end_time)
|
||||
|
||||
self.assertIsInstance(result, pd.DataFrame)
|
||||
self.assertFalse(result.empty)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1,96 +1,219 @@
|
|||
import datetime
|
||||
import unittest
|
||||
import sqlite3
|
||||
import pandas as pd
|
||||
from database import make_query, make_insert, Database, HDict, SQLite
|
||||
from exchangeinterface import ExchangeInterface
|
||||
from passlib.hash import bcrypt
|
||||
from sqlalchemy import create_engine, text
|
||||
import config
|
||||
import datetime as dt
|
||||
from Database import Database, SQLite, make_query, make_insert, HDict
|
||||
from shared_utilities import unix_time_millis
|
||||
|
||||
|
||||
def test():
|
||||
# un_hashed_pass = 'password'
|
||||
hasher = bcrypt.using(rounds=13)
|
||||
# hashed_pass = hasher.hash(un_hashed_pass)
|
||||
# print(f'password: {un_hashed_pass}')
|
||||
# print(f'hashed pass: {hashed_pass}')
|
||||
# print(f" right pass: {hasher.verify('password', hashed_pass)}")
|
||||
# print(f" wrong pass: {hasher.verify('passWord', hashed_pass)}")
|
||||
engine = create_engine("sqlite:///" + config.DB_FILE, echo=True)
|
||||
with engine.connect() as conn:
|
||||
default_user = pd.read_sql_query(sql=text("SELECT * FROM users WHERE user_name = 'guest'"), con=conn)
|
||||
# hashed_password = default_user.password.values[0]
|
||||
# print(f" verify pass: {hasher.verify('password', hashed_password)}")
|
||||
username = default_user.user_name.values[0]
|
||||
print(username)
|
||||
class TestSQLite(unittest.TestCase):
|
||||
def test_sqlite_context_manager(self):
|
||||
print("\nRunning test_sqlite_context_manager...")
|
||||
with SQLite(db_file='test_db.sqlite') as con:
|
||||
cursor = con.cursor()
|
||||
cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||
cursor.execute("INSERT INTO test_table (name) VALUES ('test')")
|
||||
cursor.execute('SELECT name FROM test_table WHERE name = ?', ('test',))
|
||||
result = cursor.fetchone()
|
||||
self.assertEqual(result[0], 'test')
|
||||
print("SQLite context manager test passed.")
|
||||
|
||||
|
||||
def test_make_query():
|
||||
values = {'first_field': 'first_value'}
|
||||
item = 'market_id'
|
||||
table = 'markets'
|
||||
q_str = make_query(item=item, table=table, values=values)
|
||||
print(f'\nWith one indexing field: {q_str}')
|
||||
class TestDatabase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Use a temporary SQLite database for testing purposes
|
||||
self.db_file = 'test_db.sqlite'
|
||||
self.db = Database(db_file=self.db_file)
|
||||
self.connection = sqlite3.connect(self.db_file)
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
values = {'first_field': 'first_value', 'second_field': 'second_value'}
|
||||
q_str = make_query(item=item, table=table, values=values)
|
||||
print(f'\nWith two indexing fields: {q_str}')
|
||||
assert q_str is not None
|
||||
def tearDown(self):
|
||||
self.connection.close()
|
||||
import os
|
||||
os.remove(self.db_file) # Remove the temporary database file after tests
|
||||
|
||||
def test_execute_sql(self):
|
||||
print("\nRunning test_execute_sql...")
|
||||
# Drop the table if it exists to avoid OperationalError
|
||||
self.cursor.execute('DROP TABLE IF EXISTS test_table')
|
||||
self.connection.commit()
|
||||
|
||||
sql = 'CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)'
|
||||
self.db.execute_sql(sql)
|
||||
|
||||
self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test_table';")
|
||||
result = self.cursor.fetchone()
|
||||
self.assertIsNotNone(result)
|
||||
print("Execute SQL test passed.")
|
||||
|
||||
def test_make_query(self):
|
||||
print("\nRunning test_make_query...")
|
||||
query = make_query('id', 'test_table', ['name'])
|
||||
expected_query = 'SELECT id FROM test_table WHERE name = ?;'
|
||||
self.assertEqual(query, expected_query)
|
||||
print("Make query test passed.")
|
||||
|
||||
def test_make_insert(self):
|
||||
print("\nRunning test_make_insert...")
|
||||
insert = make_insert('test_table', ('name', 'age'))
|
||||
expected_insert = "INSERT INTO test_table ('name', 'age') VALUES(?, ?);"
|
||||
self.assertEqual(insert, expected_insert)
|
||||
print("Make insert test passed.")
|
||||
|
||||
def test_get_item_where(self):
|
||||
print("\nRunning test_get_item_where...")
|
||||
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||
self.cursor.execute("INSERT INTO test_table (id, name) VALUES (1, 'test')")
|
||||
self.connection.commit()
|
||||
item = self.db.get_item_where('name', 'test_table', ('id', 1))
|
||||
self.assertEqual(item, 'test')
|
||||
print("Get item where test passed.")
|
||||
|
||||
def test_get_rows_where(self):
|
||||
print("\nRunning test_get_rows_where...")
|
||||
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||
self.cursor.execute("INSERT INTO test_table (id, name) VALUES (1, 'test')")
|
||||
self.connection.commit()
|
||||
rows = self.db.get_rows_where('test_table', ('name', 'test'))
|
||||
self.assertIsInstance(rows, pd.DataFrame)
|
||||
self.assertEqual(rows.iloc[0]['name'], 'test')
|
||||
print("Get rows where test passed.")
|
||||
|
||||
def test_insert_dataframe(self):
|
||||
print("\nRunning test_insert_dataframe...")
|
||||
df = pd.DataFrame({'id': [1], 'name': ['test']})
|
||||
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||
self.connection.commit()
|
||||
self.db.insert_dataframe(df, 'test_table')
|
||||
self.cursor.execute('SELECT name FROM test_table WHERE id = 1')
|
||||
result = self.cursor.fetchone()
|
||||
self.assertEqual(result[0], 'test')
|
||||
print("Insert dataframe test passed.")
|
||||
|
||||
def test_insert_row(self):
|
||||
print("\nRunning test_insert_row...")
|
||||
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||
self.connection.commit()
|
||||
self.db.insert_row('test_table', ('id', 'name'), (1, 'test'))
|
||||
self.cursor.execute('SELECT name FROM test_table WHERE id = 1')
|
||||
result = self.cursor.fetchone()
|
||||
self.assertEqual(result[0], 'test')
|
||||
print("Insert row test passed.")
|
||||
|
||||
def test_table_exists(self):
|
||||
print("\nRunning test_table_exists...")
|
||||
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||
self.connection.commit()
|
||||
exists = self.db.table_exists('test_table')
|
||||
self.assertTrue(exists)
|
||||
print("Table exists test passed.")
|
||||
|
||||
def test_get_timestamped_records(self):
|
||||
print("\nRunning test_get_timestamped_records...")
|
||||
df = pd.DataFrame({
|
||||
'open_time': [unix_time_millis(dt.datetime.utcnow())],
|
||||
'open': [1.0],
|
||||
'high': [1.0],
|
||||
'low': [1.0],
|
||||
'close': [1.0],
|
||||
'volume': [1.0]
|
||||
})
|
||||
table_name = 'test_table'
|
||||
self.cursor.execute(f"""
|
||||
CREATE TABLE {table_name} (
|
||||
id INTEGER PRIMARY KEY,
|
||||
open_time INTEGER UNIQUE,
|
||||
open REAL NOT NULL,
|
||||
high REAL NOT NULL,
|
||||
low REAL NOT NULL,
|
||||
close REAL NOT NULL,
|
||||
volume REAL NOT NULL
|
||||
)
|
||||
""")
|
||||
self.connection.commit()
|
||||
self.db.insert_dataframe(df, table_name)
|
||||
st = dt.datetime.utcnow() - dt.timedelta(minutes=1)
|
||||
et = dt.datetime.utcnow()
|
||||
records = self.db.get_timestamped_records(table_name, 'open_time', st, et)
|
||||
self.assertIsInstance(records, pd.DataFrame)
|
||||
self.assertFalse(records.empty)
|
||||
print("Get timestamped records test passed.")
|
||||
|
||||
def test_get_from_static_table(self):
|
||||
print("\nRunning test_get_from_static_table...")
|
||||
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT UNIQUE)')
|
||||
self.connection.commit()
|
||||
item = self.db.get_from_static_table('id', 'test_table', HDict({'name': 'test'}), create_id=True)
|
||||
self.assertIsInstance(item, int)
|
||||
self.cursor.execute('SELECT id FROM test_table WHERE name = ?', ('test',))
|
||||
result = self.cursor.fetchone()
|
||||
self.assertEqual(item, result[0])
|
||||
print("Get from static table test passed.")
|
||||
|
||||
def test_insert_candles_into_db(self):
|
||||
print("\nRunning test_insert_candles_into_db...")
|
||||
df = pd.DataFrame({
|
||||
'open_time': [unix_time_millis(dt.datetime.utcnow())],
|
||||
'open': [1.0],
|
||||
'high': [1.0],
|
||||
'low': [1.0],
|
||||
'close': [1.0],
|
||||
'volume': [1.0]
|
||||
})
|
||||
table_name = 'test_table'
|
||||
self.cursor.execute(f"""
|
||||
CREATE TABLE {table_name} (
|
||||
id INTEGER PRIMARY KEY,
|
||||
market_id INTEGER,
|
||||
open_time INTEGER UNIQUE,
|
||||
open REAL NOT NULL,
|
||||
high REAL NOT NULL,
|
||||
low REAL NOT NULL,
|
||||
close REAL NOT NULL,
|
||||
volume REAL NOT NULL
|
||||
)
|
||||
""")
|
||||
self.connection.commit()
|
||||
|
||||
# Create the exchange and markets tables needed for the foreign key constraints
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE exchange (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT UNIQUE
|
||||
)
|
||||
""")
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE markets (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
symbol TEXT,
|
||||
exchange_id INTEGER,
|
||||
FOREIGN KEY (exchange_id) REFERENCES exchange(id)
|
||||
)
|
||||
""")
|
||||
self.connection.commit()
|
||||
|
||||
self.db.insert_candles_into_db(df, table_name, 'BTC/USDT', 'binance')
|
||||
self.cursor.execute(f'SELECT * FROM {table_name}')
|
||||
result = self.cursor.fetchall()
|
||||
self.assertFalse(len(result) == 0)
|
||||
print("Insert candles into db test passed.")
|
||||
|
||||
|
||||
def test_make_insert():
|
||||
table = 'markets'
|
||||
values = {'first_field': 'first_value'}
|
||||
q_str = make_insert(table=table, values=values)
|
||||
print(f'\nWith one indexing field: {q_str}')
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
values = {'first_field': 'first_value', 'second_field': 'second_value'}
|
||||
q_str = make_insert(table=table, values=values)
|
||||
print(f'\nWith two indexing fields: {q_str}')
|
||||
assert q_str is not None
|
||||
|
||||
|
||||
def test__table_exists():
|
||||
exchanges = ExchangeInterface()
|
||||
d_obj = Database(exchanges)
|
||||
exists = d_obj._table_exists('BTC/USD_5m_alpaca')
|
||||
print(f'\nExists - Should be true: {exists}')
|
||||
assert exists
|
||||
exists = d_obj._table_exists('BTC/USD_5m_alpina')
|
||||
print(f'Doesnt exist - should be false: {exists}')
|
||||
assert not exists
|
||||
|
||||
|
||||
def test_get_from_static_table():
|
||||
exchanges = ExchangeInterface()
|
||||
d_obj = Database(exchanges)
|
||||
market_id = d_obj._fetch_market_id('BTC/USD', 'alpaca')
|
||||
e_id = d_obj._fetch_exchange_id('alpaca')
|
||||
print(f'market id: {market_id}')
|
||||
assert market_id > 0
|
||||
print(f'exchange_name ID: {e_id}')
|
||||
assert e_id == 4
|
||||
|
||||
|
||||
def test_populate_table():
|
||||
"""
|
||||
Populates a database table with records from the exchange_name.
|
||||
:param table_name: str - The name of the table in the database.
|
||||
:param start_time: datetime - The starting time to fetch the records from.
|
||||
:param end_time: datetime - The end time to get the records until.
|
||||
:return: pdDataframe: - The data that was downloaded.
|
||||
"""
|
||||
exchanges = ExchangeInterface()
|
||||
d_obj = Database(exchanges)
|
||||
d_obj._populate_table(table_name='BTC/USD_2h_alpaca',
|
||||
start_time=datetime.datetime(year=2023, month=3, day=27, hour=6, minute=0))
|
||||
|
||||
|
||||
def test_get_records_since():
|
||||
exchanges = ExchangeInterface()
|
||||
d_obj = Database(exchanges)
|
||||
records = d_obj.get_records_since(table_name='BTC/USD_15m_alpaca',
|
||||
st=datetime.datetime(year=2023, month=3, day=27, hour=1, minute=0),
|
||||
et=datetime.datetime.utcnow(),
|
||||
rl=15)
|
||||
print(records)
|
||||
assert records is not None
|
||||
# def test():
|
||||
# # un_hashed_pass = 'password'
|
||||
# hasher = bcrypt.using(rounds=13)
|
||||
# # hashed_pass = hasher.hash(un_hashed_pass)
|
||||
# # print(f'password: {un_hashed_pass}')
|
||||
# # print(f'hashed pass: {hashed_pass}')
|
||||
# # print(f" right pass: {hasher.verify('password', hashed_pass)}")
|
||||
# # print(f" wrong pass: {hasher.verify('passWord', hashed_pass)}")
|
||||
# engine = create_engine("sqlite:///" + config.DB_FILE, echo=True)
|
||||
# with engine.connect() as conn:
|
||||
# default_user = pd.read_sql_query(sql=text("SELECT * FROM users WHERE user_name = 'guest'"), con=conn)
|
||||
# # hashed_password = default_user.password.values[0]
|
||||
# # print(f" verify pass: {hasher.verify('password', hashed_password)}")
|
||||
# username = default_user.user_name.values[0]
|
||||
# print(username)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ class TestExchange(unittest.TestCase):
|
|||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
exchange_name = 'binance'
|
||||
exchange_name = 'kraken'
|
||||
cls.api_keys = None
|
||||
"""Uncomment and Provide api keys to connect to exchange."""
|
||||
# cls.api_keys = {'key': 'EXCHANGE_API_KEY', 'secret': 'EXCHANGE_API_SECRET'}
|
||||
Loading…
Reference in New Issue