225 lines
8.9 KiB
Python
225 lines
8.9 KiB
Python
import unittest
|
|
import sqlite3
|
|
import pandas as pd
|
|
import datetime as dt
|
|
from Database import Database, SQLite, make_query, make_insert, HDict
|
|
from shared_utilities import unix_time_millis
|
|
|
|
|
|
def utcnow() -> dt.datetime:
|
|
"""Return timezone-aware UTC datetime."""
|
|
return dt.datetime.now(dt.timezone.utc)
|
|
|
|
|
|
class TestSQLite(unittest.TestCase):
|
|
def test_sqlite_context_manager(self):
|
|
print("\nRunning test_sqlite_context_manager...")
|
|
with SQLite(db_file='test_db.sqlite') as con:
|
|
cursor = con.cursor()
|
|
cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
|
cursor.execute("INSERT INTO test_table (name) VALUES ('test')")
|
|
cursor.execute('SELECT name FROM test_table WHERE name = ?', ('test',))
|
|
result = cursor.fetchone()
|
|
self.assertEqual(result[0], 'test')
|
|
print("SQLite context manager test passed.")
|
|
|
|
|
|
class TestDatabase(unittest.TestCase):
|
|
def setUp(self):
|
|
# Use a temporary SQLite database for testing purposes
|
|
self.db_file = 'test_db.sqlite'
|
|
self.db = Database(db_file=self.db_file)
|
|
self.connection = sqlite3.connect(self.db_file)
|
|
self.cursor = self.connection.cursor()
|
|
|
|
def tearDown(self):
|
|
self.connection.close()
|
|
import os
|
|
os.remove(self.db_file) # Remove the temporary database file after tests
|
|
|
|
def test_execute_sql(self):
|
|
print("\nRunning test_execute_sql...")
|
|
# Drop the table if it exists to avoid OperationalError
|
|
self.cursor.execute('DROP TABLE IF EXISTS test_table')
|
|
self.connection.commit()
|
|
|
|
sql = 'CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)'
|
|
self.db.execute_sql(sql)
|
|
|
|
self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test_table';")
|
|
result = self.cursor.fetchone()
|
|
self.assertIsNotNone(result)
|
|
print("Execute SQL test passed.")
|
|
|
|
def test_make_query(self):
|
|
print("\nRunning test_make_query...")
|
|
query = make_query('id', 'test_table', ['name'])
|
|
expected_query = 'SELECT id FROM test_table WHERE name = ?;'
|
|
self.assertEqual(query, expected_query)
|
|
print("Make query test passed.")
|
|
|
|
def test_make_insert(self):
|
|
print("\nRunning test_make_insert...")
|
|
insert = make_insert('test_table', ('name', 'age'))
|
|
expected_insert = 'INSERT INTO "test_table" ("name", "age") VALUES (?, ?);'
|
|
self.assertEqual(insert, expected_insert)
|
|
print("Make insert test passed.")
|
|
|
|
def test_get_item_where(self):
|
|
print("\nRunning test_get_item_where...")
|
|
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
|
self.cursor.execute("INSERT INTO test_table (id, name) VALUES (1, 'test')")
|
|
self.connection.commit()
|
|
item = self.db.get_item_where('name', 'test_table', ('id', 1))
|
|
self.assertEqual(item, 'test')
|
|
print("Get item where test passed.")
|
|
|
|
def test_get_rows_where(self):
|
|
print("\nRunning test_get_rows_where...")
|
|
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
|
self.cursor.execute("INSERT INTO test_table (id, name) VALUES (1, 'test')")
|
|
self.connection.commit()
|
|
rows = self.db.get_rows_where('test_table', [('name', 'test')])
|
|
self.assertIsInstance(rows, pd.DataFrame)
|
|
self.assertEqual(rows.iloc[0]['name'], 'test')
|
|
print("Get rows where test passed.")
|
|
|
|
def test_insert_dataframe(self):
|
|
print("\nRunning test_insert_dataframe...")
|
|
df = pd.DataFrame({'id': [1], 'name': ['test']})
|
|
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
|
self.connection.commit()
|
|
self.db.insert_dataframe(df, 'test_table')
|
|
self.cursor.execute('SELECT name FROM test_table WHERE id = 1')
|
|
result = self.cursor.fetchone()
|
|
self.assertEqual(result[0], 'test')
|
|
print("Insert dataframe test passed.")
|
|
|
|
def test_insert_row(self):
|
|
print("\nRunning test_insert_row...")
|
|
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
|
self.connection.commit()
|
|
self.db.insert_row('test_table', ('id', 'name'), (1, 'test'))
|
|
self.cursor.execute('SELECT name FROM test_table WHERE id = 1')
|
|
result = self.cursor.fetchone()
|
|
self.assertEqual(result[0], 'test')
|
|
print("Insert row test passed.")
|
|
|
|
def test_table_exists(self):
|
|
print("\nRunning test_table_exists...")
|
|
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)')
|
|
self.connection.commit()
|
|
exists = self.db.table_exists('test_table')
|
|
self.assertTrue(exists)
|
|
print("Table exists test passed.")
|
|
|
|
def test_get_timestamped_records(self):
|
|
print("\nRunning test_get_timestamped_records...")
|
|
df = pd.DataFrame({
|
|
'time': [unix_time_millis(utcnow())],
|
|
'open': [1.0],
|
|
'high': [1.0],
|
|
'low': [1.0],
|
|
'close': [1.0],
|
|
'volume': [1.0]
|
|
})
|
|
table_name = 'test_table'
|
|
self.cursor.execute(f"""
|
|
CREATE TABLE {table_name} (
|
|
id INTEGER PRIMARY KEY,
|
|
time INTEGER UNIQUE,
|
|
open REAL NOT NULL,
|
|
high REAL NOT NULL,
|
|
low REAL NOT NULL,
|
|
close REAL NOT NULL,
|
|
volume REAL NOT NULL
|
|
)
|
|
""")
|
|
self.connection.commit()
|
|
self.db.insert_dataframe(df, table_name)
|
|
st = utcnow() - dt.timedelta(minutes=1)
|
|
et = utcnow()
|
|
records = self.db.get_timestamped_records(table_name, 'time', st, et)
|
|
self.assertIsInstance(records, pd.DataFrame)
|
|
self.assertFalse(records.empty)
|
|
print("Get timestamped records test passed.")
|
|
|
|
def test_get_from_static_table(self):
|
|
print("\nRunning test_get_from_static_table...")
|
|
self.cursor.execute('CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT UNIQUE)')
|
|
self.connection.commit()
|
|
item = self.db.get_from_static_table('id', 'test_table', HDict({'name': 'test'}), create_id=True)
|
|
self.assertIsInstance(item, int)
|
|
self.cursor.execute('SELECT id FROM test_table WHERE name = ?', ('test',))
|
|
result = self.cursor.fetchone()
|
|
self.assertEqual(item, result[0])
|
|
print("Get from static table test passed.")
|
|
|
|
def test_insert_candles_into_db(self):
|
|
print("\nRunning test_insert_candles_into_db...")
|
|
df = pd.DataFrame({
|
|
'time': [unix_time_millis(utcnow())],
|
|
'open': [1.0],
|
|
'high': [1.0],
|
|
'low': [1.0],
|
|
'close': [1.0],
|
|
'volume': [1.0]
|
|
})
|
|
table_name = 'test_table'
|
|
self.cursor.execute(f"""
|
|
CREATE TABLE {table_name} (
|
|
id INTEGER PRIMARY KEY,
|
|
market_id INTEGER,
|
|
time INTEGER UNIQUE,
|
|
open REAL NOT NULL,
|
|
high REAL NOT NULL,
|
|
low REAL NOT NULL,
|
|
close REAL NOT NULL,
|
|
volume REAL NOT NULL
|
|
)
|
|
""")
|
|
self.connection.commit()
|
|
|
|
# Create the exchange and markets tables needed for the foreign key constraints
|
|
self.cursor.execute("""
|
|
CREATE TABLE exchange (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT UNIQUE
|
|
)
|
|
""")
|
|
self.cursor.execute("""
|
|
CREATE TABLE markets (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
symbol TEXT,
|
|
exchange_id INTEGER,
|
|
FOREIGN KEY (exchange_id) REFERENCES exchange(id)
|
|
)
|
|
""")
|
|
self.connection.commit()
|
|
|
|
self.db.insert_candles_into_db(df, table_name, 'BTC/USDT', 'binance')
|
|
self.cursor.execute(f'SELECT * FROM {table_name}')
|
|
result = self.cursor.fetchall()
|
|
self.assertFalse(len(result) == 0)
|
|
print("Insert candles into db test passed.")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|
|
|
|
# def test():
|
|
# # un_hashed_pass = 'password'
|
|
# hasher = bcrypt.using(rounds=13)
|
|
# # hashed_pass = hasher.hash(un_hashed_pass)
|
|
# # print(f'password: {un_hashed_pass}')
|
|
# # print(f'hashed pass: {hashed_pass}')
|
|
# # print(f" right pass: {hasher.verify('password', hashed_pass)}")
|
|
# # print(f" wrong pass: {hasher.verify('passWord', hashed_pass)}")
|
|
# engine = create_engine("sqlite:///" + config.DB_FILE, echo=True)
|
|
# with engine.connect() as conn:
|
|
# default_user = pd.read_sql_query(sql=text("SELECT * FROM users WHERE user_name = 'guest'"), con=conn)
|
|
# # hashed_password = default_user.password.values[0]
|
|
# # print(f" verify pass: {hasher.verify('password', hashed_password)}")
|
|
# username = default_user.user_name.values[0]
|
|
# print(username)
|