brighter-trading/src/Database.py

321 lines
13 KiB
Python

import sqlite3
from functools import lru_cache
from typing import Any, Dict, List, Tuple
import config
import datetime as dt
import pandas as pd
from shared_utilities import unix_time_millis
class SQLite:
"""
Context manager for SQLite database connections.
Accepts a database file name or defaults to the file in config.DB_FILE.
Example usage:
--------------
with SQLite(db_file='test_db.sqlite') as con:
cursor = con.cursor()
cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
"""
def __init__(self, db_file: str = None):
self.db_file = db_file if db_file else config.DB_FILE
self.connection = sqlite3.connect(self.db_file)
def __enter__(self):
return self.connection
def __exit__(self, exc_type, exc_val, exc_tb):
self.connection.commit()
self.connection.close()
class HDict(dict):
"""
Hashable dictionary to use as data keys.
Example usage:
--------------
hdict = HDict({'key1': 'value1', 'key2': 'value2'})
hash(hdict)
"""
def __hash__(self) -> int:
return hash(frozenset(self.items()))
def make_query(item: str, table: str, columns: List[str]) -> str:
"""
Creates a SQL select query string with the required number of placeholders.
:param item: The field to select.
:param table: The table to select from.
:param columns: List of columns for the where clause.
:return: The query string.
"""
placeholders = " AND ".join([f"{col} = ?" for col in columns])
return f"SELECT {item} FROM {table} WHERE {placeholders};"
def make_insert(table: str, columns: Tuple[str, ...], replace: bool = False) -> str:
"""
Creates a SQL insert query string with the required number of placeholders.
:param replace: bool will replace if set.
:param table: The table to insert into.
:param columns: Tuple of column names.
:return: The query string.
"""
col_names = ", ".join([f'"{col}"' for col in columns]) # Use double quotes for column names
placeholders = ", ".join(["?" for _ in columns])
if replace:
return f'INSERT OR REPLACE INTO "{table}" ({col_names}) VALUES ({placeholders});'
return f'INSERT INTO "{table}" ({col_names}) VALUES ({placeholders});'
class Database:
"""
Database class to communicate and maintain the database.
Handles connections and operations for the given exchanges.
Example usage:
--------------
db = Database(db_file='test_db.sqlite')
db.execute_sql('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
"""
def __init__(self, db_file: str = None):
self.db_file = db_file
def execute_sql(self, sql: str, params: list = None) -> None:
"""
Executes a raw SQL statement with optional parameters.
:param sql: SQL statement to execute.
:param params: Optional tuple of parameters to pass with the SQL statement.
"""
with SQLite(self.db_file) as con:
cur = con.cursor()
cur.execute(sql, params)
def get_item_where(self, item_name: str, table_name: str, filter_vals: Tuple[str, Any]) -> Any:
"""
Returns an item from a table where the filter results should isolate a single row.
:param item_name: Name of the item to fetch.
:param table_name: Name of the table.
:param filter_vals: Tuple of column name and value to filter by.
:return: The item.
"""
with SQLite(self.db_file) as con:
cur = con.cursor()
qry = make_query(item_name, table_name, [filter_vals[0]])
cur.execute(qry, (filter_vals[1],))
if result := cur.fetchone():
return result[0]
else:
error = f"Couldn't fetch item {item_name} from {table_name} where {filter_vals[0]} = {filter_vals[1]}"
raise ValueError(error)
def get_rows_where(self, table: str, filter_vals: List[Tuple[str, Any]]) -> pd.DataFrame | None:
"""
Returns a DataFrame containing all rows of a table that meet the filter criteria.
:param table: Name of the table.
:param filter_vals: List of tuples containing column names and values to filter by.
:return: DataFrame of the query result or None if empty or column does not exist.
"""
try:
with SQLite(self.db_file) as con:
where_clauses = []
params = []
# Construct the WHERE clause, handling lists for 'IN' conditions
for col, val in filter_vals:
if isinstance(val, list):
# If the value is a list, use the 'IN' clause
placeholders = ', '.join('?' for _ in val)
where_clauses.append(f"{col} IN ({placeholders})")
params.extend(val) # Extend the parameters with the list values
else:
where_clauses.append(f"{col} = ?")
params.append(val)
# Prepare and execute the query with the constructed WHERE clause
where_clause = " AND ".join(where_clauses)
qry = f"SELECT * FROM {table} WHERE {where_clause}"
result = pd.read_sql(qry, con, params=params)
return result if not result.empty else None
except (sqlite3.OperationalError, pd.errors.DatabaseError) as e:
# Log the error or handle it appropriately
print(f"Error querying table '{table}' with filters {filter_vals}: {e}")
return None
def insert_dataframe(self, df: pd.DataFrame, table: str) -> int:
"""
Inserts a DataFrame into a specified table and returns the last inserted row's ID.
:param df: DataFrame to insert.
:param table: Name of the table.
:return: The auto-incremented ID of the last inserted row.
"""
with SQLite(self.db_file) as con:
# Insert the DataFrame into the specified table
df.to_sql(name=table, con=con, index=False, if_exists='append')
# Fetch the last inserted row ID
cursor = con.execute('SELECT last_insert_rowid()')
last_id = cursor.fetchone()[0]
return last_id
def insert_row(self, table: str, columns: Tuple[str, ...], values: Tuple[Any, ...]) -> int:
"""
Inserts a row into a specified table and returns the auto-incremented ID.
:param table: Name of the table.
:param columns: Tuple of column names.
:param values: Tuple of values to insert.
:return: The auto-incremented ID of the inserted row.
"""
with SQLite(self.db_file) as conn:
cursor = conn.cursor()
sql = make_insert(table=table, columns=columns)
cursor.execute(sql, values)
# Return the auto-incremented ID
return cursor.lastrowid
def table_exists(self, table_name: str) -> bool:
"""
Checks if a table exists in the database.
:param table_name: Name of the table.
:return: True if the table exists, False otherwise.
"""
with SQLite(self.db_file) as conn:
cursor = conn.cursor()
sql = make_query(item='name', table='sqlite_schema', columns=['type', 'name'])
cursor.execute(sql, ('table', table_name))
result = cursor.fetchone()
return result is not None
@lru_cache(maxsize=1000)
def get_from_static_table(self, item: str, table: str, indexes: dict, create_id: bool = False) -> Any:
"""
Returns the row id of an item from a table specified. If the item isn't listed in the table,
it will insert the item into a new row and return the autoincremented id. The item received as a hashable
dictionary so the results can be cached.
:param item: Name of the item requested.
:param table: Table being queried.
:param indexes: Hashable dictionary of indexing columns and their values.
:param create_id: If True, create a row if it doesn't exist and return the autoincrement ID.
:return: The content of the field.
"""
with SQLite(self.db_file) as conn:
cursor = conn.cursor()
sql = make_query(item, table, list(indexes.keys()))
cursor.execute(sql, tuple(indexes.values()))
result = cursor.fetchone()
if result is None and create_id:
sql = make_insert(table, tuple(indexes.keys()))
cursor.execute(sql, tuple(indexes.values()))
result = cursor.lastrowid # Get the last inserted row ID
else:
result = result[0] if result else None
return result
def _fetch_exchange_id(self, exchange_name: str) -> int:
"""
Fetches the primary ID of an exchange from the database.
:param exchange_name: Name of the exchange.
:return: Primary ID of the exchange.
"""
return self.get_from_static_table(item='id', table='exchange', create_id=True,
indexes=HDict({'name': exchange_name}))
def _fetch_market_id(self, symbol: str, exchange_name: str) -> int:
"""
Returns the markets ID for a trading pair listed in the database.
:param symbol: Symbol of the trading pair.
:param exchange_name: Name of the exchange.
:return: Market ID.
"""
exchange_id = self._fetch_exchange_id(exchange_name)
market_id = self.get_from_static_table(item='id', table='markets', create_id=True,
indexes=HDict({'symbol': symbol, 'exchange_id': exchange_id}))
return market_id
def insert_candles_into_db(self, candlesticks: pd.DataFrame, table_name: str, symbol: str,
exchange_name: str) -> None:
"""
Inserts all candlesticks from a DataFrame into the database.
:param candlesticks: DataFrame of candlestick data.
:param table_name: Name of the table to insert into.
:param symbol: Symbol of the trading pair.
:param exchange_name: Name of the exchange.
"""
market_id = self._fetch_market_id(symbol, exchange_name)
# Check if 'market_id' column already exists in the DataFrame
if 'market_id' in candlesticks.columns:
# If it exists, set its value to the fetched market_id
candlesticks['market_id'] = market_id
else:
# If it doesn't exist, insert it as the first column
candlesticks.insert(0, 'market_id', market_id)
sql_create = f"""
CREATE TABLE IF NOT EXISTS '{table_name}' (
id INTEGER PRIMARY KEY,
market_id INTEGER,
time INTEGER UNIQUE ON CONFLICT IGNORE,
open REAL NOT NULL,
high REAL NOT NULL,
low REAL NOT NULL,
close REAL NOT NULL,
volume REAL NOT NULL,
FOREIGN KEY (market_id) REFERENCES market (id)
)"""
with SQLite(self.db_file) as conn:
cursor = conn.cursor()
cursor.execute(sql_create)
candlesticks.to_sql(table_name, conn, if_exists='append', index=False)
def get_timestamped_records(self, table_name: str, timestamp_field: str, st: dt.datetime,
et: dt.datetime = None) -> pd.DataFrame:
"""
Returns records from a specified table in the database that have timestamps greater than or equal to a given
start time and, optionally, less than or equal to a given end time.
:param table_name: Database table name.
:param timestamp_field: Field name that contains the timestamp.
:param st: Start datetime.
:param et: End datetime (optional).
:return: DataFrame of records.
"""
with SQLite(self.db_file) as conn:
start_stamp = unix_time_millis(st)
if et is not None:
end_stamp = unix_time_millis(et)
q_str = (
f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= ? "
f"AND {timestamp_field} <= ?;"
)
records = pd.read_sql(q_str, conn, params=(start_stamp, end_stamp))
else:
q_str = f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= ?;"
records = pd.read_sql(q_str, conn, params=(start_stamp,))
# records = records.drop('id', axis=1) Todo: Reminder I may need to put this back later.
return records