import sqlite3 from functools import lru_cache from typing import Any, Dict, List, Tuple import config import datetime as dt import pandas as pd from shared_utilities import unix_time_millis class SQLite: """ Context manager for SQLite database connections. Accepts a database file name or defaults to the file in config.DB_FILE. Example usage: -------------- with SQLite(db_file='test_db.sqlite') as con: cursor = con.cursor() cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)') """ def __init__(self, db_file: str = None): self.db_file = db_file if db_file else config.DB_FILE self.connection = sqlite3.connect(self.db_file) def __enter__(self): return self.connection def __exit__(self, exc_type, exc_val, exc_tb): self.connection.commit() self.connection.close() class HDict(dict): """ Hashable dictionary to use as data keys. Example usage: -------------- hdict = HDict({'key1': 'value1', 'key2': 'value2'}) hash(hdict) """ def __hash__(self) -> int: return hash(frozenset(self.items())) def make_query(item: str, table: str, columns: List[str]) -> str: """ Creates a SQL select query string with the required number of placeholders. :param item: The field to select. :param table: The table to select from. :param columns: List of columns for the where clause. :return: The query string. """ placeholders = " AND ".join([f"{col} = ?" for col in columns]) return f"SELECT {item} FROM {table} WHERE {placeholders};" def make_insert(table: str, columns: Tuple[str, ...], replace: bool = False) -> str: """ Creates a SQL insert query string with the required number of placeholders. :param replace: bool will replace if set. :param table: The table to insert into. :param columns: Tuple of column names. :return: The query string. """ col_names = ", ".join([f'"{col}"' for col in columns]) # Use double quotes for column names placeholders = ", ".join(["?" for _ in columns]) if replace: return f'INSERT OR REPLACE INTO "{table}" ({col_names}) VALUES ({placeholders});' return f'INSERT INTO "{table}" ({col_names}) VALUES ({placeholders});' class Database: """ Database class to communicate and maintain the database. Handles connections and operations for the given exchanges. Example usage: -------------- db = Database(db_file='test_db.sqlite') db.execute_sql('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)') """ def __init__(self, db_file: str = None): self.db_file = db_file def execute_sql(self, sql: str, params: list = None) -> None: """ Executes a raw SQL statement with optional parameters. :param sql: SQL statement to execute. :param params: Optional tuple of parameters to pass with the SQL statement. """ with SQLite(self.db_file) as con: cur = con.cursor() cur.execute(sql, params) def get_item_where(self, item_name: str, table_name: str, filter_vals: Tuple[str, Any]) -> Any: """ Returns an item from a table where the filter results should isolate a single row. :param item_name: Name of the item to fetch. :param table_name: Name of the table. :param filter_vals: Tuple of column name and value to filter by. :return: The item. """ with SQLite(self.db_file) as con: cur = con.cursor() qry = make_query(item_name, table_name, [filter_vals[0]]) cur.execute(qry, (filter_vals[1],)) if result := cur.fetchone(): return result[0] else: error = f"Couldn't fetch item {item_name} from {table_name} where {filter_vals[0]} = {filter_vals[1]}" raise ValueError(error) def get_rows_where(self, table: str, filter_vals: List[Tuple[str, Any]]) -> pd.DataFrame | None: """ Returns a DataFrame containing all rows of a table that meet the filter criteria. :param table: Name of the table. :param filter_vals: List of tuples containing column names and values to filter by. :return: DataFrame of the query result or None if empty or column does not exist. """ try: with SQLite(self.db_file) as con: where_clauses = [] params = [] # Construct the WHERE clause, handling lists for 'IN' conditions for col, val in filter_vals: if isinstance(val, list): # If the value is a list, use the 'IN' clause placeholders = ', '.join('?' for _ in val) where_clauses.append(f"{col} IN ({placeholders})") params.extend(val) # Extend the parameters with the list values else: where_clauses.append(f"{col} = ?") params.append(val) # Prepare and execute the query with the constructed WHERE clause where_clause = " AND ".join(where_clauses) qry = f"SELECT * FROM {table} WHERE {where_clause}" result = pd.read_sql(qry, con, params=params) return result if not result.empty else None except (sqlite3.OperationalError, pd.errors.DatabaseError) as e: # Log the error or handle it appropriately print(f"Error querying table '{table}' with filters {filter_vals}: {e}") return None def insert_dataframe(self, df: pd.DataFrame, table: str) -> int: """ Inserts a DataFrame into a specified table and returns the last inserted row's ID. :param df: DataFrame to insert. :param table: Name of the table. :return: The auto-incremented ID of the last inserted row. """ with SQLite(self.db_file) as con: # Insert the DataFrame into the specified table df.to_sql(name=table, con=con, index=False, if_exists='append') # Fetch the last inserted row ID cursor = con.execute('SELECT last_insert_rowid()') last_id = cursor.fetchone()[0] return last_id def insert_row(self, table: str, columns: Tuple[str, ...], values: Tuple[Any, ...]) -> int: """ Inserts a row into a specified table and returns the auto-incremented ID. :param table: Name of the table. :param columns: Tuple of column names. :param values: Tuple of values to insert. :return: The auto-incremented ID of the inserted row. """ with SQLite(self.db_file) as conn: cursor = conn.cursor() sql = make_insert(table=table, columns=columns) cursor.execute(sql, values) # Return the auto-incremented ID return cursor.lastrowid def table_exists(self, table_name: str) -> bool: """ Checks if a table exists in the database. :param table_name: Name of the table. :return: True if the table exists, False otherwise. """ with SQLite(self.db_file) as conn: cursor = conn.cursor() sql = make_query(item='name', table='sqlite_schema', columns=['type', 'name']) cursor.execute(sql, ('table', table_name)) result = cursor.fetchone() return result is not None @lru_cache(maxsize=1000) def get_from_static_table(self, item: str, table: str, indexes: dict, create_id: bool = False) -> Any: """ Returns the row id of an item from a table specified. If the item isn't listed in the table, it will insert the item into a new row and return the autoincremented id. The item received as a hashable dictionary so the results can be cached. :param item: Name of the item requested. :param table: Table being queried. :param indexes: Hashable dictionary of indexing columns and their values. :param create_id: If True, create a row if it doesn't exist and return the autoincrement ID. :return: The content of the field. """ with SQLite(self.db_file) as conn: cursor = conn.cursor() sql = make_query(item, table, list(indexes.keys())) cursor.execute(sql, tuple(indexes.values())) result = cursor.fetchone() if result is None and create_id: sql = make_insert(table, tuple(indexes.keys())) cursor.execute(sql, tuple(indexes.values())) result = cursor.lastrowid # Get the last inserted row ID else: result = result[0] if result else None return result def _fetch_exchange_id(self, exchange_name: str) -> int: """ Fetches the primary ID of an exchange from the database. :param exchange_name: Name of the exchange. :return: Primary ID of the exchange. """ return self.get_from_static_table(item='id', table='exchange', create_id=True, indexes=HDict({'name': exchange_name})) def _fetch_market_id(self, symbol: str, exchange_name: str) -> int: """ Returns the markets ID for a trading pair listed in the database. :param symbol: Symbol of the trading pair. :param exchange_name: Name of the exchange. :return: Market ID. """ exchange_id = self._fetch_exchange_id(exchange_name) market_id = self.get_from_static_table(item='id', table='markets', create_id=True, indexes=HDict({'symbol': symbol, 'exchange_id': exchange_id})) return market_id def insert_candles_into_db(self, candlesticks: pd.DataFrame, table_name: str, symbol: str, exchange_name: str) -> None: """ Inserts all candlesticks from a DataFrame into the database. :param candlesticks: DataFrame of candlestick data. :param table_name: Name of the table to insert into. :param symbol: Symbol of the trading pair. :param exchange_name: Name of the exchange. """ market_id = self._fetch_market_id(symbol, exchange_name) # Check if 'market_id' column already exists in the DataFrame if 'market_id' in candlesticks.columns: # If it exists, set its value to the fetched market_id candlesticks['market_id'] = market_id else: # If it doesn't exist, insert it as the first column candlesticks.insert(0, 'market_id', market_id) sql_create = f""" CREATE TABLE IF NOT EXISTS '{table_name}' ( id INTEGER PRIMARY KEY, market_id INTEGER, time INTEGER UNIQUE ON CONFLICT IGNORE, open REAL NOT NULL, high REAL NOT NULL, low REAL NOT NULL, close REAL NOT NULL, volume REAL NOT NULL, FOREIGN KEY (market_id) REFERENCES market (id) )""" with SQLite(self.db_file) as conn: cursor = conn.cursor() cursor.execute(sql_create) candlesticks.to_sql(table_name, conn, if_exists='append', index=False) def get_timestamped_records(self, table_name: str, timestamp_field: str, st: dt.datetime, et: dt.datetime = None) -> pd.DataFrame: """ Returns records from a specified table in the database that have timestamps greater than or equal to a given start time and, optionally, less than or equal to a given end time. :param table_name: Database table name. :param timestamp_field: Field name that contains the timestamp. :param st: Start datetime. :param et: End datetime (optional). :return: DataFrame of records. """ with SQLite(self.db_file) as conn: start_stamp = unix_time_millis(st) if et is not None: end_stamp = unix_time_millis(et) q_str = ( f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= ? " f"AND {timestamp_field} <= ?;" ) records = pd.read_sql(q_str, conn, params=(start_stamp, end_stamp)) else: q_str = f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= ?;" records = pd.read_sql(q_str, conn, params=(start_stamp,)) # records = records.drop('id', axis=1) Todo: Reminder I may need to put this back later. return records