321 lines
13 KiB
Python
321 lines
13 KiB
Python
import sqlite3
|
|
from functools import lru_cache
|
|
from typing import Any, Dict, List, Tuple
|
|
import config
|
|
import datetime as dt
|
|
import pandas as pd
|
|
from shared_utilities import unix_time_millis
|
|
|
|
|
|
class SQLite:
|
|
"""
|
|
Context manager for SQLite database connections.
|
|
Accepts a database file name or defaults to the file in config.DB_FILE.
|
|
|
|
Example usage:
|
|
--------------
|
|
with SQLite(db_file='test_db.sqlite') as con:
|
|
cursor = con.cursor()
|
|
cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
|
"""
|
|
|
|
def __init__(self, db_file: str = None):
|
|
self.db_file = db_file if db_file else config.DB_FILE
|
|
self.connection = sqlite3.connect(self.db_file)
|
|
|
|
def __enter__(self):
|
|
return self.connection
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.connection.commit()
|
|
self.connection.close()
|
|
|
|
|
|
class HDict(dict):
|
|
"""
|
|
Hashable dictionary to use as data keys.
|
|
|
|
Example usage:
|
|
--------------
|
|
hdict = HDict({'key1': 'value1', 'key2': 'value2'})
|
|
hash(hdict)
|
|
"""
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(frozenset(self.items()))
|
|
|
|
|
|
def make_query(item: str, table: str, columns: List[str]) -> str:
|
|
"""
|
|
Creates a SQL select query string with the required number of placeholders.
|
|
|
|
:param item: The field to select.
|
|
:param table: The table to select from.
|
|
:param columns: List of columns for the where clause.
|
|
:return: The query string.
|
|
"""
|
|
placeholders = " AND ".join([f"{col} = ?" for col in columns])
|
|
return f"SELECT {item} FROM {table} WHERE {placeholders};"
|
|
|
|
|
|
def make_insert(table: str, columns: Tuple[str, ...], replace: bool = False) -> str:
|
|
"""
|
|
Creates a SQL insert query string with the required number of placeholders.
|
|
|
|
:param replace: bool will replace if set.
|
|
:param table: The table to insert into.
|
|
:param columns: Tuple of column names.
|
|
:return: The query string.
|
|
"""
|
|
col_names = ", ".join([f'"{col}"' for col in columns]) # Use double quotes for column names
|
|
placeholders = ", ".join(["?" for _ in columns])
|
|
if replace:
|
|
return f'INSERT OR REPLACE INTO "{table}" ({col_names}) VALUES ({placeholders});'
|
|
return f'INSERT INTO "{table}" ({col_names}) VALUES ({placeholders});'
|
|
|
|
|
|
|
|
class Database:
|
|
"""
|
|
Database class to communicate and maintain the database.
|
|
Handles connections and operations for the given exchanges.
|
|
|
|
Example usage:
|
|
--------------
|
|
db = Database(db_file='test_db.sqlite')
|
|
db.execute_sql('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
|
"""
|
|
|
|
def __init__(self, db_file: str = None):
|
|
self.db_file = db_file
|
|
|
|
def execute_sql(self, sql: str, params: list = None) -> None:
|
|
"""
|
|
Executes a raw SQL statement with optional parameters.
|
|
|
|
:param sql: SQL statement to execute.
|
|
:param params: Optional tuple of parameters to pass with the SQL statement.
|
|
"""
|
|
with SQLite(self.db_file) as con:
|
|
cur = con.cursor()
|
|
cur.execute(sql, params)
|
|
|
|
def get_item_where(self, item_name: str, table_name: str, filter_vals: Tuple[str, Any]) -> Any:
|
|
"""
|
|
Returns an item from a table where the filter results should isolate a single row.
|
|
|
|
:param item_name: Name of the item to fetch.
|
|
:param table_name: Name of the table.
|
|
:param filter_vals: Tuple of column name and value to filter by.
|
|
:return: The item.
|
|
"""
|
|
with SQLite(self.db_file) as con:
|
|
cur = con.cursor()
|
|
qry = make_query(item_name, table_name, [filter_vals[0]])
|
|
cur.execute(qry, (filter_vals[1],))
|
|
if result := cur.fetchone():
|
|
return result[0]
|
|
else:
|
|
error = f"Couldn't fetch item {item_name} from {table_name} where {filter_vals[0]} = {filter_vals[1]}"
|
|
raise ValueError(error)
|
|
|
|
def get_rows_where(self, table: str, filter_vals: List[Tuple[str, Any]]) -> pd.DataFrame | None:
|
|
"""
|
|
Returns a DataFrame containing all rows of a table that meet the filter criteria.
|
|
|
|
:param table: Name of the table.
|
|
:param filter_vals: List of tuples containing column names and values to filter by.
|
|
:return: DataFrame of the query result or None if empty or column does not exist.
|
|
"""
|
|
try:
|
|
with SQLite(self.db_file) as con:
|
|
where_clauses = []
|
|
params = []
|
|
|
|
# Construct the WHERE clause, handling lists for 'IN' conditions
|
|
for col, val in filter_vals:
|
|
if isinstance(val, list):
|
|
# If the value is a list, use the 'IN' clause
|
|
placeholders = ', '.join('?' for _ in val)
|
|
where_clauses.append(f"{col} IN ({placeholders})")
|
|
params.extend(val) # Extend the parameters with the list values
|
|
else:
|
|
where_clauses.append(f"{col} = ?")
|
|
params.append(val)
|
|
|
|
# Prepare and execute the query with the constructed WHERE clause
|
|
where_clause = " AND ".join(where_clauses)
|
|
qry = f"SELECT * FROM {table} WHERE {where_clause}"
|
|
result = pd.read_sql(qry, con, params=params)
|
|
|
|
return result if not result.empty else None
|
|
except (sqlite3.OperationalError, pd.errors.DatabaseError) as e:
|
|
# Log the error or handle it appropriately
|
|
print(f"Error querying table '{table}' with filters {filter_vals}: {e}")
|
|
return None
|
|
|
|
def insert_dataframe(self, df: pd.DataFrame, table: str) -> int:
|
|
"""
|
|
Inserts a DataFrame into a specified table and returns the last inserted row's ID.
|
|
|
|
:param df: DataFrame to insert.
|
|
:param table: Name of the table.
|
|
:return: The auto-incremented ID of the last inserted row.
|
|
"""
|
|
with SQLite(self.db_file) as con:
|
|
# Insert the DataFrame into the specified table
|
|
df.to_sql(name=table, con=con, index=False, if_exists='append')
|
|
|
|
# Fetch the last inserted row ID
|
|
cursor = con.execute('SELECT last_insert_rowid()')
|
|
last_id = cursor.fetchone()[0]
|
|
|
|
return last_id
|
|
|
|
def insert_row(self, table: str, columns: Tuple[str, ...], values: Tuple[Any, ...]) -> int:
|
|
"""
|
|
Inserts a row into a specified table and returns the auto-incremented ID.
|
|
|
|
:param table: Name of the table.
|
|
:param columns: Tuple of column names.
|
|
:param values: Tuple of values to insert.
|
|
:return: The auto-incremented ID of the inserted row.
|
|
"""
|
|
with SQLite(self.db_file) as conn:
|
|
cursor = conn.cursor()
|
|
sql = make_insert(table=table, columns=columns)
|
|
cursor.execute(sql, values)
|
|
|
|
# Return the auto-incremented ID
|
|
return cursor.lastrowid
|
|
|
|
def table_exists(self, table_name: str) -> bool:
|
|
"""
|
|
Checks if a table exists in the database.
|
|
|
|
:param table_name: Name of the table.
|
|
:return: True if the table exists, False otherwise.
|
|
"""
|
|
with SQLite(self.db_file) as conn:
|
|
cursor = conn.cursor()
|
|
sql = make_query(item='name', table='sqlite_schema', columns=['type', 'name'])
|
|
cursor.execute(sql, ('table', table_name))
|
|
result = cursor.fetchone()
|
|
return result is not None
|
|
|
|
@lru_cache(maxsize=1000)
|
|
def get_from_static_table(self, item: str, table: str, indexes: dict, create_id: bool = False) -> Any:
|
|
"""
|
|
Returns the row id of an item from a table specified. If the item isn't listed in the table,
|
|
it will insert the item into a new row and return the autoincremented id. The item received as a hashable
|
|
dictionary so the results can be cached.
|
|
|
|
:param item: Name of the item requested.
|
|
:param table: Table being queried.
|
|
:param indexes: Hashable dictionary of indexing columns and their values.
|
|
:param create_id: If True, create a row if it doesn't exist and return the autoincrement ID.
|
|
:return: The content of the field.
|
|
"""
|
|
with SQLite(self.db_file) as conn:
|
|
cursor = conn.cursor()
|
|
sql = make_query(item, table, list(indexes.keys()))
|
|
cursor.execute(sql, tuple(indexes.values()))
|
|
result = cursor.fetchone()
|
|
|
|
if result is None and create_id:
|
|
sql = make_insert(table, tuple(indexes.keys()))
|
|
cursor.execute(sql, tuple(indexes.values()))
|
|
result = cursor.lastrowid # Get the last inserted row ID
|
|
else:
|
|
result = result[0] if result else None
|
|
|
|
return result
|
|
|
|
def _fetch_exchange_id(self, exchange_name: str) -> int:
|
|
"""
|
|
Fetches the primary ID of an exchange from the database.
|
|
|
|
:param exchange_name: Name of the exchange.
|
|
:return: Primary ID of the exchange.
|
|
"""
|
|
return self.get_from_static_table(item='id', table='exchange', create_id=True,
|
|
indexes=HDict({'name': exchange_name}))
|
|
|
|
def _fetch_market_id(self, symbol: str, exchange_name: str) -> int:
|
|
"""
|
|
Returns the markets ID for a trading pair listed in the database.
|
|
|
|
:param symbol: Symbol of the trading pair.
|
|
:param exchange_name: Name of the exchange.
|
|
:return: Market ID.
|
|
"""
|
|
exchange_id = self._fetch_exchange_id(exchange_name)
|
|
market_id = self.get_from_static_table(item='id', table='markets', create_id=True,
|
|
indexes=HDict({'symbol': symbol, 'exchange_id': exchange_id}))
|
|
return market_id
|
|
|
|
def insert_candles_into_db(self, candlesticks: pd.DataFrame, table_name: str, symbol: str,
|
|
exchange_name: str) -> None:
|
|
"""
|
|
Inserts all candlesticks from a DataFrame into the database.
|
|
|
|
:param candlesticks: DataFrame of candlestick data.
|
|
:param table_name: Name of the table to insert into.
|
|
:param symbol: Symbol of the trading pair.
|
|
:param exchange_name: Name of the exchange.
|
|
"""
|
|
market_id = self._fetch_market_id(symbol, exchange_name)
|
|
|
|
# Check if 'market_id' column already exists in the DataFrame
|
|
if 'market_id' in candlesticks.columns:
|
|
# If it exists, set its value to the fetched market_id
|
|
candlesticks['market_id'] = market_id
|
|
else:
|
|
# If it doesn't exist, insert it as the first column
|
|
candlesticks.insert(0, 'market_id', market_id)
|
|
|
|
sql_create = f"""
|
|
CREATE TABLE IF NOT EXISTS '{table_name}' (
|
|
id INTEGER PRIMARY KEY,
|
|
market_id INTEGER,
|
|
time INTEGER UNIQUE ON CONFLICT IGNORE,
|
|
open REAL NOT NULL,
|
|
high REAL NOT NULL,
|
|
low REAL NOT NULL,
|
|
close REAL NOT NULL,
|
|
volume REAL NOT NULL,
|
|
FOREIGN KEY (market_id) REFERENCES market (id)
|
|
)"""
|
|
with SQLite(self.db_file) as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute(sql_create)
|
|
candlesticks.to_sql(table_name, conn, if_exists='append', index=False)
|
|
|
|
def get_timestamped_records(self, table_name: str, timestamp_field: str, st: dt.datetime,
|
|
et: dt.datetime = None) -> pd.DataFrame:
|
|
"""
|
|
Returns records from a specified table in the database that have timestamps greater than or equal to a given
|
|
start time and, optionally, less than or equal to a given end time.
|
|
|
|
:param table_name: Database table name.
|
|
:param timestamp_field: Field name that contains the timestamp.
|
|
:param st: Start datetime.
|
|
:param et: End datetime (optional).
|
|
:return: DataFrame of records.
|
|
"""
|
|
with SQLite(self.db_file) as conn:
|
|
start_stamp = unix_time_millis(st)
|
|
if et is not None:
|
|
end_stamp = unix_time_millis(et)
|
|
q_str = (
|
|
f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= ? "
|
|
f"AND {timestamp_field} <= ?;"
|
|
)
|
|
records = pd.read_sql(q_str, conn, params=(start_stamp, end_stamp))
|
|
else:
|
|
q_str = f"SELECT * FROM '{table_name}' WHERE {timestamp_field} >= ?;"
|
|
records = pd.read_sql(q_str, conn, params=(start_stamp,))
|
|
|
|
# records = records.drop('id', axis=1) Todo: Reminder I may need to put this back later.
|
|
return records
|