Completed unittests for Database and DataCache.
This commit is contained in:
parent
e601f8c23e
commit
4130e0ca9a
|
|
@ -27,7 +27,7 @@ class BrighterTrades:
|
||||||
self.signals = Signals(self.config.signals_list)
|
self.signals = Signals(self.config.signals_list)
|
||||||
|
|
||||||
# Object that maintains candlestick and price data.
|
# Object that maintains candlestick and price data.
|
||||||
self.candles = Candles(config_obj=self.config, exchanges=self.exchanges, database=self.data)
|
self.candles = Candles(config_obj=self.config, exchanges=self.exchanges, data_source=self.data)
|
||||||
|
|
||||||
# Object that interacts with and maintains data from available indicators
|
# Object that interacts with and maintains data from available indicators
|
||||||
self.indicators = Indicators(self.candles, self.config)
|
self.indicators = Indicators(self.candles, self.config)
|
||||||
|
|
|
||||||
284
src/DataCache.py
284
src/DataCache.py
|
|
@ -1,25 +1,41 @@
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
|
from Database import Database
|
||||||
|
from shared_utilities import query_satisfied, query_uptodate, unix_time_millis, timeframe_to_minutes
|
||||||
|
import logging
|
||||||
|
|
||||||
from database import Database
|
# Set up logging
|
||||||
from shared_utilities import query_satisfied, query_uptodate, unix_time_millis
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DataCache:
|
class DataCache:
|
||||||
"""
|
"""
|
||||||
Fetches manages data limits and optimises memory storage.
|
Fetches manages data limits and optimises memory storage.
|
||||||
|
Handles connections and operations for the given exchanges.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
--------------
|
||||||
|
db = DataCache(exchanges=some_exchanges_object)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, exchanges):
|
def __init__(self, exchanges):
|
||||||
|
"""
|
||||||
|
Initializes the DataCache class.
|
||||||
|
|
||||||
# Define the cache size
|
:param exchanges: The exchanges object handling communication with connected exchanges.
|
||||||
self.max_tables = 50 # Maximum amount of tables that will be cached at any given time.
|
"""
|
||||||
self.max_records = 1000 # Maximum number of records that will be kept per table.
|
# Maximum number of tables to cache at any given time.
|
||||||
self.cached_data = {} # A dictionary that holds all the cached records.
|
self.max_tables = 50
|
||||||
self.db = Database(
|
# Maximum number of records to be kept per table.
|
||||||
exchanges) # The class that handles the DB interactions pass it a connection to the exchange_interface.
|
self.max_records = 1000
|
||||||
|
# A dictionary that holds all the cached records.
|
||||||
|
self.cached_data = {}
|
||||||
|
# The class that handles the DB interactions.
|
||||||
|
self.db = Database()
|
||||||
|
# The class that handles exchange interactions.
|
||||||
|
self.exchanges = exchanges
|
||||||
|
|
||||||
def cache_exists(self, key: str) -> bool:
|
def cache_exists(self, key: str) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
@ -89,27 +105,32 @@ class DataCache:
|
||||||
"""
|
"""
|
||||||
Return any records from the cache indexed by table_name that are newer than start_datetime.
|
Return any records from the cache indexed by table_name that are newer than start_datetime.
|
||||||
|
|
||||||
:param ex_details: Details to pass to the server
|
:param ex_details: List[str] - Details to pass to the server
|
||||||
:param key: <str>: - The dictionary table_name of the records.
|
:param key: str - The dictionary table_name of the records.
|
||||||
:param start_datetime: <dt.datetime>: - The datetime of the first record requested.
|
:param start_datetime: dt.datetime - The datetime of the first record requested.
|
||||||
:param record_length: The timespan of the records.
|
:param record_length: int - The timespan of the records.
|
||||||
|
|
||||||
:return pd.DataFrame: - The Requested records
|
:return: pd.DataFrame - The Requested records
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
records = data_cache.get_records_since('BTC/USD_2h_binance', dt.datetime.utcnow() - dt.timedelta(minutes=60), 60, ['BTC/USD', '2h', 'binance'])
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# End time of query defaults to the current time.
|
# End time of query defaults to the current time.
|
||||||
end_datetime = dt.datetime.utcnow()
|
end_datetime = dt.datetime.utcnow()
|
||||||
|
|
||||||
if self.cache_exists(key=key):
|
if self.cache_exists(key=key):
|
||||||
print('\nGetting records from cache.')
|
logger.debug('Getting records from cache.')
|
||||||
# If the records exist retrieve them from the cache.
|
# If the records exist, retrieve them from the cache.
|
||||||
records = self.get_cache(key)
|
records = self.get_cache(key)
|
||||||
else:
|
else:
|
||||||
# If they don't exist in cache, get them from the db.
|
# If they don't exist in cache, get them from the database.
|
||||||
print(f'\nRecords not in cache. Requesting from DB: starting at: {start_datetime} to: {end_datetime}')
|
logger.debug(
|
||||||
records = self.db.get_records_since(table_name=key, st=start_datetime,
|
f'Records not in cache. Requesting from DB: starting at: {start_datetime} to: {end_datetime}')
|
||||||
|
records = self.get_records_since_from_db(table_name=key, st=start_datetime,
|
||||||
et=end_datetime, rl=record_length, ex_details=ex_details)
|
et=end_datetime, rl=record_length, ex_details=ex_details)
|
||||||
print(f'Got {len(records.index)} records from db')
|
logger.debug(f'Got {len(records.index)} records from DB.')
|
||||||
self.set_cache(data=records, key=key)
|
self.set_cache(data=records, key=key)
|
||||||
|
|
||||||
# Check if the records in the cache go far enough back to satisfy the query.
|
# Check if the records in the cache go far enough back to satisfy the query.
|
||||||
|
|
@ -119,13 +140,14 @@ class DataCache:
|
||||||
if first_timestamp:
|
if first_timestamp:
|
||||||
# The records didn't go far enough back if a timestamp was returned.
|
# The records didn't go far enough back if a timestamp was returned.
|
||||||
end_time = dt.datetime.utcfromtimestamp(first_timestamp)
|
end_time = dt.datetime.utcfromtimestamp(first_timestamp)
|
||||||
# Request candles with open_times between [start_time:end_time] from the database.
|
logger.debug(f'Requesting additional records from {start_datetime} to {end_time}')
|
||||||
print(f'requesting additional records from {start_datetime} to {end_time}')
|
# Request additional records from the database.
|
||||||
additional_records = self.db.get_records_since(table_name=key, st=start_datetime,
|
additional_records = self.get_records_since_from_db(table_name=key, st=start_datetime,
|
||||||
et=end_time, rl=record_length, ex_details=ex_details)
|
et=end_time, rl=record_length,
|
||||||
print(f'Got {len(additional_records.index)} additional records from db')
|
ex_details=ex_details)
|
||||||
|
logger.debug(f'Got {len(additional_records.index)} additional records from DB.')
|
||||||
if not additional_records.empty:
|
if not additional_records.empty:
|
||||||
# If more records were received update the cache.
|
# If more records were received, update the cache.
|
||||||
self.update_candle_cache(additional_records, key)
|
self.update_candle_cache(additional_records, key)
|
||||||
|
|
||||||
# Check if the records received are up-to-date.
|
# Check if the records received are up-to-date.
|
||||||
|
|
@ -134,21 +156,219 @@ class DataCache:
|
||||||
if last_timestamp:
|
if last_timestamp:
|
||||||
# The query was not up-to-date if a timestamp was returned.
|
# The query was not up-to-date if a timestamp was returned.
|
||||||
start_time = dt.datetime.utcfromtimestamp(last_timestamp)
|
start_time = dt.datetime.utcfromtimestamp(last_timestamp)
|
||||||
print(f'requesting additional records from {start_time} to {end_datetime}')
|
logger.debug(f'Requesting additional records from {start_time} to {end_datetime}')
|
||||||
# Request the database update its table starting from start_datetime.
|
# Request additional records from the database.
|
||||||
additional_records = self.db.get_records_since(table_name=key, st=start_time,
|
additional_records = self.get_records_since_from_db(table_name=key, st=start_time,
|
||||||
et=end_datetime, rl=record_length,
|
et=end_datetime, rl=record_length,
|
||||||
ex_details=ex_details)
|
ex_details=ex_details)
|
||||||
print(f'Got {len(additional_records.index)} additional records from db')
|
logger.debug(f'Got {len(additional_records.index)} additional records from DB.')
|
||||||
if not additional_records.empty:
|
if not additional_records.empty:
|
||||||
self.update_candle_cache(additional_records, key)
|
self.update_candle_cache(additional_records, key)
|
||||||
|
|
||||||
# Create a UTC timestamp.
|
# Create a UTC timestamp.
|
||||||
_timestamp = unix_time_millis(start_datetime)
|
_timestamp = unix_time_millis(start_datetime)
|
||||||
# Return all records equal or newer than timestamp.
|
logger.debug(f"Start timestamp in human-readable form: {dt.datetime.utcfromtimestamp(_timestamp / 1000.0)}")
|
||||||
|
|
||||||
|
# Return all records equal to or newer than the timestamp.
|
||||||
result = self.get_cache(key).query('open_time >= @_timestamp')
|
result = self.get_cache(key).query('open_time >= @_timestamp')
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred: {str(e)}")
|
logger.error(f"An error occurred: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def get_records_since_from_db(self, table_name: str, st: dt.datetime,
|
||||||
|
et: dt.datetime, rl: float, ex_details: list) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Returns records from a specified table that meet a criteria and ensures the records are complete.
|
||||||
|
If the records do not go back far enough or are not up-to-date, it fetches additional records
|
||||||
|
from an exchange and populates the table.
|
||||||
|
|
||||||
|
:param table_name: Database table name.
|
||||||
|
:param st: Start datetime.
|
||||||
|
:param et: End datetime.
|
||||||
|
:param rl: Timespan in minutes each record represents.
|
||||||
|
:param ex_details: Exchange details [symbol, interval, exchange_name].
|
||||||
|
:return: DataFrame of records.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
records = db.get_records_since('test_table', start_time, end_time, 1, ['BTC/USDT', '1m', 'binance'])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def add_data(data, tn, start_t, end_t):
|
||||||
|
new_records = self._populate_db(table_name=tn, start_time=start_t, end_time=end_t, ex_details=ex_details)
|
||||||
|
print(f'Got {len(new_records.index)} records from exchange_name')
|
||||||
|
if not new_records.empty:
|
||||||
|
data = pd.concat([data, new_records], axis=0, ignore_index=True)
|
||||||
|
data = data.drop_duplicates(subset="open_time", keep='first')
|
||||||
|
return data
|
||||||
|
|
||||||
|
if self.db.table_exists(table_name=table_name):
|
||||||
|
print('\nTable existed retrieving records from DB')
|
||||||
|
print(f'Requesting from {st} to {et}')
|
||||||
|
records = self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=st, et=et)
|
||||||
|
print(f'Got {len(records.index)} records from db')
|
||||||
|
else:
|
||||||
|
print(f'\nTable didnt exist fetching from {ex_details[2]}')
|
||||||
|
temp = (((unix_time_millis(et) - unix_time_millis(st)) / 1000) / 60) / rl
|
||||||
|
print(f'Requesting from {st} to {et}, Should be {temp} records')
|
||||||
|
records = self._populate_db(table_name=table_name, start_time=st, end_time=et, ex_details=ex_details)
|
||||||
|
print(f'Got {len(records.index)} records from {ex_details[2]}')
|
||||||
|
|
||||||
|
first_timestamp = query_satisfied(start_datetime=st, records=records, r_length_min=rl)
|
||||||
|
if first_timestamp:
|
||||||
|
print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}')
|
||||||
|
print(f'first ts on record is: {first_timestamp}')
|
||||||
|
end_time = dt.datetime.utcfromtimestamp(first_timestamp)
|
||||||
|
print(f'Requesting from {st} to {end_time}')
|
||||||
|
records = add_data(data=records, tn=table_name, start_t=st, end_t=end_time)
|
||||||
|
|
||||||
|
last_timestamp = query_uptodate(records=records, r_length_min=rl)
|
||||||
|
if last_timestamp:
|
||||||
|
print(f'\nRecords were not updated. Requesting from {ex_details[2]}.')
|
||||||
|
print(f'the last record on file is: {last_timestamp}')
|
||||||
|
start_time = dt.datetime.utcfromtimestamp(last_timestamp)
|
||||||
|
print(f'Requesting from {start_time} to {et}')
|
||||||
|
records = add_data(data=records, tn=table_name, start_t=start_time, end_t=et)
|
||||||
|
|
||||||
|
return records
|
||||||
|
|
||||||
|
def _populate_db(self, table_name: str, start_time: dt.datetime, ex_details: list,
|
||||||
|
end_time: dt.datetime = None) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Populates a database table with records from the exchange.
|
||||||
|
|
||||||
|
:param table_name: Name of the table in the database.
|
||||||
|
:param start_time: Start time to fetch the records from.
|
||||||
|
:param end_time: End time to fetch the records until (optional).
|
||||||
|
:param ex_details: Exchange details [symbol, interval, exchange_name, user_name].
|
||||||
|
:return: DataFrame of the data downloaded.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
records = db._populate_table('test_table', start_time, ['BTC/USDT', '1m', 'binance', 'user1'])
|
||||||
|
"""
|
||||||
|
if end_time is None:
|
||||||
|
end_time = dt.datetime.utcnow()
|
||||||
|
sym, inter, ex, un = ex_details
|
||||||
|
records = self._fetch_candles_from_exchange(symbol=sym, interval=inter, exchange_name=ex, user_name=un,
|
||||||
|
start_datetime=start_time, end_datetime=end_time)
|
||||||
|
if not records.empty:
|
||||||
|
self.db.insert_candles_into_db(records, table_name=table_name, symbol=sym, exchange_name=ex)
|
||||||
|
else:
|
||||||
|
print(f'No records inserted {records}')
|
||||||
|
return records
|
||||||
|
|
||||||
|
def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str,
|
||||||
|
start_datetime: dt.datetime = None,
|
||||||
|
end_datetime: dt.datetime = None) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Fetches and returns all candles from the specified market, timeframe, and exchange.
|
||||||
|
|
||||||
|
:param symbol: Symbol of the market.
|
||||||
|
:param interval: Timeframe in the format '<int><alpha>' (e.g., '15m', '4h').
|
||||||
|
:param exchange_name: Name of the exchange.
|
||||||
|
:param user_name: Name of the user.
|
||||||
|
:param start_datetime: Start datetime for fetching data (optional).
|
||||||
|
:param end_datetime: End datetime for fetching data (optional).
|
||||||
|
:return: DataFrame of candle data.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
candles = db._fetch_candles_from_exchange('BTC/USDT', '1m', 'binance', 'user1', start_time, end_time)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Fills gaps in the data by replicating the last known data point for the missing periods.
|
||||||
|
|
||||||
|
:param records: DataFrame containing the original records.
|
||||||
|
:param interval: Interval of the data (e.g., '1m', '5m').
|
||||||
|
:return: DataFrame with gaps filled.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
filled_records = fill_data_holes(df, '1m')
|
||||||
|
"""
|
||||||
|
time_span = timeframe_to_minutes(interval)
|
||||||
|
last_timestamp = None
|
||||||
|
filled_records = []
|
||||||
|
|
||||||
|
logger.info(f"Starting to fill data holes for interval: {interval}")
|
||||||
|
|
||||||
|
for index, row in records.iterrows():
|
||||||
|
time_stamp = row['open_time']
|
||||||
|
|
||||||
|
# If last_timestamp is None, this is the first record
|
||||||
|
if last_timestamp is None:
|
||||||
|
last_timestamp = time_stamp
|
||||||
|
filled_records.append(row)
|
||||||
|
logger.debug(f"First timestamp: {time_stamp}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Calculate the difference in milliseconds and minutes
|
||||||
|
delta_ms = time_stamp - last_timestamp
|
||||||
|
delta_minutes = (delta_ms / 1000) / 60
|
||||||
|
|
||||||
|
logger.debug(f"Timestamp: {time_stamp}, Delta minutes: {delta_minutes}")
|
||||||
|
|
||||||
|
# If the gap is larger than the time span of the interval, fill the gap
|
||||||
|
if delta_minutes > time_span:
|
||||||
|
num_missing_rec = int(delta_minutes / time_span)
|
||||||
|
step = int(delta_ms / num_missing_rec)
|
||||||
|
logger.debug(f"Gap detected. Filling {num_missing_rec} records with step: {step}")
|
||||||
|
|
||||||
|
for ts in range(int(last_timestamp) + step, int(time_stamp), step):
|
||||||
|
new_row = row.copy()
|
||||||
|
new_row['open_time'] = ts
|
||||||
|
filled_records.append(new_row)
|
||||||
|
logger.debug(f"Filled timestamp: {ts}")
|
||||||
|
|
||||||
|
filled_records.append(row)
|
||||||
|
last_timestamp = time_stamp
|
||||||
|
|
||||||
|
logger.info("Data holes filled successfully.")
|
||||||
|
return pd.DataFrame(filled_records)
|
||||||
|
|
||||||
|
# Default start date if not provided
|
||||||
|
if start_datetime is None:
|
||||||
|
start_datetime = dt.datetime(year=2017, month=1, day=1)
|
||||||
|
|
||||||
|
# Default end date if not provided
|
||||||
|
if end_datetime is None:
|
||||||
|
end_datetime = dt.datetime.utcnow()
|
||||||
|
|
||||||
|
# Check if start date is greater than end date
|
||||||
|
if start_datetime > end_datetime:
|
||||||
|
raise ValueError("Invalid start and end parameters: start_datetime must be before end_datetime.")
|
||||||
|
|
||||||
|
# Get the exchange object
|
||||||
|
exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name)
|
||||||
|
|
||||||
|
# Calculate the expected number of records
|
||||||
|
temp = (((unix_time_millis(end_datetime) - unix_time_millis(
|
||||||
|
start_datetime)) / 1000) / 60) / timeframe_to_minutes(interval)
|
||||||
|
|
||||||
|
logger.info(f'Fetching historical data from {start_datetime} to {end_datetime}. Expected records: {temp}')
|
||||||
|
|
||||||
|
# If start and end times are the same, set end_datetime to None
|
||||||
|
if start_datetime == end_datetime:
|
||||||
|
end_datetime = None
|
||||||
|
|
||||||
|
# Fetch historical candlestick data from the exchange
|
||||||
|
candles = exchange.get_historical_klines(symbol=symbol, interval=interval, start_dt=start_datetime,
|
||||||
|
end_dt=end_datetime)
|
||||||
|
num_rec_records = len(candles.index)
|
||||||
|
|
||||||
|
logger.info(f'{num_rec_records} candles retrieved from the exchange.')
|
||||||
|
|
||||||
|
# Check if the retrieved data covers the expected time range
|
||||||
|
open_times = candles.open_time
|
||||||
|
estimated_num_records = (((open_times.max() - open_times.min()) / 1000) / 60) / timeframe_to_minutes(interval)
|
||||||
|
|
||||||
|
# Fill in any missing data if the retrieved data is less than expected
|
||||||
|
if num_rec_records < estimated_num_records:
|
||||||
|
candles = fill_data_holes(candles, interval)
|
||||||
|
|
||||||
|
return candles
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,352 @@
|
||||||
|
import sqlite3
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Any
|
||||||
|
import config
|
||||||
|
import datetime as dt
|
||||||
|
import pandas as pd
|
||||||
|
from shared_utilities import unix_time_millis
|
||||||
|
|
||||||
|
|
||||||
|
class SQLite:
|
||||||
|
"""
|
||||||
|
Context manager for SQLite database connections.
|
||||||
|
Accepts a database file name or defaults to the file in config.DB_FILE.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
--------------
|
||||||
|
with SQLite(db_file='test_db.sqlite') as con:
|
||||||
|
cursor = con.cursor()
|
||||||
|
cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_file=None):
|
||||||
|
self.db_file = db_file if db_file else config.DB_FILE
|
||||||
|
self.connection = sqlite3.connect(self.db_file)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self.connection
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.connection.commit()
|
||||||
|
self.connection.close()
|
||||||
|
|
||||||
|
|
||||||
|
class HDict(dict):
|
||||||
|
"""
|
||||||
|
Hashable dictionary to use as cache keys.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
--------------
|
||||||
|
hdict = HDict({'key1': 'value1', 'key2': 'value2'})
|
||||||
|
hash(hdict)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(frozenset(self.items()))
|
||||||
|
|
||||||
|
|
||||||
|
def make_query(item: str, table: str, columns: list) -> str:
|
||||||
|
"""
|
||||||
|
Creates a SQL select query string with the required number of placeholders.
|
||||||
|
|
||||||
|
:param item: The field to select.
|
||||||
|
:param table: The table to select from.
|
||||||
|
:param columns: List of columns for the where clause.
|
||||||
|
:return: The query string.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
query = make_query('id', 'test_table', ['name', 'age'])
|
||||||
|
# Result: 'SELECT id FROM test_table WHERE name = ? AND age = ?;'
|
||||||
|
"""
|
||||||
|
an_itr = iter(columns)
|
||||||
|
k = next(an_itr)
|
||||||
|
where_str = f"SELECT {item} FROM {table} WHERE {k} = ?"
|
||||||
|
where_str += "".join([f" AND {k} = ?" for k in an_itr]) + ';'
|
||||||
|
return where_str
|
||||||
|
|
||||||
|
|
||||||
|
def make_insert(table: str, values: tuple) -> str:
|
||||||
|
"""
|
||||||
|
Creates a SQL insert query string with the required number of placeholders.
|
||||||
|
|
||||||
|
:param table: The table to insert into.
|
||||||
|
:param values: Tuple of values to insert.
|
||||||
|
:return: The query string.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
insert = make_insert('test_table', ('name', 'age'))
|
||||||
|
# Result: "INSERT INTO test_table ('name', 'age') VALUES(?, ?);"
|
||||||
|
"""
|
||||||
|
itr1 = iter(values)
|
||||||
|
itr2 = iter(values)
|
||||||
|
k1 = next(itr1)
|
||||||
|
_ = next(itr2)
|
||||||
|
insert_str = f"INSERT INTO {table} ('{k1}'"
|
||||||
|
insert_str += "".join([f", '{k1}'" for k1 in itr1]) + ") VALUES(?" + "".join(
|
||||||
|
[", ?" for _ in enumerate(itr2)]) + ");"
|
||||||
|
return insert_str
|
||||||
|
|
||||||
|
|
||||||
|
class Database:
|
||||||
|
"""
|
||||||
|
Database class to communicate and maintain the database.
|
||||||
|
Handles connections and operations for the given exchanges.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
--------------
|
||||||
|
db = Database(exchanges=some_exchanges_object, db_file='test_db.sqlite')
|
||||||
|
db.execute_sql('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_file=None):
|
||||||
|
"""
|
||||||
|
Initializes the Database class.
|
||||||
|
|
||||||
|
:param db_file: Optional database file name.
|
||||||
|
"""
|
||||||
|
self.db_file = db_file
|
||||||
|
|
||||||
|
def execute_sql(self, sql: str) -> None:
|
||||||
|
"""
|
||||||
|
Executes a raw SQL statement.
|
||||||
|
|
||||||
|
:param sql: SQL statement to execute.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
db = Database(exchanges=some_exchanges_object, db_file='test_db.sqlite')
|
||||||
|
db.execute_sql('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||||
|
"""
|
||||||
|
with SQLite(self.db_file) as con:
|
||||||
|
cur = con.cursor()
|
||||||
|
cur.execute(sql)
|
||||||
|
|
||||||
|
def get_item_where(self, item_name: str, table_name: str, filter_vals: tuple) -> int:
|
||||||
|
"""
|
||||||
|
Returns an item from a table where the filter results should isolate a single row.
|
||||||
|
|
||||||
|
:param item_name: Name of the item to fetch.
|
||||||
|
:param table_name: Name of the table.
|
||||||
|
:param filter_vals: Tuple of column name and value to filter by.
|
||||||
|
:return: The item.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
item = db.get_item_where('name', 'test_table', ('id', 1))
|
||||||
|
# Fetches the 'name' from 'test_table' where 'id' is 1
|
||||||
|
"""
|
||||||
|
with SQLite(self.db_file) as con:
|
||||||
|
cur = con.cursor()
|
||||||
|
qry = make_query(item_name, table_name, [filter_vals[0]])
|
||||||
|
cur.execute(qry, (filter_vals[1],))
|
||||||
|
if user_id := cur.fetchone():
|
||||||
|
return user_id[0]
|
||||||
|
else:
|
||||||
|
error = f"Couldn't fetch item {item_name} from {table_name} where {filter_vals[0]} = {filter_vals[1]}"
|
||||||
|
raise ValueError(error)
|
||||||
|
|
||||||
|
def get_rows_where(self, table: str, filter_vals: tuple) -> pd.DataFrame | None:
|
||||||
|
"""
|
||||||
|
Returns a DataFrame containing all rows of a table that meet the filter criteria.
|
||||||
|
|
||||||
|
:param table: Name of the table.
|
||||||
|
:param filter_vals: Tuple of column name and value to filter by.
|
||||||
|
:return: DataFrame of the query result or None if empty.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
rows = db.get_rows_where('test_table', ('name', 'test'))
|
||||||
|
# Fetches all rows from 'test_table' where 'name' is 'test'
|
||||||
|
"""
|
||||||
|
with SQLite(self.db_file) as con:
|
||||||
|
qry = f"SELECT * FROM {table} WHERE {filter_vals[0]}='{filter_vals[1]}'"
|
||||||
|
result = pd.read_sql(qry, con=con)
|
||||||
|
if not result.empty:
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def insert_dataframe(self, df: pd.DataFrame, table: str) -> None:
|
||||||
|
"""
|
||||||
|
Inserts a DataFrame into a specified table.
|
||||||
|
|
||||||
|
:param df: DataFrame to insert.
|
||||||
|
:param table: Name of the table.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
df = pd.DataFrame({'id': [1], 'name': ['test']})
|
||||||
|
db.insert_dataframe(df, 'test_table')
|
||||||
|
"""
|
||||||
|
with SQLite(self.db_file) as con:
|
||||||
|
df.to_sql(name=table, con=con, index=False, if_exists='append')
|
||||||
|
|
||||||
|
def insert_row(self, table: str, columns: tuple, values: tuple) -> None:
|
||||||
|
"""
|
||||||
|
Inserts a row into a specified table.
|
||||||
|
|
||||||
|
:param table: Name of the table.
|
||||||
|
:param columns: Tuple of column names.
|
||||||
|
:param values: Tuple of values to insert.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
db.insert_row('test_table', ('id', 'name'), (1, 'test'))
|
||||||
|
"""
|
||||||
|
with SQLite(self.db_file) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
sql = make_insert(table=table, values=columns)
|
||||||
|
cursor.execute(sql, values)
|
||||||
|
|
||||||
|
def table_exists(self, table_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if a table exists in the database.
|
||||||
|
|
||||||
|
:param table_name: Name of the table.
|
||||||
|
:return: True if the table exists, False otherwise.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
exists = db._table_exists('test_table')
|
||||||
|
# Checks if 'test_table' exists in the database
|
||||||
|
"""
|
||||||
|
with SQLite(self.db_file) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
sql = make_query(item='name', table='sqlite_schema', columns=['type', 'name'])
|
||||||
|
cursor.execute(sql, ('table', table_name))
|
||||||
|
result = cursor.fetchone()
|
||||||
|
return result is not None
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1000)
|
||||||
|
def get_from_static_table(self, item: str, table: str, indexes: HDict, create_id: bool = False) -> Any:
|
||||||
|
"""
|
||||||
|
Returns the row id of an item from a table specified. If the item isn't listed in the table,
|
||||||
|
it will insert the item into a new row and return the autoincremented id. The item received as a hashable
|
||||||
|
dictionary so the results can be cached.
|
||||||
|
|
||||||
|
:param item: Name of the item requested.
|
||||||
|
:param table: Table being queried.
|
||||||
|
:param indexes: Hashable dictionary of indexing columns and their values.
|
||||||
|
:param create_id: If True, create a row if it doesn't exist and return the autoincrement ID.
|
||||||
|
:return: The content of the field.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
exchange_id = db.get_from_static_table('id', 'exchange', HDict({'name': 'binance'}), create_id=True)
|
||||||
|
"""
|
||||||
|
with SQLite(self.db_file) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
sql = make_query(item, table, list(indexes.keys()))
|
||||||
|
cursor.execute(sql, tuple(indexes.values()))
|
||||||
|
result = cursor.fetchone()
|
||||||
|
if result is None and create_id:
|
||||||
|
sql = make_insert(table, tuple(indexes.keys()))
|
||||||
|
cursor.execute(sql, tuple(indexes.values()))
|
||||||
|
sql = make_query(item, table, list(indexes.keys()))
|
||||||
|
cursor.execute(sql, tuple(indexes.values()))
|
||||||
|
result = cursor.fetchone()
|
||||||
|
return result[0] if result else None
|
||||||
|
|
||||||
|
def _fetch_exchange_id(self, exchange_name: str) -> int:
|
||||||
|
"""
|
||||||
|
Fetches the primary ID of an exchange from the database.
|
||||||
|
|
||||||
|
:param exchange_name: Name of the exchange.
|
||||||
|
:return: Primary ID of the exchange.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
exchange_id = db._fetch_exchange_id('binance')
|
||||||
|
"""
|
||||||
|
return self.get_from_static_table(item='id', table='exchange', create_id=True,
|
||||||
|
indexes=HDict({'name': exchange_name}))
|
||||||
|
|
||||||
|
def _fetch_market_id(self, symbol: str, exchange_name: str) -> int:
|
||||||
|
"""
|
||||||
|
Returns the market ID for a trading pair listed in the database.
|
||||||
|
|
||||||
|
:param symbol: Symbol of the trading pair.
|
||||||
|
:param exchange_name: Name of the exchange.
|
||||||
|
:return: Market ID.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
market_id = db._fetch_market_id('BTC/USDT', 'binance')
|
||||||
|
"""
|
||||||
|
exchange_id = self._fetch_exchange_id(exchange_name)
|
||||||
|
market_id = self.get_from_static_table(item='id', table='markets', create_id=True,
|
||||||
|
indexes=HDict({'symbol': symbol, 'exchange_id': exchange_id}))
|
||||||
|
return market_id
|
||||||
|
|
||||||
|
def insert_candles_into_db(self, candlesticks: pd.DataFrame, table_name: str, symbol: str,
|
||||||
|
exchange_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Inserts all candlesticks from a DataFrame into the database.
|
||||||
|
|
||||||
|
:param candlesticks: DataFrame of candlestick data.
|
||||||
|
:param table_name: Name of the table to insert into.
|
||||||
|
:param symbol: Symbol of the trading pair.
|
||||||
|
:param exchange_name: Name of the exchange.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
df = pd.DataFrame({
|
||||||
|
'open_time': [unix_time_millis(dt.datetime.utcnow())],
|
||||||
|
'open': [1.0],
|
||||||
|
'high': [1.0],
|
||||||
|
'low': [1.0],
|
||||||
|
'close': [1.0],
|
||||||
|
'volume': [1.0]
|
||||||
|
})
|
||||||
|
db._insert_candles_into_db(df, 'test_table', 'BTC/USDT', 'binance')
|
||||||
|
"""
|
||||||
|
market_id = self._fetch_market_id(symbol, exchange_name)
|
||||||
|
candlesticks.insert(0, 'market_id', market_id)
|
||||||
|
sql_create = f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS '{table_name}' (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
market_id INTEGER,
|
||||||
|
open_time INTEGER UNIQUE ON CONFLICT IGNORE,
|
||||||
|
open REAL NOT NULL,
|
||||||
|
high REAL NOT NULL,
|
||||||
|
low REAL NOT NULL,
|
||||||
|
close REAL NOT NULL,
|
||||||
|
volume REAL NOT NULL,
|
||||||
|
FOREIGN KEY (market_id) REFERENCES market (id)
|
||||||
|
)"""
|
||||||
|
with SQLite(self.db_file) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(sql_create)
|
||||||
|
candlesticks.to_sql(table_name, conn, if_exists='append', index=False)
|
||||||
|
|
||||||
|
def get_timestamped_records(self, table_name: str, timestamp_field: str, st: dt.datetime,
|
||||||
|
et: dt.datetime = None) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Returns records from a specified table in the database that have timestamps greater than or equal to a given
|
||||||
|
start time and, optionally, less than or equal to a given end time.
|
||||||
|
|
||||||
|
:param table_name: Database table name.
|
||||||
|
:param timestamp_field: Field name that contains the timestamp.
|
||||||
|
:param st: Start datetime.
|
||||||
|
:param et: End datetime (optional).
|
||||||
|
:return: DataFrame of records.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
records = db.get_timestamped_records('test_table', 'open_time', start_time, end_time)
|
||||||
|
"""
|
||||||
|
with SQLite(self.db_file) as conn:
|
||||||
|
start_stamp = unix_time_millis(st)
|
||||||
|
if et is not None:
|
||||||
|
end_stamp = unix_time_millis(et)
|
||||||
|
q_str = (
|
||||||
|
f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= {start_stamp} "
|
||||||
|
f"AND {timestamp_field} <= {end_stamp};"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q_str = f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= {start_stamp};"
|
||||||
|
records = pd.read_sql(q_str, conn)
|
||||||
|
records = records.drop('id', axis=1)
|
||||||
|
return records
|
||||||
|
|
@ -5,7 +5,7 @@ from typing import Any
|
||||||
|
|
||||||
from passlib.hash import bcrypt
|
from passlib.hash import bcrypt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from database import HDict
|
from Database import HDict
|
||||||
|
|
||||||
|
|
||||||
class Users:
|
class Users:
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,12 @@
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
import logging as log
|
import logging as log
|
||||||
from DataCache import DataCache
|
|
||||||
from shared_utilities import timeframe_to_minutes, ts_of_n_minutes_ago
|
from shared_utilities import timeframe_to_minutes, ts_of_n_minutes_ago
|
||||||
|
|
||||||
|
|
||||||
# log.basicConfig(level=log.ERROR)
|
# log.basicConfig(level=log.ERROR)
|
||||||
|
|
||||||
class Candles:
|
class Candles:
|
||||||
def __init__(self, exchanges, config_obj, database):
|
def __init__(self, exchanges, config_obj, data_source):
|
||||||
|
|
||||||
# A reference to the app configuration
|
# A reference to the app configuration
|
||||||
self.config = config_obj
|
self.config = config_obj
|
||||||
|
|
@ -15,8 +14,8 @@ class Candles:
|
||||||
# The maximum amount of candles to load at one time.
|
# The maximum amount of candles to load at one time.
|
||||||
self.max_records = self.config.app_data.get('max_data_loaded')
|
self.max_records = self.config.app_data.get('max_data_loaded')
|
||||||
|
|
||||||
# This object maintains all the cached data. Pass it connection to the exchanges.
|
# This object maintains all the cached data.
|
||||||
self.data = database
|
self.data = data_source
|
||||||
|
|
||||||
# print('Setting the candle cache.')
|
# print('Setting the candle cache.')
|
||||||
# # Populate the cache:
|
# # Populate the cache:
|
||||||
|
|
|
||||||
496
src/database.py
496
src/database.py
|
|
@ -1,496 +0,0 @@
|
||||||
import sqlite3
|
|
||||||
from functools import lru_cache
|
|
||||||
from typing import Any
|
|
||||||
import config
|
|
||||||
import datetime as dt
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
from shared_utilities import query_satisfied, query_uptodate, unix_time_millis, timeframe_to_minutes
|
|
||||||
|
|
||||||
|
|
||||||
class SQLite:
|
|
||||||
"""
|
|
||||||
Context manager returns a cursor. The connection is closed when
|
|
||||||
the cursor is destroyed, even if an exception is thrown.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.connection = sqlite3.connect(config.DB_FILE)
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self.connection
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
self.connection.commit()
|
|
||||||
self.connection.close()
|
|
||||||
|
|
||||||
|
|
||||||
class HDict(dict):
|
|
||||||
def __hash__(self):
|
|
||||||
return hash(frozenset(self.items()))
|
|
||||||
|
|
||||||
|
|
||||||
def make_query(item: str, table: str, columns: list) -> str:
|
|
||||||
"""
|
|
||||||
Creates a sql select string with the required number of ?'s to match the given columns.
|
|
||||||
|
|
||||||
:param item: The field to select.
|
|
||||||
:param table: The table to select.
|
|
||||||
:param columns: list - A list of database columns.
|
|
||||||
:return: str: The query string.
|
|
||||||
"""
|
|
||||||
an_itr = iter(columns)
|
|
||||||
k = next(an_itr)
|
|
||||||
where_str = f"SELECT {item} FROM {table} WHERE {k} = ?"
|
|
||||||
where_str = where_str + "".join([f" AND {k} = ?" for k in an_itr]) + ';'
|
|
||||||
return where_str
|
|
||||||
|
|
||||||
|
|
||||||
def make_insert(table: str, values: tuple) -> str:
|
|
||||||
"""
|
|
||||||
Creates a sql insert string with the required number of ?'s to match the given values.
|
|
||||||
|
|
||||||
:param table: The table to insert into.
|
|
||||||
:param values: dict - A dictionary of table_name-value pairs used to index a db query.
|
|
||||||
:return: str: The query string.
|
|
||||||
"""
|
|
||||||
itr1 = iter(values)
|
|
||||||
itr2 = iter(values)
|
|
||||||
k1 = next(itr1)
|
|
||||||
_ = next(itr2)
|
|
||||||
insert_str = f"INSERT INTO {table} ('{k1}'"
|
|
||||||
insert_str = insert_str + "".join([f", '{k1}'" for k1 in itr1]) + ") VALUES(?" + "".join(
|
|
||||||
[", ?" for _ in enumerate(itr2)]) + ");"
|
|
||||||
return insert_str
|
|
||||||
|
|
||||||
|
|
||||||
class Database:
|
|
||||||
"""
|
|
||||||
Communicates and maintains the database.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, exchanges):
|
|
||||||
# The exchanges object handles communication with all connected exchanges.
|
|
||||||
self.exchanges = exchanges
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def execute_sql(sql: str) -> None:
|
|
||||||
"""
|
|
||||||
Executes a sql statement. This is for stuff I haven't created a function for yet.
|
|
||||||
|
|
||||||
:param sql: str - sql statement.
|
|
||||||
:return: None
|
|
||||||
"""
|
|
||||||
with SQLite() as con:
|
|
||||||
cur = con.cursor()
|
|
||||||
cur.execute(sql)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_item_where(item_name: str, table_name: str, filter_vals: tuple) -> int:
|
|
||||||
"""
|
|
||||||
Returns an item from a table where the filter results should isolate a single row.
|
|
||||||
|
|
||||||
:param item_name: str - The name of the item to fetch.
|
|
||||||
:param table_name: str - The name of the table.
|
|
||||||
:param filter_vals: tuple(str, str) - The column and value to filter the results with.
|
|
||||||
:return: str - The item.
|
|
||||||
"""
|
|
||||||
with SQLite() as con:
|
|
||||||
cur = con.cursor()
|
|
||||||
qry = make_query(item_name, table_name, [filter_vals[0]])
|
|
||||||
cur.execute(qry, (filter_vals[1],))
|
|
||||||
if user_id := cur.fetchone():
|
|
||||||
return user_id[0]
|
|
||||||
else:
|
|
||||||
error = f"Couldn't fetch item{item_name} from {table_name} where {filter_vals[0]} = {filter_vals[1]}"
|
|
||||||
raise ValueError(error)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_rows_where(table: str, filter_vals: tuple) -> pd.DataFrame | None:
|
|
||||||
"""
|
|
||||||
Returns a dataframe containing all rows of a table that meet the filter criteria.
|
|
||||||
|
|
||||||
:param table: str - The name of the table.
|
|
||||||
:param filter_vals: tuple(column: str, value: str) - the criteria
|
|
||||||
:return: dataframe|None - returns the data in a dataframe or None if the query fails.
|
|
||||||
"""
|
|
||||||
with SQLite() as con:
|
|
||||||
qry = f"select * from {table} where {filter_vals[0]}='{filter_vals[1]}'"
|
|
||||||
result = pd.read_sql(qry, con=con)
|
|
||||||
if not result.empty:
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def insert_dataframe(df, table):
|
|
||||||
# Connect to the database.
|
|
||||||
with SQLite() as con:
|
|
||||||
# Insert the modified user as a new record in the table.
|
|
||||||
df.to_sql(name=table, con=con, index=False, if_exists='append')
|
|
||||||
# Commit the changes to the database.
|
|
||||||
con.commit()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def insert_row(table: str, columns: tuple, values: tuple) -> None:
|
|
||||||
"""
|
|
||||||
Saves user specific data from a table in the database.
|
|
||||||
|
|
||||||
:param table: str - The table to insert into
|
|
||||||
:param columns: tuple(str1, str2, ...) - The columns of the database.
|
|
||||||
:param values: tuple(val1, val2, ...) - The values to be inserted.
|
|
||||||
:return: None
|
|
||||||
"""
|
|
||||||
# Connect to the database.
|
|
||||||
with SQLite() as conn:
|
|
||||||
# Get a cursor from the sql connection.
|
|
||||||
cursor = conn.cursor()
|
|
||||||
sql = make_insert(table=table, values=columns)
|
|
||||||
cursor.execute(sql, values)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _table_exists(table_name: str) -> bool:
|
|
||||||
"""
|
|
||||||
Returns True if table_name exists in the database.
|
|
||||||
|
|
||||||
:param table_name: The name of the database.
|
|
||||||
:return: bool - True|False
|
|
||||||
"""
|
|
||||||
# Connect to the database.
|
|
||||||
with SQLite() as conn:
|
|
||||||
# Get a cursor from the sql connection.
|
|
||||||
cursor = conn.cursor()
|
|
||||||
# sql = f"SELECT name FROM sqlite_schema WHERE type = 'table' AND name = '{table_name}';"
|
|
||||||
sql = make_query(item='name', table='sqlite_schema', columns=['type', 'name'])
|
|
||||||
# Check if the table exists.
|
|
||||||
cursor.execute(sql, ('table', table_name))
|
|
||||||
# Fetch the results from the cursor.
|
|
||||||
result = cursor.fetchone()
|
|
||||||
if not result:
|
|
||||||
# If the table doesn't exist return False.
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _populate_table(self, table_name: str, start_time: dt.datetime, ex_details: list, end_time: dt.datetime = None):
|
|
||||||
"""
|
|
||||||
Populates a database table with records from the exchange_name.
|
|
||||||
:param table_name: str - The name of the table in the database.
|
|
||||||
:param start_time: datetime - The starting time to fetch the records from.
|
|
||||||
:param end_time: datetime - The end time to get the records until.
|
|
||||||
:return: pdDataframe: - The data that was downloaded.
|
|
||||||
"""
|
|
||||||
# Set the default end_time to UTC now.
|
|
||||||
if end_time is None:
|
|
||||||
end_time = dt.datetime.utcnow()
|
|
||||||
# Fetch the records from the exchange_name.
|
|
||||||
# Extract the parameters from the details. Format: <symbol>_<timeframe>_<exchange_name>.
|
|
||||||
sym, inter, ex, un = ex_details
|
|
||||||
records = self._fetch_candles_from_exchange(symbol=sym, interval=inter, exchange_name=ex, user_name=un,
|
|
||||||
start_datetime=start_time, end_datetime=end_time)
|
|
||||||
# Update the database.
|
|
||||||
if not records.empty:
|
|
||||||
# Inert into the database any received records.
|
|
||||||
self._insert_candles_into_db(records, table_name=table_name, symbol=sym, exchange_name=ex)
|
|
||||||
else:
|
|
||||||
print(f'No records inserted {records}')
|
|
||||||
return records
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache(maxsize=1000)
|
|
||||||
def get_from_static_table(item: str, table: str, indexes: HDict, create_id: bool = False) -> Any:
|
|
||||||
"""
|
|
||||||
Retrieves a single item from a table. This method returns a cached result and is ment
|
|
||||||
for fetching static data like settings, names and ID's.
|
|
||||||
|
|
||||||
:param create_id: bool: - If True, create a row if it doesn't exist and return the autoincrement ID.
|
|
||||||
:param item: str - The name of the item requested.
|
|
||||||
:param table: str - The table being queried.
|
|
||||||
:param indexes: str - A hashable dictionary containing the indexing columns and their values.
|
|
||||||
:return: Any - The content of the field.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Connect to the database.
|
|
||||||
with SQLite() as conn:
|
|
||||||
# Get a cursor from the sql connection.
|
|
||||||
cursor = conn.cursor()
|
|
||||||
# Retrieve the record from the db.
|
|
||||||
sql = make_query(item, table, list(indexes.keys()))
|
|
||||||
cursor.execute(sql, tuple(indexes.values()))
|
|
||||||
# The result is returned as tuple. Example: (id,)
|
|
||||||
result = cursor.fetchone()
|
|
||||||
|
|
||||||
if result is None and create_id is True:
|
|
||||||
# Insert the indexes into the db.
|
|
||||||
sql = make_insert(table, tuple(indexes.keys()))
|
|
||||||
cursor.execute(sql, tuple(indexes.values()))
|
|
||||||
# Retrieve the record from the db.
|
|
||||||
sql = make_query(item, table, list(indexes.keys()))
|
|
||||||
cursor.execute(sql, tuple(indexes.values()))
|
|
||||||
# Get the first element of the tuple received from sql query.
|
|
||||||
result = cursor.fetchone()
|
|
||||||
|
|
||||||
# Return the result from the tuple if it exists.
|
|
||||||
if result:
|
|
||||||
return result[0]
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _fetch_exchange_id(self, exchange_name: str) -> int:
|
|
||||||
"""
|
|
||||||
Fetch the primary id of exchange_name from the database.
|
|
||||||
|
|
||||||
:param exchange_name: str - The name of the exchange_name.
|
|
||||||
:return: int - The primary id of the exchange_name.
|
|
||||||
"""
|
|
||||||
return self.get_from_static_table(item='id', table='exchange', create_id=True,indexes=HDict({'name': exchange_name}))
|
|
||||||
|
|
||||||
def _fetch_market_id(self, symbol: str, exchange_name: str) -> int:
|
|
||||||
"""
|
|
||||||
Returns the market id that belongs to a trading pair listed in the database.
|
|
||||||
|
|
||||||
:param symbol: str - The symbol of the trading pair.
|
|
||||||
:param exchange_name: str - The exchange_name name.
|
|
||||||
:return: int - The market ID
|
|
||||||
"""
|
|
||||||
# Fetch the id of the exchange_name.
|
|
||||||
exchange_id = self._fetch_exchange_id(exchange_name)
|
|
||||||
|
|
||||||
# Ask the db for the market_id. Tell it to create one if it doesn't already exist.
|
|
||||||
market_id = self.get_from_static_table(item='id', table='markets', create_id=True,
|
|
||||||
indexes=HDict({'symbol': symbol, 'exchange_id': exchange_id}))
|
|
||||||
# Return the market id.
|
|
||||||
return market_id
|
|
||||||
|
|
||||||
def _insert_candles_into_db(self, candlesticks, table_name: str, symbol, exchange_name) -> None:
|
|
||||||
"""
|
|
||||||
Insert all the candlesticks from a dataframe into the database.
|
|
||||||
|
|
||||||
:param exchange_name: The name of the exchange_name.
|
|
||||||
:param symbol: The symbol of the trading pair.
|
|
||||||
:param candlesticks: pd.dataframe - A rows of candlestick attributes.
|
|
||||||
:param table_name: str - The name of the table to inset.
|
|
||||||
:return: None
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Retrieve the market id for the symbol.
|
|
||||||
market_id = self._fetch_market_id(symbol, exchange_name)
|
|
||||||
# Insert the market id into the dataframe.
|
|
||||||
candlesticks.insert(0, 'market_id', market_id)
|
|
||||||
# Create a table schema. todo delete these line if not needed anymore
|
|
||||||
# # Get a list of all the columns in the dataframe.
|
|
||||||
# columns = list(candlesticks.columns.values)
|
|
||||||
# # Isolate any extra columns specific to individual exchanges.
|
|
||||||
# # The carriage return and tabs are unnecessary, they just tidy output for debugging.
|
|
||||||
# columns = ',\n\t\t\t\t\t'.join(columns[7:], )
|
|
||||||
# # Define the columns common with all exchanges and append any extras columns.
|
|
||||||
sql_create = f"""
|
|
||||||
CREATE TABLE IF NOT EXISTS '{table_name}' (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
market_id INTEGER,
|
|
||||||
open_time INTEGER UNIQUE ON CONFLICT IGNORE,
|
|
||||||
open REAL NOT NULL,
|
|
||||||
high REAL NOT NULL,
|
|
||||||
low REAL NOT NULL,
|
|
||||||
close REAL NOT NULL,
|
|
||||||
volume REAL NOT NULL,
|
|
||||||
FOREIGN KEY (market_id) REFERENCES market (id)
|
|
||||||
)"""
|
|
||||||
# Connect to the database.
|
|
||||||
with SQLite() as conn:
|
|
||||||
# Get a cursor from the sql connection.
|
|
||||||
cursor = conn.cursor()
|
|
||||||
# Create the table if it doesn't exist.
|
|
||||||
cursor.execute(sql_create)
|
|
||||||
# Insert the candles into the table.
|
|
||||||
candlesticks.to_sql(table_name, conn, if_exists='append', index=False)
|
|
||||||
return
|
|
||||||
|
|
||||||
def get_records_since(self, table_name: str, st: dt.datetime,
|
|
||||||
et: dt.datetime, rl: float, ex_details: list) -> pd.DataFrame:
|
|
||||||
"""
|
|
||||||
Returns all the candles newer than the provided start_datetime from the specified table.
|
|
||||||
|
|
||||||
:param ex_details: list of details to pass to the server. [symbol, interval, exchange_name]
|
|
||||||
:param table_name: str - The database table name. Format: : <symbol>_<timeframe>_<exchange_name>.
|
|
||||||
:param st: dt.datetime.start_datetime - The start_datetime of the first record requested.
|
|
||||||
:param et: dt.datetime - The end time of the query
|
|
||||||
:param rl: float - The timespan in minutes each record represents.
|
|
||||||
:return: pd.dataframe -
|
|
||||||
"""
|
|
||||||
|
|
||||||
def add_data(data, tn, start_t, end_t):
|
|
||||||
new_records = self._populate_table(table_name=tn, start_time=start_t, end_time=end_t, ex_details=ex_details)
|
|
||||||
print(f'Got {len(new_records.index)} records from exchange_name')
|
|
||||||
if not new_records.empty:
|
|
||||||
# Combine the new records with the previously records.
|
|
||||||
data = pd.concat([data, new_records], axis=0, ignore_index=True)
|
|
||||||
# Drop any duplicates from overlap.
|
|
||||||
data = data.drop_duplicates(subset="open_time", keep='first')
|
|
||||||
# Return the modified dataframe.
|
|
||||||
return data
|
|
||||||
|
|
||||||
if self._table_exists(table_name=table_name):
|
|
||||||
# If the table exists retrieve all the records.
|
|
||||||
print('\nTable existed retrieving records from DB')
|
|
||||||
print(f'Requesting from {st} to {et}')
|
|
||||||
records = self._get_records(table_name=table_name, st=st, et=et)
|
|
||||||
print(f'Got {len(records.index)} records from db')
|
|
||||||
else:
|
|
||||||
# If the table doesn't exist, get them from the exchange_name.
|
|
||||||
print(f'\nTable didnt exist fetching from {ex_details[2]}')
|
|
||||||
temp = (((unix_time_millis(et) - unix_time_millis(st)) / 1000) / 60) / rl
|
|
||||||
print(f'Requesting from {st} to {et}, Should be {temp} records')
|
|
||||||
records = self._populate_table(table_name=table_name, start_time=st, end_time=et, ex_details=ex_details)
|
|
||||||
print(f'Got {len(records.index)} records from {ex_details[2]}')
|
|
||||||
|
|
||||||
# Check if the records in the db go far enough back to satisfy the query.
|
|
||||||
first_timestamp = query_satisfied(start_datetime=st, records=records, r_length_min=rl)
|
|
||||||
if first_timestamp:
|
|
||||||
# The records didn't go far enough back if a timestamp was returned.
|
|
||||||
print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}')
|
|
||||||
print(f'first ts on record is: {first_timestamp}')
|
|
||||||
end_time = dt.datetime.utcfromtimestamp(first_timestamp)
|
|
||||||
print(f'Requesting from {st} to {end_time}')
|
|
||||||
# Request records with open_times between [st:end_time] from the database.
|
|
||||||
records = add_data(data=records, tn=table_name, start_t=st, end_t=end_time)
|
|
||||||
|
|
||||||
# Check if the records received are up-to-date.
|
|
||||||
last_timestamp = query_uptodate(records=records, r_length_min=rl)
|
|
||||||
if last_timestamp:
|
|
||||||
# The query was not up-to-date if a timestamp was returned.
|
|
||||||
print(f'\nRecords were not updated. Requesting from {ex_details[2]}.')
|
|
||||||
print(f'the last record on file is: {last_timestamp}')
|
|
||||||
start_time = dt.datetime.utcfromtimestamp(last_timestamp)
|
|
||||||
print(f'Requesting from {start_time} to {et}')
|
|
||||||
# Request records with open_times between [start_time:et] from the database.
|
|
||||||
records = add_data(data=records, tn=table_name, start_t=start_time, end_t=et)
|
|
||||||
|
|
||||||
return records
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_records(table_name: str, st: dt.datetime, et: dt.datetime = None) -> pd.DataFrame:
|
|
||||||
"""
|
|
||||||
Returns all the candles newer than the provided start_datetime from the specified table.
|
|
||||||
|
|
||||||
:param table_name: str - The database table name. Format: : <symbol>_<timeframe>_<exchange_name>.
|
|
||||||
:param st: dt.datetime.start_datetime - The start_datetime of the first record requested.
|
|
||||||
:param et: dt.datetime - The end time of the query
|
|
||||||
:return: pd.dataframe -
|
|
||||||
"""
|
|
||||||
# Connect to the database.
|
|
||||||
with SQLite() as conn:
|
|
||||||
# Create a timestamp in milliseconds
|
|
||||||
start_stamp = unix_time_millis(st)
|
|
||||||
if et is not None:
|
|
||||||
# Create a timestamp in milliseconds
|
|
||||||
end_stamp = unix_time_millis(et)
|
|
||||||
q_str = f"SELECT * FROM '{table_name}' WHERE open_time >= {start_stamp} AND open_time <= {end_stamp};"
|
|
||||||
else:
|
|
||||||
q_str = f"SELECT * FROM '{table_name}' WHERE open_time >= {start_stamp};"
|
|
||||||
# Retrieve all the records from the table.
|
|
||||||
records = pd.read_sql(q_str, conn)
|
|
||||||
# Drop the databases primary id.
|
|
||||||
records = records.drop('id', axis=1)
|
|
||||||
# Return the data.
|
|
||||||
return records
|
|
||||||
|
|
||||||
# @staticmethod
|
|
||||||
# def date_of_last_timestamp(table_name):
|
|
||||||
# """
|
|
||||||
# Returns the latest timestamp stored in the db.
|
|
||||||
# TODO: Unused.
|
|
||||||
#
|
|
||||||
# :return: dt.timestamp
|
|
||||||
# """
|
|
||||||
# # Connect to the database.
|
|
||||||
# with SQLite() as conn:
|
|
||||||
# # Get a cursor from the connection.
|
|
||||||
# cursor = conn.cursor()
|
|
||||||
# cursor.execute(f"""SELECT open_time FROM '{table_name}' ORDER BY open_time DESC LIMIT 1""")
|
|
||||||
# ts = cursor.fetchone()[0] / 1000
|
|
||||||
# return dt.datetime.utcfromtimestamp(ts)
|
|
||||||
#
|
|
||||||
def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str,
|
|
||||||
start_datetime: object = None, end_datetime: object = None) -> pd.DataFrame:
|
|
||||||
"""
|
|
||||||
Fetches and returns all candles from the specified market, timeframe, and exchange_name.
|
|
||||||
|
|
||||||
:param symbol: str - The symbol of the market.
|
|
||||||
:param interval: str - The timeframe. Format '<int><alpha>' - examples: '15m', '4h'
|
|
||||||
:param exchange_name: str - The name of the exchange_name.
|
|
||||||
:param start_datetime: dt.datetime - The open_time of the first record requested.
|
|
||||||
:param end_datetime: dt.datetime - The end_time for the query.
|
|
||||||
:return: pd.DataFrame: Dataframe containing rows of candle attributes that vary
|
|
||||||
depending on the exchange_name.
|
|
||||||
For example: [open_time, open, high, low, close, volume, close_time,
|
|
||||||
quote_volume, num_trades, taker_buy_base_volume, taker_buy_quote_volume]
|
|
||||||
"""
|
|
||||||
|
|
||||||
def fill_data_holes(records, interval):
|
|
||||||
time_span = timeframe_to_minutes(interval)
|
|
||||||
last_timestamp = None
|
|
||||||
filled_records = []
|
|
||||||
|
|
||||||
for _, row in records.iterrows():
|
|
||||||
time_stamp = row['open_time']
|
|
||||||
|
|
||||||
if last_timestamp is None:
|
|
||||||
last_timestamp = time_stamp
|
|
||||||
filled_records.append(row)
|
|
||||||
continue
|
|
||||||
|
|
||||||
delta_ms = time_stamp - last_timestamp
|
|
||||||
delta_minutes = (delta_ms / 1000) / 60
|
|
||||||
|
|
||||||
if delta_minutes > time_span:
|
|
||||||
num_missing_rec = int(delta_minutes / time_span)
|
|
||||||
step = int(delta_ms / num_missing_rec)
|
|
||||||
|
|
||||||
for ts in range(int(last_timestamp) + step, int(time_stamp), step):
|
|
||||||
new_row = row.copy()
|
|
||||||
new_row['open_time'] = ts
|
|
||||||
filled_records.append(new_row)
|
|
||||||
|
|
||||||
filled_records.append(row)
|
|
||||||
last_timestamp = time_stamp
|
|
||||||
|
|
||||||
return pd.DataFrame(filled_records)
|
|
||||||
|
|
||||||
# Default start date for fetching from the exchange_name.
|
|
||||||
if start_datetime is None:
|
|
||||||
start_datetime = dt.datetime(year=2017, month=1, day=1)
|
|
||||||
|
|
||||||
# Default end date for fetching from the exchange_name.
|
|
||||||
if end_datetime is None:
|
|
||||||
end_datetime = dt.datetime.utcnow()
|
|
||||||
|
|
||||||
if start_datetime > end_datetime:
|
|
||||||
raise ValueError("\ndatabase:fetch_candles_from_exchange():"
|
|
||||||
" Invalid start and end parameters: ", start_datetime, end_datetime)
|
|
||||||
|
|
||||||
# Get a reference to the exchange
|
|
||||||
exchange = self.exchanges.get_exchange(ename=exchange_name, uname=user_name)
|
|
||||||
|
|
||||||
temp = (((unix_time_millis(end_datetime) - unix_time_millis(
|
|
||||||
start_datetime)) / 1000) / 60) / timeframe_to_minutes(interval)
|
|
||||||
print(f'Fetching historical data {start_datetime} to {end_datetime}, Should be {temp} records')
|
|
||||||
|
|
||||||
if start_datetime == end_datetime:
|
|
||||||
end_datetime = None
|
|
||||||
|
|
||||||
# Request candlestick data from the exchange_name.
|
|
||||||
candles = exchange.get_historical_klines(symbol=symbol,
|
|
||||||
interval=interval,
|
|
||||||
start_dt=start_datetime,
|
|
||||||
end_dt=end_datetime)
|
|
||||||
num_rec_records = len(candles.index)
|
|
||||||
print(f'\n{num_rec_records} candles retrieved from the exchange_name.')
|
|
||||||
# Isolate the open_times from the records received.
|
|
||||||
open_times = candles.open_time
|
|
||||||
# Calculate the number of records that would fit between the min and max open time.
|
|
||||||
estimated_num_records = (((open_times.max() - open_times.min()) / 1000) / 60) / timeframe_to_minutes(interval)
|
|
||||||
if num_rec_records < estimated_num_records:
|
|
||||||
# Some records may be missing due to server maintenance periods ect.
|
|
||||||
# Fill the holes with copies of the last record received before the gap.
|
|
||||||
candles = fill_data_holes(candles, interval)
|
|
||||||
return candles
|
|
||||||
|
|
@ -3,15 +3,53 @@ from exchangeinterface import ExchangeInterface
|
||||||
import unittest
|
import unittest
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
|
import os
|
||||||
|
from Database import SQLite, Database
|
||||||
|
from shared_utilities import unix_time_millis
|
||||||
|
|
||||||
|
|
||||||
class TestDataCache(unittest.TestCase):
|
class TestDataCache(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# Setup the database connection here
|
# Set the database connection here
|
||||||
self.exchanges = ExchangeInterface()
|
self.exchanges = ExchangeInterface()
|
||||||
self.exchanges.connect_exchange(exchange_name='binance', user_name='test_guy', api_keys=None)
|
self.exchanges.connect_exchange(exchange_name='binance', user_name='test_guy', api_keys=None)
|
||||||
# This object maintains all the cached data. Pass it connection to the exchanges.
|
# This object maintains all the cached data. Pass it connection to the exchanges.
|
||||||
|
self.db_file = 'test_db.sqlite'
|
||||||
|
self.database = Database(db_file=self.db_file)
|
||||||
|
|
||||||
|
# Create necessary tables
|
||||||
|
with SQLite(db_file=self.db_file) as con:
|
||||||
|
cursor = con.cursor()
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS exchange (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT UNIQUE
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS markets (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
symbol TEXT,
|
||||||
|
exchange_id INTEGER,
|
||||||
|
FOREIGN KEY (exchange_id) REFERENCES exchange(id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS test_table (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
market_id INTEGER,
|
||||||
|
open_time INTEGER UNIQUE,
|
||||||
|
open REAL NOT NULL,
|
||||||
|
high REAL NOT NULL,
|
||||||
|
low REAL NOT NULL,
|
||||||
|
close REAL NOT NULL,
|
||||||
|
volume REAL NOT NULL,
|
||||||
|
FOREIGN KEY (market_id) REFERENCES markets(id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
self.data = DataCache(self.exchanges)
|
self.data = DataCache(self.exchanges)
|
||||||
|
self.data.db = self.database
|
||||||
|
|
||||||
asset, timeframe, exchange = 'BTC/USD', '2h', 'binance'
|
asset, timeframe, exchange = 'BTC/USD', '2h', 'binance'
|
||||||
self.key1 = f'{asset}_{timeframe}_{exchange}'
|
self.key1 = f'{asset}_{timeframe}_{exchange}'
|
||||||
|
|
@ -19,8 +57,11 @@ class TestDataCache(unittest.TestCase):
|
||||||
asset, timeframe, exchange = 'ETH/USD', '2h', 'binance'
|
asset, timeframe, exchange = 'ETH/USD', '2h', 'binance'
|
||||||
self.key2 = f'{asset}_{timeframe}_{exchange}'
|
self.key2 = f'{asset}_{timeframe}_{exchange}'
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
if os.path.exists(self.db_file):
|
||||||
|
os.remove(self.db_file)
|
||||||
|
|
||||||
def test_set_cache(self):
|
def test_set_cache(self):
|
||||||
# Tests
|
|
||||||
print('Testing set_cache flag not set:')
|
print('Testing set_cache flag not set:')
|
||||||
self.data.set_cache(data='data', key=self.key1)
|
self.data.set_cache(data='data', key=self.key1)
|
||||||
attr = self.data.__getattribute__('cached_data')
|
attr = self.data.__getattribute__('cached_data')
|
||||||
|
|
@ -36,7 +77,6 @@ class TestDataCache(unittest.TestCase):
|
||||||
self.assertEqual(attr[self.key1], 'more_data')
|
self.assertEqual(attr[self.key1], 'more_data')
|
||||||
|
|
||||||
def test_cache_exists(self):
|
def test_cache_exists(self):
|
||||||
# Tests
|
|
||||||
print('Testing cache_exists() method:')
|
print('Testing cache_exists() method:')
|
||||||
self.assertFalse(self.data.cache_exists(key=self.key2))
|
self.assertFalse(self.data.cache_exists(key=self.key2))
|
||||||
self.data.set_cache(data='data', key=self.key1)
|
self.data.set_cache(data='data', key=self.key1)
|
||||||
|
|
@ -44,7 +84,6 @@ class TestDataCache(unittest.TestCase):
|
||||||
|
|
||||||
def test_update_candle_cache(self):
|
def test_update_candle_cache(self):
|
||||||
print('Testing update_candle_cache() method:')
|
print('Testing update_candle_cache() method:')
|
||||||
# Initial data
|
|
||||||
df_initial = pd.DataFrame({
|
df_initial = pd.DataFrame({
|
||||||
'open_time': [1, 2, 3],
|
'open_time': [1, 2, 3],
|
||||||
'open': [100, 101, 102],
|
'open': [100, 101, 102],
|
||||||
|
|
@ -54,7 +93,6 @@ class TestDataCache(unittest.TestCase):
|
||||||
'volume': [1000, 1001, 1002]
|
'volume': [1000, 1001, 1002]
|
||||||
})
|
})
|
||||||
|
|
||||||
# Data to be added
|
|
||||||
df_new = pd.DataFrame({
|
df_new = pd.DataFrame({
|
||||||
'open_time': [3, 4, 5],
|
'open_time': [3, 4, 5],
|
||||||
'open': [102, 103, 104],
|
'open': [102, 103, 104],
|
||||||
|
|
@ -96,7 +134,7 @@ class TestDataCache(unittest.TestCase):
|
||||||
def test_get_records_since(self):
|
def test_get_records_since(self):
|
||||||
print('Testing get_records_since() method:')
|
print('Testing get_records_since() method:')
|
||||||
df_initial = pd.DataFrame({
|
df_initial = pd.DataFrame({
|
||||||
'open_time': [1, 2, 3],
|
'open_time': [unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=i)) for i in range(3)],
|
||||||
'open': [100, 101, 102],
|
'open': [100, 101, 102],
|
||||||
'high': [110, 111, 112],
|
'high': [110, 111, 112],
|
||||||
'low': [90, 91, 92],
|
'low': [90, 91, 92],
|
||||||
|
|
@ -105,20 +143,86 @@ class TestDataCache(unittest.TestCase):
|
||||||
})
|
})
|
||||||
|
|
||||||
self.data.set_cache(data=df_initial, key=self.key1)
|
self.data.set_cache(data=df_initial, key=self.key1)
|
||||||
start_datetime = dt.datetime.utcfromtimestamp(2)
|
start_datetime = dt.datetime.utcnow() - dt.timedelta(minutes=2)
|
||||||
result = self.data.get_records_since(key=self.key1, start_datetime=start_datetime, record_length=60, ex_details=[]).sort_values(by='open_time').reset_index(drop=True)
|
result = self.data.get_records_since(key=self.key1, start_datetime=start_datetime, record_length=60,
|
||||||
|
ex_details=['BTC/USD', '2h', 'binance'])
|
||||||
|
|
||||||
expected = pd.DataFrame({
|
expected = pd.DataFrame({
|
||||||
'open_time': [2, 3],
|
'open_time': df_initial['open_time'][:2].values,
|
||||||
'open': [101, 102],
|
'open': [100, 101],
|
||||||
'high': [111, 112],
|
'high': [110, 111],
|
||||||
'low': [91, 92],
|
'low': [90, 91],
|
||||||
'close': [106, 107],
|
'close': [105, 106],
|
||||||
'volume': [1001, 1002]
|
'volume': [1000, 1001]
|
||||||
})
|
})
|
||||||
|
|
||||||
pd.testing.assert_frame_equal(result, expected)
|
pd.testing.assert_frame_equal(result, expected)
|
||||||
|
|
||||||
|
def test_get_records_since_from_db(self):
|
||||||
|
print('Testing get_records_since_from_db() method:')
|
||||||
|
df_initial = pd.DataFrame({
|
||||||
|
'market_id': [None],
|
||||||
|
'open_time': [unix_time_millis(dt.datetime.utcnow())],
|
||||||
|
'open': [1.0],
|
||||||
|
'high': [1.0],
|
||||||
|
'low': [1.0],
|
||||||
|
'close': [1.0],
|
||||||
|
'volume': [1.0]
|
||||||
|
})
|
||||||
|
|
||||||
|
with SQLite(self.db_file) as con:
|
||||||
|
df_initial.to_sql('test_table', con, if_exists='append', index=False)
|
||||||
|
|
||||||
|
start_datetime = dt.datetime.utcnow() - dt.timedelta(minutes=1)
|
||||||
|
end_datetime = dt.datetime.utcnow()
|
||||||
|
result = self.data.get_records_since_from_db(table_name='test_table', st=start_datetime, et=end_datetime,
|
||||||
|
rl=1, ex_details=['BTC/USD', '2h', 'binance']).sort_values(
|
||||||
|
by='open_time').reset_index(drop=True)
|
||||||
|
|
||||||
|
print("Columns in the result DataFrame:", result.columns)
|
||||||
|
print("Result DataFrame:\n", result)
|
||||||
|
|
||||||
|
# Remove 'id' column from the result DataFrame if it exists
|
||||||
|
if 'id' in result.columns:
|
||||||
|
result = result.drop(columns=['id'])
|
||||||
|
|
||||||
|
expected = pd.DataFrame({
|
||||||
|
'market_id': [None],
|
||||||
|
'open_time': [unix_time_millis(dt.datetime.utcnow())],
|
||||||
|
'open': [1.0],
|
||||||
|
'high': [1.0],
|
||||||
|
'low': [1.0],
|
||||||
|
'close': [1.0],
|
||||||
|
'volume': [1.0]
|
||||||
|
})
|
||||||
|
|
||||||
|
print("Expected DataFrame:\n", expected)
|
||||||
|
|
||||||
|
pd.testing.assert_frame_equal(result, expected)
|
||||||
|
|
||||||
|
def test_populate_db(self):
|
||||||
|
print('Testing _populate_db() method:')
|
||||||
|
start_time = dt.datetime.utcnow() - dt.timedelta(days=1)
|
||||||
|
end_time = dt.datetime.utcnow()
|
||||||
|
|
||||||
|
result = self.data._populate_db(table_name='test_table', start_time=start_time,
|
||||||
|
end_time=end_time, ex_details=['BTC/USD', '2h', 'binance', 'test_guy'])
|
||||||
|
|
||||||
|
self.assertIsInstance(result, pd.DataFrame)
|
||||||
|
self.assertFalse(result.empty)
|
||||||
|
|
||||||
|
def test_fetch_candles_from_exchange(self):
|
||||||
|
print('Testing _fetch_candles_from_exchange() method:')
|
||||||
|
start_time = dt.datetime.utcnow() - dt.timedelta(days=1)
|
||||||
|
end_time = dt.datetime.utcnow()
|
||||||
|
|
||||||
|
result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h', exchange_name='binance',
|
||||||
|
user_name='test_guy', start_datetime=start_time,
|
||||||
|
end_datetime=end_time)
|
||||||
|
|
||||||
|
self.assertIsInstance(result, pd.DataFrame)
|
||||||
|
self.assertFalse(result.empty)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
|
|
@ -1,96 +1,219 @@
|
||||||
import datetime
|
import unittest
|
||||||
|
import sqlite3
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from database import make_query, make_insert, Database, HDict, SQLite
|
import datetime as dt
|
||||||
from exchangeinterface import ExchangeInterface
|
from Database import Database, SQLite, make_query, make_insert, HDict
|
||||||
from passlib.hash import bcrypt
|
from shared_utilities import unix_time_millis
|
||||||
from sqlalchemy import create_engine, text
|
|
||||||
import config
|
|
||||||
|
|
||||||
|
|
||||||
def test():
|
class TestSQLite(unittest.TestCase):
|
||||||
# un_hashed_pass = 'password'
|
def test_sqlite_context_manager(self):
|
||||||
hasher = bcrypt.using(rounds=13)
|
print("\nRunning test_sqlite_context_manager...")
|
||||||
# hashed_pass = hasher.hash(un_hashed_pass)
|
with SQLite(db_file='test_db.sqlite') as con:
|
||||||
# print(f'password: {un_hashed_pass}')
|
cursor = con.cursor()
|
||||||
# print(f'hashed pass: {hashed_pass}')
|
cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||||
# print(f" right pass: {hasher.verify('password', hashed_pass)}")
|
cursor.execute("INSERT INTO test_table (name) VALUES ('test')")
|
||||||
# print(f" wrong pass: {hasher.verify('passWord', hashed_pass)}")
|
cursor.execute('SELECT name FROM test_table WHERE name = ?', ('test',))
|
||||||
engine = create_engine("sqlite:///" + config.DB_FILE, echo=True)
|
result = cursor.fetchone()
|
||||||
with engine.connect() as conn:
|
self.assertEqual(result[0], 'test')
|
||||||
default_user = pd.read_sql_query(sql=text("SELECT * FROM users WHERE user_name = 'guest'"), con=conn)
|
print("SQLite context manager test passed.")
|
||||||
# hashed_password = default_user.password.values[0]
|
|
||||||
# print(f" verify pass: {hasher.verify('password', hashed_password)}")
|
|
||||||
username = default_user.user_name.values[0]
|
|
||||||
print(username)
|
|
||||||
|
|
||||||
|
|
||||||
def test_make_query():
|
class TestDatabase(unittest.TestCase):
|
||||||
values = {'first_field': 'first_value'}
|
def setUp(self):
|
||||||
item = 'market_id'
|
# Use a temporary SQLite database for testing purposes
|
||||||
table = 'markets'
|
self.db_file = 'test_db.sqlite'
|
||||||
q_str = make_query(item=item, table=table, values=values)
|
self.db = Database(db_file=self.db_file)
|
||||||
print(f'\nWith one indexing field: {q_str}')
|
self.connection = sqlite3.connect(self.db_file)
|
||||||
|
self.cursor = self.connection.cursor()
|
||||||
|
|
||||||
values = {'first_field': 'first_value', 'second_field': 'second_value'}
|
def tearDown(self):
|
||||||
q_str = make_query(item=item, table=table, values=values)
|
self.connection.close()
|
||||||
print(f'\nWith two indexing fields: {q_str}')
|
import os
|
||||||
assert q_str is not None
|
os.remove(self.db_file) # Remove the temporary database file after tests
|
||||||
|
|
||||||
|
def test_execute_sql(self):
|
||||||
|
print("\nRunning test_execute_sql...")
|
||||||
|
# Drop the table if it exists to avoid OperationalError
|
||||||
|
self.cursor.execute('DROP TABLE IF EXISTS test_table')
|
||||||
|
self.connection.commit()
|
||||||
|
|
||||||
|
sql = 'CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)'
|
||||||
|
self.db.execute_sql(sql)
|
||||||
|
|
||||||
|
self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test_table';")
|
||||||
|
result = self.cursor.fetchone()
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
print("Execute SQL test passed.")
|
||||||
|
|
||||||
|
def test_make_query(self):
|
||||||
|
print("\nRunning test_make_query...")
|
||||||
|
query = make_query('id', 'test_table', ['name'])
|
||||||
|
expected_query = 'SELECT id FROM test_table WHERE name = ?;'
|
||||||
|
self.assertEqual(query, expected_query)
|
||||||
|
print("Make query test passed.")
|
||||||
|
|
||||||
|
def test_make_insert(self):
|
||||||
|
print("\nRunning test_make_insert...")
|
||||||
|
insert = make_insert('test_table', ('name', 'age'))
|
||||||
|
expected_insert = "INSERT INTO test_table ('name', 'age') VALUES(?, ?);"
|
||||||
|
self.assertEqual(insert, expected_insert)
|
||||||
|
print("Make insert test passed.")
|
||||||
|
|
||||||
|
def test_get_item_where(self):
|
||||||
|
print("\nRunning test_get_item_where...")
|
||||||
|
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||||
|
self.cursor.execute("INSERT INTO test_table (id, name) VALUES (1, 'test')")
|
||||||
|
self.connection.commit()
|
||||||
|
item = self.db.get_item_where('name', 'test_table', ('id', 1))
|
||||||
|
self.assertEqual(item, 'test')
|
||||||
|
print("Get item where test passed.")
|
||||||
|
|
||||||
|
def test_get_rows_where(self):
|
||||||
|
print("\nRunning test_get_rows_where...")
|
||||||
|
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||||
|
self.cursor.execute("INSERT INTO test_table (id, name) VALUES (1, 'test')")
|
||||||
|
self.connection.commit()
|
||||||
|
rows = self.db.get_rows_where('test_table', ('name', 'test'))
|
||||||
|
self.assertIsInstance(rows, pd.DataFrame)
|
||||||
|
self.assertEqual(rows.iloc[0]['name'], 'test')
|
||||||
|
print("Get rows where test passed.")
|
||||||
|
|
||||||
|
def test_insert_dataframe(self):
|
||||||
|
print("\nRunning test_insert_dataframe...")
|
||||||
|
df = pd.DataFrame({'id': [1], 'name': ['test']})
|
||||||
|
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||||
|
self.connection.commit()
|
||||||
|
self.db.insert_dataframe(df, 'test_table')
|
||||||
|
self.cursor.execute('SELECT name FROM test_table WHERE id = 1')
|
||||||
|
result = self.cursor.fetchone()
|
||||||
|
self.assertEqual(result[0], 'test')
|
||||||
|
print("Insert dataframe test passed.")
|
||||||
|
|
||||||
|
def test_insert_row(self):
|
||||||
|
print("\nRunning test_insert_row...")
|
||||||
|
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||||
|
self.connection.commit()
|
||||||
|
self.db.insert_row('test_table', ('id', 'name'), (1, 'test'))
|
||||||
|
self.cursor.execute('SELECT name FROM test_table WHERE id = 1')
|
||||||
|
result = self.cursor.fetchone()
|
||||||
|
self.assertEqual(result[0], 'test')
|
||||||
|
print("Insert row test passed.")
|
||||||
|
|
||||||
|
def test_table_exists(self):
|
||||||
|
print("\nRunning test_table_exists...")
|
||||||
|
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
||||||
|
self.connection.commit()
|
||||||
|
exists = self.db.table_exists('test_table')
|
||||||
|
self.assertTrue(exists)
|
||||||
|
print("Table exists test passed.")
|
||||||
|
|
||||||
|
def test_get_timestamped_records(self):
|
||||||
|
print("\nRunning test_get_timestamped_records...")
|
||||||
|
df = pd.DataFrame({
|
||||||
|
'open_time': [unix_time_millis(dt.datetime.utcnow())],
|
||||||
|
'open': [1.0],
|
||||||
|
'high': [1.0],
|
||||||
|
'low': [1.0],
|
||||||
|
'close': [1.0],
|
||||||
|
'volume': [1.0]
|
||||||
|
})
|
||||||
|
table_name = 'test_table'
|
||||||
|
self.cursor.execute(f"""
|
||||||
|
CREATE TABLE {table_name} (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
open_time INTEGER UNIQUE,
|
||||||
|
open REAL NOT NULL,
|
||||||
|
high REAL NOT NULL,
|
||||||
|
low REAL NOT NULL,
|
||||||
|
close REAL NOT NULL,
|
||||||
|
volume REAL NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
self.connection.commit()
|
||||||
|
self.db.insert_dataframe(df, table_name)
|
||||||
|
st = dt.datetime.utcnow() - dt.timedelta(minutes=1)
|
||||||
|
et = dt.datetime.utcnow()
|
||||||
|
records = self.db.get_timestamped_records(table_name, 'open_time', st, et)
|
||||||
|
self.assertIsInstance(records, pd.DataFrame)
|
||||||
|
self.assertFalse(records.empty)
|
||||||
|
print("Get timestamped records test passed.")
|
||||||
|
|
||||||
|
def test_get_from_static_table(self):
|
||||||
|
print("\nRunning test_get_from_static_table...")
|
||||||
|
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT UNIQUE)')
|
||||||
|
self.connection.commit()
|
||||||
|
item = self.db.get_from_static_table('id', 'test_table', HDict({'name': 'test'}), create_id=True)
|
||||||
|
self.assertIsInstance(item, int)
|
||||||
|
self.cursor.execute('SELECT id FROM test_table WHERE name = ?', ('test',))
|
||||||
|
result = self.cursor.fetchone()
|
||||||
|
self.assertEqual(item, result[0])
|
||||||
|
print("Get from static table test passed.")
|
||||||
|
|
||||||
|
def test_insert_candles_into_db(self):
|
||||||
|
print("\nRunning test_insert_candles_into_db...")
|
||||||
|
df = pd.DataFrame({
|
||||||
|
'open_time': [unix_time_millis(dt.datetime.utcnow())],
|
||||||
|
'open': [1.0],
|
||||||
|
'high': [1.0],
|
||||||
|
'low': [1.0],
|
||||||
|
'close': [1.0],
|
||||||
|
'volume': [1.0]
|
||||||
|
})
|
||||||
|
table_name = 'test_table'
|
||||||
|
self.cursor.execute(f"""
|
||||||
|
CREATE TABLE {table_name} (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
market_id INTEGER,
|
||||||
|
open_time INTEGER UNIQUE,
|
||||||
|
open REAL NOT NULL,
|
||||||
|
high REAL NOT NULL,
|
||||||
|
low REAL NOT NULL,
|
||||||
|
close REAL NOT NULL,
|
||||||
|
volume REAL NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
self.connection.commit()
|
||||||
|
|
||||||
|
# Create the exchange and markets tables needed for the foreign key constraints
|
||||||
|
self.cursor.execute("""
|
||||||
|
CREATE TABLE exchange (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT UNIQUE
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
self.cursor.execute("""
|
||||||
|
CREATE TABLE markets (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
symbol TEXT,
|
||||||
|
exchange_id INTEGER,
|
||||||
|
FOREIGN KEY (exchange_id) REFERENCES exchange(id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
self.connection.commit()
|
||||||
|
|
||||||
|
self.db.insert_candles_into_db(df, table_name, 'BTC/USDT', 'binance')
|
||||||
|
self.cursor.execute(f'SELECT * FROM {table_name}')
|
||||||
|
result = self.cursor.fetchall()
|
||||||
|
self.assertFalse(len(result) == 0)
|
||||||
|
print("Insert candles into db test passed.")
|
||||||
|
|
||||||
|
|
||||||
def test_make_insert():
|
if __name__ == '__main__':
|
||||||
table = 'markets'
|
unittest.main()
|
||||||
values = {'first_field': 'first_value'}
|
|
||||||
q_str = make_insert(table=table, values=values)
|
|
||||||
print(f'\nWith one indexing field: {q_str}')
|
|
||||||
|
|
||||||
values = {'first_field': 'first_value', 'second_field': 'second_value'}
|
# def test():
|
||||||
q_str = make_insert(table=table, values=values)
|
# # un_hashed_pass = 'password'
|
||||||
print(f'\nWith two indexing fields: {q_str}')
|
# hasher = bcrypt.using(rounds=13)
|
||||||
assert q_str is not None
|
# # hashed_pass = hasher.hash(un_hashed_pass)
|
||||||
|
# # print(f'password: {un_hashed_pass}')
|
||||||
|
# # print(f'hashed pass: {hashed_pass}')
|
||||||
def test__table_exists():
|
# # print(f" right pass: {hasher.verify('password', hashed_pass)}")
|
||||||
exchanges = ExchangeInterface()
|
# # print(f" wrong pass: {hasher.verify('passWord', hashed_pass)}")
|
||||||
d_obj = Database(exchanges)
|
# engine = create_engine("sqlite:///" + config.DB_FILE, echo=True)
|
||||||
exists = d_obj._table_exists('BTC/USD_5m_alpaca')
|
# with engine.connect() as conn:
|
||||||
print(f'\nExists - Should be true: {exists}')
|
# default_user = pd.read_sql_query(sql=text("SELECT * FROM users WHERE user_name = 'guest'"), con=conn)
|
||||||
assert exists
|
# # hashed_password = default_user.password.values[0]
|
||||||
exists = d_obj._table_exists('BTC/USD_5m_alpina')
|
# # print(f" verify pass: {hasher.verify('password', hashed_password)}")
|
||||||
print(f'Doesnt exist - should be false: {exists}')
|
# username = default_user.user_name.values[0]
|
||||||
assert not exists
|
# print(username)
|
||||||
|
|
||||||
|
|
||||||
def test_get_from_static_table():
|
|
||||||
exchanges = ExchangeInterface()
|
|
||||||
d_obj = Database(exchanges)
|
|
||||||
market_id = d_obj._fetch_market_id('BTC/USD', 'alpaca')
|
|
||||||
e_id = d_obj._fetch_exchange_id('alpaca')
|
|
||||||
print(f'market id: {market_id}')
|
|
||||||
assert market_id > 0
|
|
||||||
print(f'exchange_name ID: {e_id}')
|
|
||||||
assert e_id == 4
|
|
||||||
|
|
||||||
|
|
||||||
def test_populate_table():
|
|
||||||
"""
|
|
||||||
Populates a database table with records from the exchange_name.
|
|
||||||
:param table_name: str - The name of the table in the database.
|
|
||||||
:param start_time: datetime - The starting time to fetch the records from.
|
|
||||||
:param end_time: datetime - The end time to get the records until.
|
|
||||||
:return: pdDataframe: - The data that was downloaded.
|
|
||||||
"""
|
|
||||||
exchanges = ExchangeInterface()
|
|
||||||
d_obj = Database(exchanges)
|
|
||||||
d_obj._populate_table(table_name='BTC/USD_2h_alpaca',
|
|
||||||
start_time=datetime.datetime(year=2023, month=3, day=27, hour=6, minute=0))
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_records_since():
|
|
||||||
exchanges = ExchangeInterface()
|
|
||||||
d_obj = Database(exchanges)
|
|
||||||
records = d_obj.get_records_since(table_name='BTC/USD_15m_alpaca',
|
|
||||||
st=datetime.datetime(year=2023, month=3, day=27, hour=1, minute=0),
|
|
||||||
et=datetime.datetime.utcnow(),
|
|
||||||
rl=15)
|
|
||||||
print(records)
|
|
||||||
assert records is not None
|
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ class TestExchange(unittest.TestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
exchange_name = 'binance'
|
exchange_name = 'kraken'
|
||||||
cls.api_keys = None
|
cls.api_keys = None
|
||||||
"""Uncomment and Provide api keys to connect to exchange."""
|
"""Uncomment and Provide api keys to connect to exchange."""
|
||||||
# cls.api_keys = {'key': 'EXCHANGE_API_KEY', 'secret': 'EXCHANGE_API_SECRET'}
|
# cls.api_keys = {'key': 'EXCHANGE_API_KEY', 'secret': 'EXCHANGE_API_SECRET'}
|
||||||
Loading…
Reference in New Issue