From 8361efd96534f613077956d0598e168c03314ab1 Mon Sep 17 00:00:00 2001 From: Rob Date: Mon, 19 Aug 2024 23:10:13 -0300 Subject: [PATCH] Refactored DataCache, Database and Users. All db interactions are now all inside Database. All data request from Users now go through DataCache. Expanded DataCache test to include all methods. All DataCache tests pass. --- .gitignore | 2 +- src/BrighterTrades.py | 68 +-- src/Configuration.py | 375 ++++++++---- src/DataCache_v2.py | 308 ++++++++-- src/Database.py | 24 +- src/Signals.py | 9 +- src/Strategies.py | 28 +- src/Users.py | 856 ++++++++++++++-------------- src/app.py | 3 +- src/archived_code/DataCache.py | 30 +- src/candles.py | 36 +- src/indicators.py | 20 +- src/maintenence/debuging_testing.py | 55 +- src/trade.py | 7 +- tests/test_DataCache_v2.py | 336 +++++++++-- tests/test_Users.py | 2 +- 16 files changed, 1383 insertions(+), 776 deletions(-) diff --git a/.gitignore b/.gitignore index 4234c28..aa480d3 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,7 @@ # Ignore Flask session data flask_session/ -# Ignore testing cache +# Ignore testing data .pytest_cache/ # Ignore databases diff --git a/src/BrighterTrades.py b/src/BrighterTrades.py index b9c85b8..7755dc7 100644 --- a/src/BrighterTrades.py +++ b/src/BrighterTrades.py @@ -1,6 +1,6 @@ - from typing import Any +from Users import Users from DataCache_v2 import DataCache from Strategies import Strategies from backtesting import Backtester @@ -20,25 +20,29 @@ class BrighterTrades: # Object that interacts with the persistent data. self.data = DataCache(self.exchanges) - # Configuration and settings for the user app and charts - self.config = Configuration(cache=self.data) + # Configuration for the app + self.config = Configuration() - # Object that maintains signals. Initialize with any signals loaded from file. - self.signals = Signals(self.config.signals_list) + # The object that manages users in the system. + self.users = Users(data_cache=self.data) + + # Object that maintains signals. + self.signals = Signals(self.config) # Object that maintains candlestick and price data. - self.candles = Candles(config_obj=self.config, exchanges=self.exchanges, data_source=self.data) + self.candles = Candles(users=self.users, exchanges=self.exchanges, data_source=self.data, + config=self.config) # Object that interacts with and maintains data from available indicators - self.indicators = Indicators(self.candles, self.config) + self.indicators = Indicators(self.candles, self.users) # Object that maintains the trades data - self.trades = Trades(self.config.trades) + self.trades = Trades(self.users) # The Trades object needs to connect to an exchange_interface. self.trades.connect_exchanges(exchanges=self.exchanges) # Object that maintains the strategies data - self.strategies = Strategies(self.config.strategies_list, self.trades) + self.strategies = Strategies(self.data, self.trades) # Object responsible for testing trade and strategies data. self.backtester = Backtester() @@ -56,8 +60,8 @@ class BrighterTrades: raise ValueError("Missing required arguments for 'create_new_user'") try: - self.config.users.create_new_user(email=email, username=username, password=password) - login_successful = self.config.users.log_in_user(username=username, password=password) + self.users.create_new_user(email=email, username=username, password=password) + login_successful = self.users.log_in_user(username=username, password=password) return login_successful except Exception as e: # Handle specific exceptions or log the error @@ -77,11 +81,11 @@ class BrighterTrades: try: if cmd == 'logout': - return self.config.users.log_out_user(username=user_name) + return self.users.log_out_user(username=user_name) elif cmd == 'login': if password is None: raise ValueError("Password is required for login.") - return self.config.users.log_in_user(username=user_name, password=password) + return self.users.log_in_user(username=user_name, password=password) except Exception as e: # Handle specific exceptions or log the error raise ValueError("Error during user login/logout: " + str(e)) @@ -98,19 +102,19 @@ class BrighterTrades: if info == 'Chart View': try: - return self.config.users.get_chart_view(user_name=user_name) + return self.users.get_chart_view(user_name=user_name) except Exception as e: # Handle specific exceptions or log the error raise ValueError("Error retrieving chart view information: " + str(e)) elif info == 'Is logged in?': try: - return self.config.users.is_logged_in(user_name=user_name) + return self.users.is_logged_in(user_name=user_name) except Exception as e: # Handle specific exceptions or log the error raise ValueError("Error checking logged in status: " + str(e)) elif info == 'User_id': try: - return self.config.users.get_id(user_name=user_name) + return self.users.get_id(user_name=user_name) except Exception as e: # Handle specific exceptions or log the error raise ValueError("Error fetching id: " + str(e)) @@ -181,10 +185,10 @@ class BrighterTrades: :param default_keys: default API keys. :return: bool - True on success. """ - active_exchanges = self.config.users.get_exchanges(user_name, category='active_exchanges') + active_exchanges = self.users.get_exchanges(user_name, category='active_exchanges') success = False for exchange in active_exchanges: - keys = self.config.users.get_api_keys(user_name, exchange) + keys = self.users.get_api_keys(user_name, exchange) success = self.connect_or_config_exchange(user_name=user_name, exchange_name=exchange, api_keys=keys) @@ -202,7 +206,7 @@ class BrighterTrades: :param user_name: str - The name of the user making the query. """ - chart_view = self.config.users.get_chart_view(user_name=user_name) + chart_view = self.users.get_chart_view(user_name=user_name) indicator_types = self.indicators.indicator_types available_indicators = self.indicators.get_indicator_list(user_name) @@ -231,19 +235,20 @@ class BrighterTrades: :return: A dictionary containing the requested data. """ - chart_view = self.config.users.get_chart_view(user_name=user_name) + chart_view = self.users.get_chart_view(user_name=user_name) exchange = self.exchanges.get_exchange(ename=chart_view.get('exchange'), uname=user_name) # noinspection PyDictCreation r_data = {} - r_data['title'] = self.config.app_data.get('application_title', '') + r_data['title'] = self.config.get_setting('application_title') r_data['chart_interval'] = chart_view.get('timeframe', '') r_data['selected_exchange'] = chart_view.get('exchange', '') r_data['intervals'] = exchange.intervals if exchange else [] r_data['symbols'] = exchange.get_symbols() if exchange else {} r_data['available_exchanges'] = self.exchanges.get_available_exchanges() or [] r_data['connected_exchanges'] = self.exchanges.get_connected_exchanges(user_name) or [] - r_data['configured_exchanges'] = self.config.users.get_exchanges(user_name, category='configured_exchanges') or [] + r_data['configured_exchanges'] = self.users.get_exchanges( + user_name, category='configured_exchanges') or [] r_data['my_balances'] = self.exchanges.get_all_balances(user_name) or {} r_data['indicator_types'] = self.indicators.indicator_types or [] r_data['indicator_list'] = self.indicators.get_indicator_list(user_name) or [] @@ -295,7 +300,7 @@ class BrighterTrades: return "The new signal must have a 'name' attribute." self.signals.new_signal(data) - self.config.update_data('signals', self.signals.get_signals('dict')) + self.config.set_setting('signals_list', self.signals.get_signals('dict')) return data def received_new_strategy(self, data: dict) -> str | dict: @@ -309,7 +314,7 @@ class BrighterTrades: return "The new strategy must have a 'name' attribute." self.strategies.new_strategy(data) - self.config.update_data('strategies', self.strategies.get_strategies('dict')) + self.config.set_setting('strategies', self.strategies.get_strategies('dict')) return data def delete_strategy(self, strategy_name: str) -> None: @@ -377,10 +382,9 @@ class BrighterTrades: success = self.exchanges.connect_exchange(exchange_name=exchange_name, user_name=user_name, api_keys=api_keys) if success: - self.config.users.active_exchange(exchange=exchange_name, user_name=user_name, cmd='set') + self.users.active_exchange(exchange=exchange_name, user_name=user_name, cmd='set') if api_keys: - self.config.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, - user_name=user_name) + self.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, user_name=user_name) return True else: return False # Failed to connect @@ -391,7 +395,7 @@ class BrighterTrades: else: # Exchange is already connected, update API keys if provided if api_keys: - self.config.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, user_name=user_name) + self.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, user_name=user_name) return True # Already connected def close_trade(self, trade_id): @@ -470,19 +474,19 @@ class BrighterTrades: if setting == 'interval': interval_state = params['timeframe'] - self.config.users.set_chart_view(values=interval_state, specific_property='timeframe', user_name=user_name) + self.users.set_chart_view(values=interval_state, specific_property='timeframe', user_name=user_name) elif setting == 'trading_pair': trading_pair = params['symbol'] - self.config.users.set_chart_view(values=trading_pair, specific_property='market', user_name=user_name) + self.users.set_chart_view(values=trading_pair, specific_property='market', user_name=user_name) elif setting == 'exchange': exchange_name = params['exchange_name'] # Get the first result of a list of available symbols from this exchange_name. market = self.exchanges.get_exchange(ename=exchange_name, uname=user_name).get_symbols()[0] # Change the exchange in the chart view and pass it a default market to display. - self.config.users.set_chart_view(values=exchange_name, specific_property='exchange_name', - user_name=user_name, default_market=market) + self.users.set_chart_view(values=exchange_name, specific_property='exchange_name', + user_name=user_name, default_market=market) elif setting == 'toggle_indicator': indicators_to_toggle = params.getlist('indicator') diff --git a/src/Configuration.py b/src/Configuration.py index 67b82d3..70f7974 100644 --- a/src/Configuration.py +++ b/src/Configuration.py @@ -1,130 +1,279 @@ -import pandas import yaml -from Signals import Signals -from indicators import Indicators -from Users import Users +import os +import time +import logging +from typing import Any + +logger = logging.getLogger(__name__) class Configuration: - def __init__(self, cache): - # ************** Default values************** + """ + Manages the application's settings, loading and saving them to a YAML file. + - Automatically loads the settings from file when instantiated. + - Settings can be added or modified with set_setting('key', value) + or by editing the file directly. + """ - self.app_data = { - 'application_title': 'BrighterTrades', # The title of our program. - 'max_data_loaded': 1000 # The maximum number of candles to store in memory. + def __init__(self, config_file='config.yml'): + """Initializes with default settings and loads saved data.""" + self.config_file = config_file + self.save_in_progress = False # Persistent flag to prevent infinite recursion + self._set_default_settings() + self.manage_config('load') + + def _set_default_settings(self): + """Set default settings.""" + self.settings = { + 'application_title': 'BrighterTrades', + 'max_data_loaded': 1000 } - # The object that interacts with the database. - self.data = cache.db + def _generate_config_data(self): + """Generates a list of settings to be saved to the config file.""" + return self.settings.copy() - # The object that manages users in the system. - self.users = Users(data_cache=cache) - - # The name of the file that stores saved_data - self.config_FN = 'config.yml' - - # A list of all the available Signals. - # Calls a static method of Signals that initializes a default list. - self.signals_list = Signals.get_signals_defaults() - - # list of all the available strategies. - self.strategies_list = [] - - # list of trades. - self.trades = [] - - # The data that will be saved and loaded from file . - self.saved_data = None - - # Load any saved data from file - self.config_and_states('load') - - def update_data(self, data_type, data): + def manage_config(self, cmd: str): """ - Replace current list of data sets with an updated list. - - :param data_type: The data being replaced - :param data: The replacement data. - :return: None. + Loads or saves settings to the config file according to cmd: 'load' | 'save' """ - if data_type == 'strategies': - self.strategies_list = data - elif data_type == 'signals': - self.signals_list = data - elif data_type == 'trades': - self.trades = data - else: - raise ValueError(f'Configuration: update_data(): Unsupported data_type: {data_type}') - # Save it to file. - self.config_and_states('save') - def remove(self, what, name): - # Removes by name an item from a list in saved data. + def _apply_loaded_settings(data): + """Updates settings from loaded data and marks if saving is needed.""" + needs_save = False + for key, value in data.items(): + if key not in self.settings or self.settings[key] != value: + self.settings[key] = value + needs_save = True + return needs_save - # Trades are indexed by unique_id, while Signals and Strategies are indexed by name. - if what == 'trades': - prop = 'unique_id' - else: - prop = 'name' + def _load_config_from_file(filepath): + try: + with open(filepath, "r") as file_descriptor: + return yaml.safe_load(file_descriptor) + except yaml.YAMLError: + timestamp = time.strftime("%Y%m%d-%H%M%S") + backup_path = f"{filepath}.{timestamp}.backup" + os.rename(filepath, backup_path) + logging.warning(f"Corrupt YAML file detected. Backup saved to {backup_path}") + logging.info(f"Recreating the configuration file with default settings.") + return None - for obj in self.saved_data[what]: - if obj[prop] == name: - self.saved_data[what].remove(obj) - break - # Save it to file. - self.config_and_states('save') - - def config_and_states(self, cmd): - """Loads or saves configurable data to the file set in self.config_FN""" - - # The data stored and retrieved from file session. - self.saved_data = { - 'signals': self.signals_list, - 'strategies': self.strategies_list, - 'trades': self.trades - } - - def set_loaded_values(): - # Sets the values in the saved_data object. - - if 'signals' in self.saved_data: - self.signals_list = self.saved_data['signals'] - - if 'strategies' in self.saved_data: - self.strategies_list = self.saved_data['strategies'] - - if 'trades' in self.saved_data: - self.trades = self.saved_data['trades'] - - def load_configuration(filepath): - """load file data""" - with open(filepath, "r") as file_descriptor: - data = yaml.safe_load(file_descriptor) - return data - - def save_configuration(filepath, data): - """Saves file data""" - with open(filepath, "w") as file_descriptor: - yaml.dump(data, file_descriptor) + def _save_config_to_file(filepath, data): + try: + with open(filepath, "w") as file_descriptor: + yaml.dump(data, file_descriptor) + except (IOError, OSError) as e: + logging.error(f"Failed to save configuration to {filepath}: {e}") + raise ValueError(f"Failed to save configuration to {filepath}: {e}") if cmd == 'load': - # If load_configuration() finds a file it overwrites - # the saved_data object otherwise it creates a new file - # with the defaults contained in saved_data> - - # If file exist load the values. - try: - self.saved_data = load_configuration(self.config_FN) - set_loaded_values() - # If file doesn't exist create a file and save the default values. - except IOError: - save_configuration(self.config_FN, self.saved_data) - + if os.path.exists(self.config_file): + loaded_data = _load_config_from_file(self.config_file) + if loaded_data is None: # Corrupt file case, recreate it + _save_config_to_file(self.config_file, self._generate_config_data()) + elif _apply_loaded_settings(loaded_data) and not self.save_in_progress: + self.save_in_progress = True + self.manage_config('save') + self.save_in_progress = False + else: + logging.info(f"Configuration file not found. Creating a new one at {self.config_file}.") + _save_config_to_file(self.config_file, self._generate_config_data()) elif cmd == 'save': - try: - # Write saved_data to the file. - save_configuration(self.config_FN, self.saved_data) - except IOError: - raise ValueError("save_configuration(): Couldn't save the file.") + _save_config_to_file(self.config_file, self._generate_config_data()) else: - raise ValueError('save_configuration(): Invalid command received.') + raise ValueError('manage_config(): Invalid command.') + + def reset_settings_to_defaults(self): + """Resets settings to default values.""" + self._set_default_settings() + self.manage_config('save') + + def get_setting(self, key: str) -> Any: + """ + Returns the value of the specified setting or None if the key is not found. + """ + return self.settings.get(key, None) + + def set_setting(self, key: str, value: Any): + """ + Receives a key and value of any setting and saves the configuration. + """ + self.settings[key] = value + self.manage_config('save') + +# import yaml +# from Signals import Signals +# +# +# class Configuration: +# """ +# Configuration class manages the application's settings, +# signals, strategies, and trades. It loads and saves these +# configurations to a YAML file. +# +# Attributes: +# app_data (dict): Default application settings like title and maximum data load. +# data (object): Database interaction object from the data. +# config_FN (str): Filename for storing and loading configurations. +# signals_list (list): List of available signals initialized by the Signals class. +# strategies_list (list): List of strategies available for the application. +# trades (list): List of trades managed by the application. +# saved_data (dict): Data structure for saving/loading configuration data. +# """ +# +# def __init__(self, data): +# """ +# Initializes the Configuration object with default values and loads saved data. +# +# Args: +# data (object): Cache object with a database attribute to interact with the database. +# """ +# # ************** Default values ************** +# # Application metadata such as title and maximum data to be loaded. +# self.app_data = { +# 'application_title': 'BrighterTrades', # The title of the program. +# 'max_data_loaded': 1000 # Maximum number of candles to store in memory. +# } +# +# # The database object for interaction. +# self.data = data.db +# +# # Name of the configuration file. +# self.config_FN = 'config.yml' +# +# # List of available signals initialized with default values. +# self.signals_list = Signals.get_signals_defaults() +# +# # List to hold available strategies. +# self.strategies_list = [] +# +# # List to hold trades. +# self.trades = [] +# +# # Placeholder for data loaded from or to be saved to the file. +# self.saved_data = None +# +# # Load any saved data from the configuration file. +# self.config_and_states('load') +# +# def update_data(self, data_type, data): +# """ +# Replace the current list of data sets with an updated list. +# +# Args: +# data_type (str): Type of data to be updated ('strategies', 'signals', or 'trades'). +# data (list): The new data to replace the old one. +# +# Raises: +# ValueError: If the provided data_type is not supported. +# """ +# if data_type == 'strategies': +# self.strategies_list = data +# elif data_type == 'signals': +# self.signals_list = data +# elif data_type == 'trades': +# self.trades = data +# else: +# raise ValueError(f'Configuration: update_data(): Unsupported data_type: {data_type}') +# +# # Save the updated data to the configuration file. +# self.config_and_states('save') +# +# def remove(self, what, name): +# """ +# Removes an item by name from the saved data list. +# +# Args: +# what (str): Type of data to remove ('strategies', 'signals', or 'trades'). +# name (str): The name or unique_id of the item to remove. +# +# Raises: +# ValueError: If the item with the specified name is not found. +# """ +# # Determine the property to match based on the type of data. +# if what == 'trades': +# prop = 'unique_id' +# else: +# prop = 'name' +# +# # Remove the item from the list if it matches the name or unique_id. +# for obj in self.saved_data[what]: +# if obj[prop] == name: +# self.saved_data[what].remove(obj) +# break +# +# # Save the updated data to the configuration file. +# self.config_and_states('save') +# +# def config_and_states(self, cmd): +# """ +# Loads or saves configurable data to the file set in self.config_FN. +# +# Args: +# cmd (str): Command to either 'load' or 'save' the configuration data. +# +# Raises: +# ValueError: If the command is neither 'load' nor 'save'. +# """ +# +# # Data structure to hold the current state of signals, strategies, and trades. +# self.saved_data = { +# 'signals': self.signals_list, +# 'strategies': self.strategies_list, +# 'trades': self.trades +# } +# +# def set_loaded_values(): +# """Sets the values in the saved_data object to the class attributes.""" +# if 'signals' in self.saved_data: +# self.signals_list = self.saved_data['signals'] +# +# if 'strategies' in self.saved_data: +# self.strategies_list = self.saved_data['strategies'] +# +# if 'trades' in self.saved_data: +# self.trades = self.saved_data['trades'] +# +# def load_configuration(filepath): +# """ +# Load configuration data from a YAML file. +# +# Args: +# filepath (str): Path to the configuration file. +# +# Returns: +# dict: Loaded configuration data. +# """ +# with open(filepath, "r") as file_descriptor: +# data = yaml.safe_load(file_descriptor) +# return data +# +# def save_configuration(filepath, data): +# """ +# Save configuration data to a YAML file. +# +# Args: +# filepath (str): Path to the configuration file. +# data (dict): Data to save. +# """ +# with open(filepath, "w") as file_descriptor: +# yaml.dump(data, file_descriptor) +# +# if cmd == 'load': +# try: +# # Attempt to load the configuration from the file. +# self.saved_data = load_configuration(self.config_FN) +# set_loaded_values() +# except IOError: +# # If the file doesn't exist, save the default values. +# save_configuration(self.config_FN, self.saved_data) +# +# elif cmd == 'save': +# try: +# # Save the current state to the configuration file. +# save_configuration(self.config_FN, self.saved_data) +# except IOError: +# raise ValueError("save_configuration(): Couldn't save the file.") +# else: +# raise ValueError('save_configuration(): Invalid command received.') diff --git a/src/DataCache_v2.py b/src/DataCache_v2.py index d8f0a74..d9a2861 100644 --- a/src/DataCache_v2.py +++ b/src/DataCache_v2.py @@ -1,4 +1,5 @@ -from typing import List, Any +import json +from typing import List, Any, Tuple import pandas as pd import datetime as dt import logging @@ -59,23 +60,31 @@ def estimate_record_count(start_time, end_time, timeframe: str) -> int: return int(expected_records) -def generate_expected_timestamps(start_datetime: dt.datetime, end_datetime: dt.datetime, - timeframe: str) -> pd.DatetimeIndex: - if start_datetime.tzinfo is None: - raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") - if end_datetime.tzinfo is None: - raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.") - - delta = timeframe_to_timedelta(timeframe) - if isinstance(delta, pd.Timedelta): - return pd.date_range(start=start_datetime, end=end_datetime, freq=delta) - elif isinstance(delta, pd.DateOffset): - current = start_datetime - timestamps = [] - while current <= end_datetime: - timestamps.append(current) - current += delta - return pd.DatetimeIndex(timestamps) +# def generate_expected_timestamps(start_datetime: dt.datetime, end_datetime: dt.datetime, +# timeframe: str) -> pd.DatetimeIndex: +# """ +# What it says. Todo: confirm this is unused and archive. +# +# :param start_datetime: +# :param end_datetime: +# :param timeframe: +# :return: +# """ +# if start_datetime.tzinfo is None: +# raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") +# if end_datetime.tzinfo is None: +# raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.") +# +# delta = timeframe_to_timedelta(timeframe) +# if isinstance(delta, pd.Timedelta): +# return pd.date_range(start=start_datetime, end=end_datetime, freq=delta) +# elif isinstance(delta, pd.DateOffset): +# current = start_datetime +# timestamps = [] +# while current <= end_datetime: +# timestamps.append(current) +# current += delta +# return pd.DatetimeIndex(timestamps) class DataCache: @@ -84,10 +93,170 @@ class DataCache: def __init__(self, exchanges): self.db = Database() self.exchanges = exchanges - self.cached_data = {} + # Single DataFrame for all cached data + self.cache = pd.DataFrame(columns=['key', 'data']) # Assuming 'key' and 'data' are necessary logger.info("DataCache initialized.") + def fetch_cached_rows(self, table: str, filter_vals: Tuple[str, Any]) -> pd.DataFrame | None: + """ + Retrieves rows from the cache if available; otherwise, queries the database and caches the result. + + :param table: Name of the database table to query. + :param filter_vals: A tuple containing the column name and the value to filter by. + :return: A DataFrame containing the requested rows, or None if no matching rows are found. + """ + # Construct a filter condition for the cache based on the table name and filter values. + cache_filter = (self.cache['table'] == table) & (self.cache[filter_vals[0]] == filter_vals[1]) + cached_rows = self.cache[cache_filter] + + # If the data is found in the cache, return it. + if not cached_rows.empty: + return cached_rows + + # If the data is not found in the cache, query the database. + rows = self.db.get_rows_where(table, filter_vals) + if rows is not None: + # Tag the rows with the table name and add them to the cache. + rows['table'] = table + self.cache = pd.concat([self.cache, rows]) + return rows + + def remove_row(self, filter_vals: Tuple[str, Any], additional_filter: Tuple[str, Any] = None, + remove_from_db: bool = True, table: str = None) -> None: + """ + Removes a specific row from the cache and optionally from the database based on filter criteria. + + :param filter_vals: A tuple containing the column name and the value to filter by. + :param additional_filter: An optional additional filter to apply. + :param remove_from_db: If True, also removes the row from the database. Default is True. + :param table: The name of the table from which to remove the row in the database (optional). + """ + logger.debug( + f"Removing row from cache: filter={filter_vals}," + f" additional_filter={additional_filter}, remove_from_db={remove_from_db}, table={table}") + + # Construct the filter condition for the cache + cache_filter = (self.cache[filter_vals[0]] == filter_vals[1]) + + if additional_filter: + cache_filter = cache_filter & (self.cache[additional_filter[0]] == additional_filter[1]) + + # Remove the row from the cache + self.cache = self.cache.drop(self.cache[cache_filter].index) + logger.info(f"Row removed from cache: filter={filter_vals}") + + if remove_from_db and table: + # Construct the SQL query to delete from the database + sql = f"DELETE FROM {table} WHERE {filter_vals[0]} = ?" + params = [filter_vals[1]] + + if additional_filter: + sql += f" AND {additional_filter[0]} = ?" + params.append(additional_filter[1]) + + # Execute the SQL query to remove the row from the database + self.db.execute_sql(sql, tuple(params)) + logger.info( + f"Row removed from database: table={table}, filter={filter_vals}," + f" additional_filter={additional_filter}") + + def is_attr_taken(self, table: str, attr: str, val: Any) -> bool: + """ + Checks if a specific attribute in a table is already taken. + + :param table: The name of the table to check. + :param attr: The attribute to check (e.g., 'user_name', 'email'). + :param val: The value of the attribute to check. + :return: True if the attribute is already taken, False otherwise. + """ + # Fetch rows from the specified table where the attribute matches the given value + result = self.fetch_cached_rows(table=table, filter_vals=(attr, val)) + return result is not None and not result.empty + + def fetch_cached_item(self, item_name: str, table_name: str, filter_vals: Tuple[str, Any]) -> Any: + """ + Retrieves a specific item from the cache or database, caching the result if necessary. + + :param item_name: The name of the column to retrieve. + :param table_name: The name of the table where the item is stored. + :param filter_vals: A tuple containing the column name and the value to filter by. + :return: The value of the requested item. + :raises ValueError: If the item is not found in either the cache or the database. + """ + # Fetch the relevant rows + rows = self.fetch_cached_rows(table_name, filter_vals) + if rows is not None and not rows.empty: + # Return the specific item from the first matching row. + return rows.iloc[0][item_name] + + # If the item is not found, raise an error. + raise ValueError(f"Item {item_name} not found in {table_name} where {filter_vals[0]} = {filter_vals[1]}") + + def modify_cached_row(self, table: str, filter_vals: Tuple[str, Any], field_name: str, new_data: Any) -> None: + """ + Modifies a specific field in a row within the cache and updates the database accordingly. + + :param table: The name of the table where the data is stored. + :param filter_vals: A tuple containing the column name and the value to filter by. + :param field_name: The field to be updated. + :param new_data: The new data to be set. + """ + # Retrieve the row from the cache or database + row = self.fetch_cached_rows(table, filter_vals) + + if row is None or row.empty: + raise ValueError(f"Row not found in cache or database for {filter_vals[0]} = {filter_vals[1]}") + + # Modify the specified field + if isinstance(new_data, str): + row.loc[0, field_name] = new_data + else: + # If new_data is not a string, it’s converted to a JSON string before being inserted into the DataFrame. + row.loc[0, field_name] = json.dumps(new_data) + + # Update the cache by removing the old entry and adding the modified row + self.cache = self.cache.drop( + self.cache[(self.cache['table'] == table) & (self.cache[filter_vals[0]] == filter_vals[1])].index + ) + self.cache = pd.concat([self.cache, row]) + + # Update the database with the modified row + self.db.insert_dataframe(row.drop(columns='id'), table) + + def insert_data(self, df: pd.DataFrame, table: str, skip_cache: bool = False) -> None: + """ + Inserts data into the specified table in the database, with an option to skip cache insertion. + + :param df: The DataFrame containing the data to insert. + :param table: The name of the table where the data should be inserted. + :param skip_cache: If True, skips inserting the data into the cache. Default is False. + """ + # Insert the data into the database + self.db.insert_dataframe(df=df, table=table) + + # Optionally insert the data into the cache + if not skip_cache: + df['table'] = table # Add table name for cache identification + self.cache = pd.concat([self.cache, df]) + + def insert_row(self, table: str, columns: tuple, values: tuple) -> None: + """ + Inserts a single row into the specified table in the database. + + :param table: The name of the table where the row should be inserted. + :param columns: A tuple of column names corresponding to the values. + :param values: A tuple of values to insert into the specified columns. + """ + self.db.insert_row(table=table, columns=columns, values=values) + def get_records_since(self, start_datetime: dt.datetime, ex_details: List[str]) -> pd.DataFrame: + """ + This gets up-to-date records from a specified market and exchange. + + :param start_datetime: The approximate time the first record should represent. + :param ex_details: The user exchange and market. + :return: The records. + """ if self.TYPECHECKING_ENABLED: if not isinstance(start_datetime, dt.datetime): raise TypeError("start_datetime must be a datetime object") @@ -106,12 +275,20 @@ class DataCache: 'end_datetime': end_datetime, 'ex_details': ex_details, } - return self.get_or_fetch_from('cache', **args) + return self._get_or_fetch_from('data', **args) except Exception as e: logger.error(f"An error occurred: {str(e)}") raise - def get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame: + def _get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame: + """ + Fetches market records from a resource stack (data, database, exchange). + fills incomplete request by fetching down the stack then updates the rest. + + :param target: Starting point for the fetch. ['data', 'database', 'exchange'] + :param kwargs: Details and credentials for the request. + :return: Records in a dataframe. + """ start_datetime = kwargs.get('start_datetime') if start_datetime.tzinfo is None: raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") @@ -144,12 +321,12 @@ class DataCache: key = self._make_key(ex_details) combined_data = pd.DataFrame() - if target == 'cache': - resources = [self.get_candles_from_cache, self.get_from_database, self.get_from_server] + if target == 'data': + resources = [self._get_candles_from_cache, self._get_from_database, self._get_from_server] elif target == 'database': - resources = [self.get_from_database, self.get_from_server] + resources = [self._get_from_database, self._get_from_server] elif target == 'server': - resources = [self.get_from_server] + resources = [self._get_from_server] else: raise ValueError('Not a valid Target!') @@ -165,11 +342,11 @@ class DataCache: combined_data = pd.concat([combined_data, result]).drop_duplicates(subset='open_time').sort_values( by='open_time') - is_complete, request_criteria = self.data_complete(combined_data, **request_criteria) + is_complete, request_criteria = self._data_complete(combined_data, **request_criteria) if is_complete: - if fetch_method in [self.get_from_database, self.get_from_server]: - self.update_candle_cache(combined_data, key) - if fetch_method == self.get_from_server: + if fetch_method in [self._get_from_database, self._get_from_server]: + self._update_candle_cache(combined_data, key) + if fetch_method == self._get_from_server: self._populate_db(ex_details, combined_data) return combined_data @@ -178,7 +355,7 @@ class DataCache: logger.error('Unable to fetch the requested data.') return combined_data if not combined_data.empty else pd.DataFrame() - def get_candles_from_cache(self, **kwargs) -> pd.DataFrame: + def _get_candles_from_cache(self, **kwargs) -> pd.DataFrame: start_datetime = kwargs.get('start_datetime') if start_datetime.tzinfo is None: raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") @@ -200,7 +377,7 @@ class DataCache: raise ValueError("Missing required arguments") key = self._make_key(ex_details) - logger.debug('Getting records from cache.') + logger.debug('Getting records from data.') df = self.get_cache(key) if df is None: logger.debug("Cache records didn't exist.") @@ -210,7 +387,7 @@ class DataCache: df['open_time'] <= unix_time_millis(end_datetime))].reset_index(drop=True) return df_filtered - def get_from_database(self, **kwargs) -> pd.DataFrame: + def _get_from_database(self, **kwargs) -> pd.DataFrame: start_datetime = kwargs.get('start_datetime') if start_datetime.tzinfo is None: raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") @@ -240,7 +417,7 @@ class DataCache: return self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=start_datetime, et=end_datetime) - def get_from_server(self, **kwargs) -> pd.DataFrame: + def _get_from_server(self, **kwargs) -> pd.DataFrame: symbol = kwargs.get('ex_details')[0] interval = kwargs.get('ex_details')[1] exchange_name = kwargs.get('ex_details')[2] @@ -272,7 +449,7 @@ class DataCache: end_datetime) @staticmethod - def data_complete(data: pd.DataFrame, **kwargs) -> (bool, dict): + def _data_complete(data: pd.DataFrame, **kwargs) -> (bool, dict): """ Checks if the data completely satisfies the request. @@ -341,17 +518,21 @@ class DataCache: return True, kwargs def cache_exists(self, key: str) -> bool: - return key in self.cached_data + return key in self.cache['key'].values def get_cache(self, key: str) -> Any | None: - if key not in self.cached_data: - logger.warning(f"The requested cache key({key}) doesn't exist!") + # Check if the key exists in the cache + if key not in self.cache['key'].values: + logger.warning(f"The requested data key ({key}) doesn't exist!") return None - return self.cached_data[key] - def update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None: - logger.debug('Updating cache with new records.') - # Concatenate the new records with the existing cache + # Retrieve the data associated with the key + result = self.cache[self.cache['key'] == key]['data'].iloc[0] + return result + + def _update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None: + logger.debug('Updating data with new records.') + # Concatenate the new records with the existing data records = pd.concat([self.get_cache(key), more_records], axis=0, ignore_index=True) # Drop duplicates based on 'open_time' and keep the first occurrence records = records.drop_duplicates(subset="open_time", keep='first') @@ -359,24 +540,49 @@ class DataCache: records = records.sort_values(by='open_time').reset_index(drop=True) # Reindex 'id' to ensure the expected order records['id'] = range(1, len(records) + 1) - # Set the updated DataFrame back to cache + # Set the updated DataFrame back to data self.set_cache(data=records, key=key) def update_cached_dict(self, cache_key: str, dict_key: str, data: Any) -> None: """ - Updates a dictionary stored in cache. + Updates a dictionary stored in the DataFrame cache. - :param data: The data to insert into cache. - :param cache_key: The cache index key for the dictionary. - :param dict_key: The dictionary key for the data. + :param data: The data to insert into the dictionary. + :param cache_key: The cache key for the dictionary. + :param dict_key: The key within the dictionary to update. :return: None """ - self.cached_data[cache_key].update({dict_key: data}) + # Locate the row in the DataFrame that matches the cache_key + cache_index = self.cache.index[self.cache['key'] == cache_key] + + if not cache_index.empty: + # Update the dictionary stored in the 'data' column + cache_dict = self.cache.at[cache_index[0], 'data'] + + if isinstance(cache_dict, dict): + cache_dict[dict_key] = data + + # Ensure the DataFrame is updated with the new dictionary + self.cache.at[cache_index[0], 'data'] = cache_dict + else: + raise ValueError(f"Expected a dictionary in cache, but found {type(cache_dict)}.") + else: + raise KeyError(f"Cache key '{cache_key}' not found.") def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None: - if do_not_overwrite and key in self.cached_data: + if do_not_overwrite and key in self.cache['key'].values: return - self.cached_data[key] = data + + # Corrected construction of the new row + new_row = pd.DataFrame({'key': [key], 'data': [data]}) + + # If the key already exists, drop the old entry + self.cache = self.cache[self.cache['key'] != key] + + # Append the new row to the cache + self.cache = pd.concat([self.cache, new_row], ignore_index=True) + + print(f'Current Cache: {self.cache}') logger.debug(f'Cache set for key: {key}') def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str, @@ -422,12 +628,12 @@ class DataCache: if num_rec_records < estimated_num_records: logger.info('Detected gaps in the data, attempting to fill missing records.') - candles = self.fill_data_holes(candles, interval) + candles = self._fill_data_holes(candles, interval) return candles @staticmethod - def fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame: + def _fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame: time_span = timeframe_to_timedelta(interval).total_seconds() / 60 last_timestamp = None filled_records = [] diff --git a/src/Database.py b/src/Database.py index 27f1aa4..b0477c1 100644 --- a/src/Database.py +++ b/src/Database.py @@ -33,7 +33,7 @@ class SQLite: class HDict(dict): """ - Hashable dictionary to use as cache keys. + Hashable dictionary to use as data keys. Example usage: -------------- @@ -85,15 +85,16 @@ class Database: def __init__(self, db_file: str = None): self.db_file = db_file - def execute_sql(self, sql: str) -> None: + def execute_sql(self, sql: str, params: tuple = ()) -> None: """ - Executes a raw SQL statement. + Executes a raw SQL statement with optional parameters. :param sql: SQL statement to execute. + :param params: Optional tuple of parameters to pass with the SQL statement. """ with SQLite(self.db_file) as con: cur = con.cursor() - cur.execute(sql) + cur.execute(sql, params) def get_item_where(self, item_name: str, table_name: str, filter_vals: Tuple[str, Any]) -> Any: """ @@ -120,12 +121,17 @@ class Database: :param table: Name of the table. :param filter_vals: Tuple of column name and value to filter by. - :return: DataFrame of the query result or None if empty. + :return: DataFrame of the query result or None if empty or column does not exist. """ - with SQLite(self.db_file) as con: - qry = f"SELECT * FROM {table} WHERE {filter_vals[0]} = ?" - result = pd.read_sql(qry, con, params=(filter_vals[1],)) - return result if not result.empty else None + try: + with SQLite(self.db_file) as con: + qry = f"SELECT * FROM {table} WHERE {filter_vals[0]} = ?" + result = pd.read_sql(qry, con, params=(filter_vals[1],)) + return result if not result.empty else None + except (sqlite3.OperationalError, pd.errors.DatabaseError) as e: + # Log the error or handle it appropriately + print(f"Error querying table '{table}' for column '{filter_vals[0]}': {e}") + return None def insert_dataframe(self, df: pd.DataFrame, table: str) -> None: """ diff --git a/src/Signals.py b/src/Signals.py index bdec03b..fe99d8e 100644 --- a/src/Signals.py +++ b/src/Signals.py @@ -52,11 +52,18 @@ class Signal: class Signals: - def __init__(self, loaded_signals=None): + def __init__(self, config): # list of Signal objects. self.signals = [] + # load a list of existing signals from file. + loaded_signals = config.get_setting('signals_list') + if loaded_signals is None: + # Populate the list and file with defaults defined in this class. + loaded_signals = self.get_signals_defaults() + config.set_setting('signals_list', loaded_signals) + # Initialize signals with loaded data. if loaded_signals is not None: self.create_signal_from_dic(loaded_signals) diff --git a/src/Strategies.py b/src/Strategies.py index 64301f9..fb0d303 100644 --- a/src/Strategies.py +++ b/src/Strategies.py @@ -1,5 +1,5 @@ import json - +from DataCache_v2 import DataCache class Strategy: def __init__(self, **args): @@ -154,14 +154,26 @@ class Strategy: class Strategies: - def __init__(self, loaded_strats, trades): + def __init__(self, data: DataCache, trades): + # Handles database connections + self.data = data + # Reference to the trades object that maintains all trading actions and data. self.trades = trades - # A list of all the Strategies created. - self.strat_list = [] - # Initialise all the stately objects with the data saved to file. - for entry in loaded_strats: - self.strat_list.append(Strategy(**entry)) + + def get_all_strategy_names(self) -> list | None: + """Return a list of all strategies in the database""" + self.data._get_from_database() + # Load existing Strategies from file. + loaded_strategies = config.get_setting('strategies') + if loaded_strategies is None: + # Populate the list and file with defaults defined in this class. + loaded_strategies = self.get_strategy_defaults() + config.set_setting('strategies', loaded_strategies) + + for entry in loaded_strategies: + # Initialise all the strategy objects with data from file. + self.strat_list.append(Strategy(**entry)) return None def new_strategy(self, data): # Create an instance of the new Strategy. @@ -238,6 +250,7 @@ class Strategies: published strategies and evaluates conditions against the data. This function returns a list of strategies and action commands. """ + def process_strategy(strategy): action, cmd = strategy.evaluate_strategy(signals) if action != 'do_nothing': @@ -262,4 +275,3 @@ class Strategies: return False else: return return_obj - diff --git a/src/Users.py b/src/Users.py index 7bcb84a..6f450ec 100644 --- a/src/Users.py +++ b/src/Users.py @@ -2,531 +2,399 @@ import datetime as dt import json import random from typing import Any - from passlib.hash import bcrypt import pandas as pd -from Database import HDict +from DataCache_v2 import DataCache -class Users: +class BaseUser: """ - Manages user data and states. + Handles basic user data retrieval and manipulation. + This is the base class for all user-related operations. """ - def __init__(self, data_cache): - - # The object that handles all cached data. - self.cache = data_cache - - # The unique identifiers for any guests in the system. - self.cache.set_cache(data=[], key='guest_suffixes', do_not_overwrite=True) - - # Maximum number of guest allowed in the system at one time. - self.max_guests = 100 - - # A place to cache data for users logged into the system. - self.cache.set_cache(data={}, key='cached_users', do_not_overwrite=True) - - # The class contains methods that interact with the database. - self.db = data_cache.db - - # Clear the status of any users that were signed in before the application last shut down. - # Configured it to recover the user, but I may put this back in. - # self.log_out_all_users() - - def get_indicators(self, user_name: str) -> pd.DataFrame | None: + def __init__(self, data_cache: DataCache): """ - Return a dataframe containing all the indicators for a specific user. + Initialize the BaseUser with caching and database interaction capabilities. - :param user_name: The name of the user. - :return: pd.Dataframe - Indicator attributes and properties or None if query fails. + :param data_cache: Object responsible for managing cached data and database interaction. """ - user_id = self.get_id(user_name) - df = self.db.get_rows_where(table='indicators', filter_vals=('creator', user_id)) - - # Convert the 'source' and 'properties' columns from strings to dictionaries - if df is not None: - df['source'] = df['source'].apply(json.loads) - df['properties'] = df['properties'].apply(json.loads) - - return df + self.data = data_cache def get_id(self, user_name: str) -> int: """ - Gets user id from the users table in the database. - :param user_name: str - The name of the user. - :return: The id of the user. + Retrieves the user ID based on the username. + + :param user_name: The name of the user. + :return: The ID of the user as an integer. """ - return self.db.get_item_where(item_name='id', table_name='users', filter_vals=('user_name', user_name)) + return self.data.fetch_cached_item( + item_name='id', + table_name='users', + filter_vals=('user_name', user_name) + ) - def get_username(self, u_id: int) -> str: + def _remove_user_from_memory(self, user_name: str) -> None: """ - Gets user_name from the users table in the database. - :param u_id: int - The id of the user. - :return: str - The name of the user. + Private method to remove a user's data from the cache (memory). + + :param user_name: The name of the user to remove from the cache. """ - return self.db.get_item_where(item_name='user_name', table_name='users', filter_vals=('id', u_id)) + # Use DataCache to remove the user from the cache only + self.data.remove_row( + table='users', + filter_vals=('user_name', user_name) + ) - def save_indicators(self, indicators: pd.DataFrame) -> None: + def delete_user(self, user_name: str) -> None: """ - Store one or many indicators in the database. + Deletes the user from both the cache and the database. - :param indicators: pd.dataframe - Indicator attributes and properties. - :return: None. + :param user_name: The name of the user to delete. """ - for _, indicator in indicators.iterrows(): - src_string = json.dumps(indicator['source']) - prop_string = json.dumps(indicator['properties']) - values = (indicator['creator'], indicator['name'], indicator['visible'], - indicator['kind'], src_string, prop_string) - columns = ('creator', 'name', 'visible', 'kind', 'source', 'properties') - self.db.insert_row(table='indicators', columns=columns, values=values) + self.data.remove_row( + filter_vals=('user_name', user_name), + table='users' + ) - def remove_indicator(self): - return - - def is_logged_in(self, user_name) -> bool: + def get_user_data(self, user_name: str) -> pd.DataFrame | None: """ - Returns True if the session var contains a currently logged-in user. - :param user_name: The name of the user. - :return: True|False + Retrieves user data from the cache or database. If not found in the cache, + it loads from the database and caches it. + + :param user_name: The name of the user. + :return: A DataFrame containing the user's data. + :raises ValueError: If the user is not found in both the cache and the database. """ + # Attempt to fetch the user data from the cache or database via DataCache + user = self.data.fetch_cached_rows( + table='users', + filter_vals=('user_name', user_name) + ) - def poll_login_flag(name: str): - # Get the user from the cache. - user = self.get_user_from_cache(name) + if user is None or user.empty: + raise ValueError(f"User '{user_name}' not found in database or cache!") - if user is not None: - # If the user is in the cache. - if user.at[0, 'status'] == 'logged_in': - # Return true if the user is logged in. - return True - else: - # Remove the user from the cached if the user's status flag is not set. - self.remove_user_from_cache(name) - # Return False. - return False + return user - # If the user was not in the cache, check the db incase the application has been restarted. - user = self.db.get_rows_where(table='users', filter_vals=('user_name', name)) - if user is None: - # The user is not in the database. - return False + def modify_user_data(self, username: str, field_name: str, new_data: Any) -> None: + """ + Updates user data in both the cache and the database. Ensures consistency between cache and database. + + :param username: The name of the user whose data is being modified. + :param field_name: The field to be updated. + :param new_data: The new data to be set. + """ + # Use DataCache to modify the user's data + self.data.modify_cached_row( + table='users', + filter_vals=('user_name', username), + field_name=field_name, + new_data=new_data + ) + + +class UserAccountManagement(BaseUser): + """ + Manages user login, logout, validation, and account creation processes. + """ + + def __init__(self, data_cache, max_guests: int = 100): + """ + Initialize the UserAccountManagement with caching, database interaction, + and guest management capabilities. + + :param data_cache: Object responsible for managing cached data and database interaction. + :param max_guests: Maximum number of guests allowed in the system at one time. + """ + super().__init__(data_cache) + + self.max_guests = max_guests # Maximum number of guests + + # Initialize data for guest suffixes and cached users + self.data.set_cache(data=[], key='guest_suffixes', do_not_overwrite=True) + self.data.set_cache(data={}, key='cached_users', do_not_overwrite=True) + + def is_logged_in(self, user_name: str) -> bool: + """ + Checks if the user is logged in by verifying their status in the cache or database. + + :param user_name: The name of the user. + :return: True if the user is logged in, False otherwise. + """ + if user_name is None: + return False + + def is_guest(_user) -> bool: + split_name = _user.at[0, 'user_name'].split("_") + return 'guest' in split_name and len(split_name) == 2 + + # Retrieve the user's data from either the cache or the database. + user = self.get_user_data(user_name) + + if user is not None and not user.empty: + # If the user exists, check their login status. if user.at[0, 'status'] == 'logged_in': - # Add the user back into the cache and return true. - self.add_user_to_cache(user) - # Check if the user was a guest. - split_name = user.at[0, 'user_name'].split("_") - if ('guest' in split_name) and (len(split_name) == 2): - # if it was, add the suffix to the cache. - guest_suffixes = self.cache.get_cache('guest_suffixes') - guest_suffixes.append(name[1]) - self.cache.set_cache(data=guest_suffixes, key='guest_suffixes') + # If the user is logged in, check if they are a guest. + if is_guest(user): + # Update the guest suffix cache if the user is a guest. + guest_suffixes = self.data.get_cache('guest_suffixes') + guest_suffixes.append(user_name[1]) + self.data.set_cache(data=guest_suffixes, key='guest_suffixes') return True else: - # The user is not logged in. + # If the user is not logged in, remove their data from the cache. + self._remove_user_from_memory(user_name) return False - if user_name is None: - # Not signed in if the session var is not set. - return False - if poll_login_flag(user_name): - # The user is logged in. - return True - else: - # The user is not logged in. + # If the user data is not found or the status is not 'logged_in', return False. + return False + + def validate_password(self, username: str, password: str) -> bool: + """ + Validates the user's password against the stored hash. + + :param username: The name of the user. + :param password: The plain text password to validate. + :return: True if the password is correct, False otherwise. + """ + # Retrieve the hashed password using DataCache + user_data = self.data.fetch_cached_rows(table='users', filter_vals=('user_name', username)) + + if user_data is None or user_data.empty: return False - def create_unique_guest_name(self): - # Retrieve the used suffixes from the cache. - guest_suffixes = self.cache.get_cache('guest_suffixes') - # Limit the number of guest allowed in the system at one time. - if len(guest_suffixes) > self.max_guests: - return None - # Pick a random suffix. - suffix = random.choice(range(0, (self.max_guests * 9))) - # If the suffix is already used try again until an unused one is found. - while suffix in guest_suffixes: - suffix = random.choice(range(0, (self.max_guests * 9))) - # return a unique user_name by appending the suffix. - return f'guest_{suffix}' + hashed_pass = user_data.iloc[0].get('password') - def create_guest(self) -> str | None: - """ - Insert a generic user into the database. - :return: str|None - Returns the user_name of the guest. | None if the - self.max_guest limit is exceeded or any other login error. - """ - - # if None is returned the guest limit is reached. - if (username := self.create_unique_guest_name()) is None: - return None - - attrs = ({'user_name': username},) - self.create_new_user_in_db(attrs=attrs) - - # Log the user in. - login_success = self.log_in_user(username=username, password='password') - if login_success: - return username - - def user_attr_is_taken(self, attr, val): - # Submit a request from the db for the satus of any user where attr == val. - status = self.db.get_from_static_table(item='status', table='users', indexes=HDict({attr: val})) - if status is None: - return False - else: - # If a result was returned the attribute is already taken. - return True - - @staticmethod - def scramble_text(some_text): - # More rounds increases complexity but makes it slower. - hasher = bcrypt.using(rounds=13) - return hasher.hash(some_text) - - def create_new_user_in_db(self, attrs: tuple): - """ - Retrieves a default user, modifies it, then appends the user table. - - :param attrs: tuple - a tuple of key, value dicts to apply to the new user. ({attr:value},...) - :return: None. - """ - # Get the default user template. - default_user = self.db.get_rows_where(table='users', filter_vals=('user_name', 'guest')) - - for each in attrs: - # Modify the every attribute received in the template. - key = str(*each) - default_user.loc[0, key] = each.get(key) - - # Drop the database index. - default_user = default_user.drop(columns='id') - - # Insert the modified user as a new record in the table. - self.db.insert_dataframe(df=default_user, table="users") - - def create_new_user(self, username: str, email: str, password: str): - - if self.user_attr_is_taken('user_name', username): + if not hashed_pass: return False - if self.user_attr_is_taken('email', email): + # Verify the password using bcrypt + try: + return bcrypt.verify(password, hashed_pass) + except ValueError: + # Handle any issues with the verification process return False - encrypted_password = self.scramble_text(password) - - attrs = ({'user_name': username}, {'email': email}, {'password': encrypted_password}) - self.create_new_user_in_db(attrs=attrs) - - return True - - def load_or_create_user(self, username: str) -> str: + def log_in_user(self, username: str, password: str) -> bool: """ - Loads user data either for existing user or a newly created user. + Validates user credentials and signs in the user by updating their status. - :param username: str|None - The user_name if exists in session. - :return: str - Either the same as passed in or a new user_name. - """ - - # Check if the user is logged in. - if not self.is_logged_in(user_name=username): - # Create a guest session if the user is not logged in. - username = self.create_guest() - if not username: - # There was a problem creating the user and system should display an error page. - raise ValueError('GuestLimitExceeded!') - # Load user data if not loaded. - self.load_user_data(user_name=username) - # Return the user_name. - return username - - def log_in_user(self, username: str, password: str): - """ - Validate User credentials and sign in the user. - - :param username: The name index for the user. - :param password: The unencrypted password. - :return: bool - True on success | False on fail. + :param username: The name of the user. + :param password: The unencrypted password. + :return: True on successful login, False otherwise. """ if success := self.validate_password(username=username, password=password): - # Set the user's status flag as logged in. self.modify_user_data(username=username, field_name="status", new_data="logged_in") - # Record the login time. self.modify_user_data(username=username, field_name="signin_time", new_data=dt.datetime.utcnow().timestamp()) return success - def log_out_user(self, username: str): + def log_out_user(self, username: str) -> bool: """ - Log out the user. + Logs out the user by updating their status and removing them from the cache. - :param username: The name of the user. - :return: bool + :param username: The name of the user. + :return: True on successful logout. """ - # Set the user's status flag as logged out. + # Update the user's status and sign-in time in both cache and database self.modify_user_data(username=username, field_name='status', new_data='logged_out') - # Record the logout time. self.modify_user_data(username=username, field_name='signin_time', new_data=dt.datetime.utcnow().timestamp()) - # Remove the user from the cache. - self.remove_user_from_cache(user_name=username) + + # Remove the user's data from the cache + self._remove_user_from_memory(user_name=username) + return True - def log_out_all_users(self): + def log_out_all_users(self, enforcement: str = 'hard') -> None: """ - Log out all users. - :return: None + Logs out all users by updating their status in the database and clearing the data. + + :param enforcement: 'soft' or 'hard' - Determines how strictly users are logged out. """ - for key in self.cache.cached_data: - if (self.cache.cached_data[key] is dict) and ('user_name' in self.cache.cached_data[key]): - # remove the user from the cache. - del self.cache.cached_data[key] + if enforcement == 'soft': + self._soft_log_out_all_users() + elif enforcement == 'hard': + # Clear all user-related entries from the cache + for index, row in self.data.cache.iterrows(): + if 'user_name' in row: + self._remove_user_from_memory(row['user_name']) - df = self.db.get_rows_where(table='users', filter_vals=('status', 'logged_in')) - if df is not None: - df = df[df.user_name != 'guest'] + df = self.data.fetch_cached_rows(table='users', filter_vals=('status', 'logged_in')) + if df is not None: + df = df[df.user_name != 'guest'] - # Set all logged-in user's status flag to logged_out. - ids = df.user_name.values - if ids is not None: - for value in ids: - sql = f"UPDATE users SET status = 'logged_out' WHERE user_name = '{value}';" - self.db.execute_sql(sql) - return + # Update the status of all logged-in users to 'logged_out' + for user_name in df.user_name.values: + self.modify_user_data(username=user_name, field_name='status', new_data='logged_out') + else: + raise ValueError("Invalid enforcement type. Use 'soft' or 'hard'.") - def load_user_data(self, user_name: str) -> None: + def _soft_log_out_all_users(self) -> None: """ - Check if user data is in the cache and load if it is not. - - :param user_name: The name of the user. - :return: None + Placeholder for soft logout logic. Soft logout might involve suspending logins, halting trades, + or allowing current operations to finish. (TODO: Implement soft logout logic) """ - # if self.cache.is_loaded(user_name): - # # Return if already loaded. - # return - - print(f'Retrieving this user: {user_name}') - # Request a dataframe containing the users properties. - user = self.get_user_from_db(user_name=user_name) - if user.empty: - raise ValueError('Attempted to load user not in db!') - - # Load chart view stored in the database and convert back to a dictionary. - chart_view = json.loads(user.at[0, 'chart_views']) - # todo : save chart view somewhere you can do it - # load indicators - pass - # load strategies + # TODO: Implement soft logout logic here pass - def modify_user_data(self, username: str, field_name: str, new_data: Any) -> None: + def user_attr_is_taken(self, attr: str, val: str) -> bool: """ - Update a user data field in both the cache and the database. + Checks if a specific user attribute (e.g., username, email) is already taken. - :param username: str - The name of the user. - :param field_name: str - The field to update. - :param new_data: Any - The data to be set. - :return: None + :param attr: The attribute to check (e.g., 'user_name', 'email'). + :param val: The value of the attribute to check. + :return: True if the attribute is already taken, False otherwise. """ + # Use DataCache to check if the attribute is taken + return self.data.is_attr_taken(table='users', attr=attr, val=val) - # Get the user data from the db. - user = self.db.get_rows_where(table='users', filter_vals=('user_name', username)) - - # Drop the db id column. - modified_user = user.drop(columns='id') - - # Modify the targeted field. - if type(new_data) is str: - modified_user.loc[0, field_name] = new_data - else: - modified_user.loc[0, field_name] = json.dumps(new_data) - - # Replace the user records in the cache with the updated data record. - self.cache.update_cached_dict(cache_key='cached_users', dict_key=username, data=modified_user) - - # Delete the old record from the db - sql = f"DELETE FROM users WHERE user_name = '{user.loc[0, 'user_name']}';" - self.db.execute_sql(sql) - - # Replace the records in the database. - self.db.insert_dataframe(modified_user, 'users') - - return - - def validate_password(self, username: str, password: str) -> bool: + def create_unique_guest_name(self) -> str | None: """ - Validate a user password with the hashed password stored in the db. + Creates a unique guest username by appending a random suffix. - :param username: str - The name of the user. - :param password: str - The plain text password. - :return: bool - True on success. + :return: A unique guest username or None if the guest limit is reached. """ - # Validate input - if (password is None) or (username is None): + guest_suffixes = self.data.get_cache('guest_suffixes') + if len(guest_suffixes) > self.max_guests: + return None + suffix = random.choice(range(0, (self.max_guests * 9))) + while suffix in guest_suffixes: + suffix = random.choice(range(0, (self.max_guests * 9))) + return f'guest_{suffix}' + + def create_guest(self) -> str | None: + """ + Creates a guest user in the database and logs them in. + + :return: The guest username or None if the guest limit is reached. + """ + if (username := self.create_unique_guest_name()) is None: + return None + attrs = ({'user_name': username},) + self.create_new_user_in_db(attrs=attrs) + + login_success = self.log_in_user(username=username, password='password') + if login_success: + return username + + def create_new_user_in_db(self, attrs: tuple) -> None: + """ + Creates a new user in the database by modifying a default template. + + :param attrs: A tuple of key-value pairs representing user attributes. + :raises ValueError: If attrs are not well-formed. + """ + if not attrs or not all(isinstance(attr, dict) and len(attr) == 1 for attr in attrs): + raise ValueError("Attributes must be a tuple of single key-value pair dictionaries.") + + # Retrieve the default user template from the database using DataCache + default_user = self.data.fetch_cached_rows(table='users', filter_vals=('user_name', 'guest')) + + if default_user is None or default_user.empty: + raise ValueError("Default user template not found in the database.") + + # Modify the default user template with the provided attributes + for attr in attrs: + key, value = next(iter(attr.items())) + default_user.loc[0, key] = value + + # Remove the 'id' column before inserting into the database + default_user = default_user.drop(columns='id') + + # Insert the modified user data into the database, skipping cache insertion + self.data.insert_data(df=default_user, table="users", skip_cache=True) + + def create_new_user(self, username: str, email: str, password: str) -> bool: + """ + Creates a new user in the system if the username and email are not already taken. + + :param username: The desired username. + :param email: The user's email address. + :param password: The user's password. + :return: True if the user was successfully created, False otherwise. + """ + if self.user_attr_is_taken('user_name', username) or self.user_attr_is_taken('email', email): return False + encrypted_password = self.scramble_text(password) + attrs = ({'user_name': username}, {'email': email}, {'password': encrypted_password}) + self.create_new_user_in_db(attrs=attrs) + return True - # Get the password from the database. - hashed_pass = self.db.get_from_static_table(item='password', table='users', - indexes=HDict({'user_name': username})) - if hashed_pass is None: - return False - hasher = bcrypt.using(rounds=13) # Make it slower - return hasher.verify(password, hashed_pass) - - def get_chart_view(self, user_name: str, prop: str | None = None): + def load_or_create_user(self, username: str) -> str: """ - Fetches the chart view or one specific_property of it for a specific user. + Loads existing user data or creates a guest user if the user is not logged in. - :param user_name: str - The name of the user the query is for. - :param prop: str|None - Optional the specific specific_property of the chart view. - :return: dict|str - The value of the specific_property specified in specific_property or a dict of 3 values. + :param username: The username to load or create. + :return: The username of the loaded or newly created user. + :raises ValueError: If a guest user cannot be created. """ + if not self.is_logged_in(user_name=username): + username = self.create_guest() + if not username: + raise ValueError('GuestLimitExceeded!') + self.get_user_data(user_name=username) + return username - # Get a pd.DataFrame that contains the user properties. - print(f'Retrieving this user: {user_name}') - user = self.get_user_from_db(user_name=user_name) - if user.empty: - return - # The chart view is dict object stored as a string in the database to simplify the data storage. - # This converts it back to a dictionary. - # chart_view = literal_eval(user.at[0, 'chart_views']) - chart_view = json.loads(user.at[0, 'chart_views']) - - if prop is None: - # If no specific_property is specified return the entire dict. - return chart_view - if prop == 'exchange_name': - # Return the exchange_name name as a string. - return chart_view.get('exchange') - if prop == 'timeframe': - # Return the timeframe as a string. - return chart_view.get('timeframe') - if prop == 'market': - # Return the market as a string. - return chart_view.get('market') - - def set_chart_view(self, user_name: str, values: dict | str, - specific_property: str | None = None, default_market: str | None = None): + @staticmethod + def scramble_text(some_text: str) -> str: """ + Encrypts text using bcrypt with 13 rounds of hashing. - :param user_name: str - The name of the user whose data is being modified. - :param values: str|list - The value of the property or values of a dict to be set. - :param specific_property: str|None - Optional the specific property of the chart view to modify. - :param default_market: str|None - a default market must be provided to display when - changing the exchange_name view. - :return: + :param some_text: The text to encrypt. + :return: The hashed text. """ + return bcrypt.hash(some_text, rounds=13) - # Get a pd.DataFrame that contains the user properties. - user = self.get_user_from_db(user_name=user_name) - if specific_property is None: - # Validate values - assert (type(values) is dict) - assert (len(values) == 3) - # user['chart_view'] = str(values) - # TODO!!! setting directly then useing modifiy user data elsewhere - user['chart_view'] = json.dumps(values) - return - - # Validate specific_property - assert (type(specific_property) is str) - - # chart_view = literal_eval(user.at[0, 'chart_views']) - chart_view = json.loads(user.at[0, 'chart_views']) - if specific_property == 'exchange_name': - # Ensure a market view is always passed in while changing the exchange_name. - assert (default_market is not None) - # Set the market. - chart_view['market'] = default_market - # Set the exchange_name. - chart_view['exchange_name'] = values - elif specific_property == 'timeframe': - # Set the timeframe - chart_view['timeframe'] = values - elif specific_property == 'market': - # Set the market. - chart_view['market'] = values - else: - raise ValueError(f'{specific_property} is not a specific_property of chart_views') - - self.modify_user_data(username=user_name, field_name='chart_views', new_data=chart_view) - return - - def get_user_from_db(self, user_name: str) -> pd.DataFrame | None: - user = self.db.get_rows_where(table='users', filter_vals=('user_name', user_name)) - if not user.empty: - return user - - def get_user_from_cache(self, user_name: str) -> pd.DataFrame | None: - """ - Returns the DataFrame object that contains the data of only . - or None if the user is not logged in. - """ - # Get a dictionary of all cached users data indexed by user_name. - if all_users := self.cache.get_cache('cached_users'): - # Return the specific data. - return all_users.get(user_name) - - def remove_user_from_cache(self, user_name: str) -> None: - """ - Removes a user object from the cache. - """ - # Get a dictionary of the cached user data indexed by user_name. - users = self.cache.get_cache('cached_users') - # Remove the specific user from the cached data. - del users[user_name] - # Reset the cached data - self.cache.set_cache(data=users, key='cached_users') - return - - def add_user_to_cache(self, user: pd.DataFrame) -> None: - """ - Adds a user object to the cache. - """ - self.cache.update_cached_dict(cache_key='cached_users', dict_key=user.at[0, 'user_name'], data=user) - return +class UserExchangeManagement(UserAccountManagement): + """ + Manages user exchange-related data and operations. + """ def get_api_keys(self, user_name: str, exchange: str) -> dict: - # Get the user records from the database. - user = self.get_user_from_db(user_name) + """ + Retrieves the API keys for a specific exchange associated with a user. + + :param user_name: The name of the user. + :param exchange: The name of the exchange. + :return: A dictionary containing the API keys for the exchange. + """ + user = self.get_user_data(user_name) + if user is None or user.empty or 'api_keys' not in user.columns: + return {} - # Get the api keys dictionary. user_keys = user.loc[0, 'api_keys'] user_keys = json.loads(user_keys) if user_keys else {} return user_keys.get(exchange) - def update_api_keys(self, api_keys, exchange, user_name: str): - # Get the user records from the database. - user = self.get_user_from_db(user_name) + def update_api_keys(self, api_keys, exchange, user_name: str) -> None: + """ + Updates the API keys for a specific exchange and user, and activates the exchange. - # Get the old api keys dictionary. + :param api_keys: The new API keys to store. + :param exchange: The exchange for which the keys are being updated. + :param user_name: The name of the user. + """ + user = self.get_user_data(user_name) user_keys = user.loc[0, 'api_keys'] user_keys = json.loads(user_keys) if user_keys else {} - # Add the new keys to the dictionary. user_keys.update({exchange: api_keys}) - # Modify the cache and db. self.modify_user_data(username=user_name, field_name='api_keys', new_data=json.dumps(user_keys)) - # Get the old list of configured exchanges. configured_exchanges = json.loads(user.loc[0, 'configured_exchanges']) - - # Check if the exchange is already in the list if exchange not in configured_exchanges: - # Append the list of configured exchanges. configured_exchanges.append(exchange) - # Modify the cache and db. self.modify_user_data(username=user_name, field_name='configured_exchanges', new_data=json.dumps(configured_exchanges)) - # Activate the newly configured exchange. self.active_exchange(exchange=exchange, user_name=user_name, cmd='set') def get_exchanges(self, user_name: str, category: str) -> list | None: @@ -535,17 +403,13 @@ class Users: :param user_name: The name of the user. :param category: The category to retrieve ('active_exchanges' or 'configured_exchanges'). - :return: List of exchanges or None if user or category is not found. + :return: A list of exchanges or None if user or category is not found. """ try: - # Get the user records from the database. - user = self.get_user_from_db(user_name) - # Get the exchanges list based on the field. + user = self.get_user_data(user_name) exchanges = user.loc[0, category] - # Return the list if it exists, otherwise return an empty list. return json.loads(exchanges) if exchanges else [] except (KeyError, IndexError, json.JSONDecodeError) as e: - # Log the error to the console print(f"Error retrieving exchanges for user '{user_name}' and field '{category}': {str(e)}") return None @@ -553,31 +417,157 @@ class Users: """ Inserts, removes, or toggles an exchange name in a list of active exchanges. - :param cmd: str - 'toggle' | 'set' | 'remove' - :param exchange: str - The name of the exchange. - :param user_name: str - The name of the user executing this command. - :return: None + :param cmd: 'toggle' | 'set' | 'remove' - The action to perform on the exchange. + :param exchange: The name of the exchange. + :param user_name: The name of the user executing this command. """ - # Get the user records from the database. - user = self.get_user_from_db(user_name) - # Get the old active_exchanges list, or initialize as an empty list if it is None. + user = self.get_user_data(user_name) active_exchanges = user.loc[0, 'active_exchanges'] if active_exchanges is None: active_exchanges = [] else: active_exchanges = json.loads(active_exchanges) - # Define the actions for each command actions = { 'toggle': lambda x: active_exchanges.remove(x) if x in active_exchanges else active_exchanges.append(x), 'set': lambda x: active_exchanges.append(x) if x not in active_exchanges else active_exchanges, 'remove': lambda x: active_exchanges.remove(x) if x in active_exchanges else active_exchanges } - # Perform the action based on the command action = actions.get(cmd) if action: action(exchange) - # Modify the cache and db. self.modify_user_data(username=user_name, field_name='active_exchanges', new_data=json.dumps(active_exchanges)) + +class UserIndicatorManagement(UserExchangeManagement): + """ + Manages user indicators and related operations. + """ + + def get_indicators(self, user_name: str) -> pd.DataFrame | None: + """ + Retrieves all indicators for a specific user. + + :param user_name: The name of the user. + :return: A DataFrame containing the user's indicators or None if not found. + """ + # Retrieve the user's ID + user_id = self.get_id(user_name) + + # Fetch the indicators from the database using DataCache + df = self.data.fetch_cached_rows(table='indicators', filter_vals=('creator', user_id)) + + # If indicators are found, process the JSON fields + if df is not None and not df.empty: + df['source'] = df['source'].apply(json.loads) + df['properties'] = df['properties'].apply(json.loads) + + return df + + def save_indicators(self, indicators: pd.DataFrame) -> None: + """ + Stores one or many indicators in the database. + + :param indicators: A DataFrame containing indicator attributes and properties. + """ + for _, indicator in indicators.iterrows(): + # Convert necessary fields to JSON strings + src_string = json.dumps(indicator['source']) + prop_string = json.dumps(indicator['properties']) + + # Prepare the values and columns for insertion + values = (indicator['creator'], indicator['name'], indicator['visible'], + indicator['kind'], src_string, prop_string) + columns = ('creator', 'name', 'visible', 'kind', 'source', 'properties') + + # Insert the row into the database using DataCache + self.data.insert_row(table='indicators', columns=columns, values=values) + + def remove_indicator(self, indicator_name: str, user_name: str) -> None: + """ + Removes a specific indicator from the database and cache. + + :param indicator_name: The name of the indicator to remove. + :param user_name: The name of the user who created the indicator. + """ + user_id = self.get_id(user_name) + self.data.remove_row( + filter_vals=('name', indicator_name), + additional_filter=('creator', user_id), + table='indicators' + ) + + def get_chart_view(self, user_name: str, prop: str | None = None): + """ + Fetches the chart view or one specific property of it for a specific user. + + :param user_name: The name of the user. + :param prop: Optional specific property of the chart view to retrieve. + :return: A dictionary of chart views or a specific property if specified. + """ + user = self.get_user_data(user_name) + if user.empty: + return + chart_view = json.loads(user.at[0, 'chart_views']) + + if prop is None: + return chart_view + if prop == 'exchange_name': + return chart_view.get('exchange') + if prop == 'timeframe': + return chart_view.get('timeframe') + if prop == 'market': + return chart_view.get('market') + + def set_chart_view(self, user_name: str, values: dict | str, + specific_property: str | None = None, default_market: str | None = None) -> None: + """ + Sets or updates the chart view for a specific user. + + :param user_name: The name of the user. + :param values: The values to set for the chart view. + :param specific_property: Optional specific property of the chart view to modify. + :param default_market: Default market to display when changing the exchange view. + :raises ValueError: If inputs are not valid. + """ + user = self.get_user_data(user_name) + + if specific_property is None: + if not isinstance(values, dict) or len(values) != 3: + raise ValueError("Values must be a dictionary with exactly 3 key-value pairs.") + user['chart_view'] = json.dumps(values) + return + + if not isinstance(specific_property, str): + raise ValueError("Specific property must be a string.") + + chart_view = json.loads(user.at[0, 'chart_views']) + if specific_property == 'exchange_name': + if default_market is None: + raise ValueError("Default market must be provided when setting the exchange name.") + chart_view['market'] = default_market + chart_view['exchange_name'] = values + elif specific_property == 'timeframe': + chart_view['timeframe'] = values + elif specific_property == 'market': + chart_view['market'] = values + else: + raise ValueError(f'{specific_property} is not a specific property of chart_views') + + self.modify_user_data(username=user_name, field_name='chart_views', new_data=chart_view) + + +class Users(UserIndicatorManagement): + """ + Comprehensive user management class that inherits all functionalities from UserIndicatorManagement. + """ + + def __init__(self, data_cache, max_guests: int = 100): + """ + Initialize the Users class with caching, database interaction, and user management capabilities. + + :param data_cache: Object responsible for managing cached data and database interaction. + :param max_guests: Maximum number of guests allowed in the system at one time. + """ + super().__init__(data_cache, max_guests) diff --git a/src/app.py b/src/app.py index 66ab56e..388accc 100644 --- a/src/app.py +++ b/src/app.py @@ -7,7 +7,6 @@ from flask_sock import Sock from email_validator import validate_email, EmailNotValidError # Handles all updates and requests for locally stored data. -import config from BrighterTrades import BrighterTrades # Set up logging @@ -54,7 +53,7 @@ def index(): try: # Log the user in. - user_name = brighter_trades.config.users.load_or_create_user(username=session.get('user')) + user_name = brighter_trades.users.load_or_create_user(username=session.get('user')) except ValueError as e: if str(e) != 'GuestLimitExceeded!': raise diff --git a/src/archived_code/DataCache.py b/src/archived_code/DataCache.py index c2388ca..2853bff 100644 --- a/src/archived_code/DataCache.py +++ b/src/archived_code/DataCache.py @@ -37,16 +37,16 @@ class DataCache: def cache_exists(self, key: str) -> bool: """ - Checks if a cache exists for the given key. + Checks if a data exists for the given key. :param key: The access key. - :return: True if cache exists, False otherwise. + :return: True if data exists, False otherwise. """ return key in self.cached_data def update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None: """ - Adds records to existing cache. + Adds records to existing data. :param more_records: The new records to be added. :param key: The access key. @@ -59,9 +59,9 @@ class DataCache: def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None: """ - Creates a new cache key and inserts data. + Creates a new data key and inserts data. - :param data: The records to insert into cache. + :param data: The records to insert into data. :param key: The index key for the data. :param do_not_overwrite: Flag to prevent overwriting existing data. :return: None @@ -72,10 +72,10 @@ class DataCache: def update_cached_dict(self, cache_key: str, dict_key: str, data: Any) -> None: """ - Updates a dictionary stored in cache. + Updates a dictionary stored in data. - :param data: The data to insert into cache. - :param cache_key: The cache index key for the dictionary. + :param data: The data to insert into data. + :param cache_key: The data index key for the dictionary. :param dict_key: The dictionary key for the data. :return: None """ @@ -89,7 +89,7 @@ class DataCache: :return: Any|None - The requested data or None on key error. """ if key not in self.cached_data: - logger.warning(f"The requested cache key({key}) doesn't exist!") + logger.warning(f"The requested data key({key}) doesn't exist!") return None return self.cached_data[key] @@ -98,14 +98,14 @@ class DataCache: """ Fetches records since the specified start datetime. - :param key: The cache key. + :param key: The data key. :param start_datetime: The start datetime to fetch records from. :param record_length: The required number of records. :param ex_details: Exchange details. :return: DataFrame containing the records. """ try: - target = 'cache' + target = 'data' args = { 'key': key, 'start_datetime': start_datetime, @@ -167,7 +167,7 @@ class DataCache: 'record_length': record_length, } - if target == 'cache': + if target == 'data': result = get_from_cache() if data_complete(result, **request_criteria): return result @@ -193,7 +193,7 @@ class DataCache: """ Fetches records since the specified start datetime. - :param key: The cache key. + :param key: The data key. :param start_datetime: The start datetime to fetch records from. :param record_length: The required number of records. :param ex_details: Exchange details. @@ -203,11 +203,11 @@ class DataCache: end_datetime = dt.datetime.utcnow() if self.cache_exists(key=key): - logger.debug('Getting records from cache.') + logger.debug('Getting records from data.') records = self.get_cache(key) else: logger.debug( - f'Records not in cache. Requesting from DB: starting at: {start_datetime} to: {end_datetime}') + f'Records not in data. Requesting from DB: starting at: {start_datetime} to: {end_datetime}') records = self.get_records_since_from_db(table_name=key, st=start_datetime, et=end_datetime, rl=record_length, ex_details=ex_details) logger.debug(f'Got {len(records.index)} records from DB.') diff --git a/src/candles.py b/src/candles.py index 2c00ccc..3ac8ec9 100644 --- a/src/candles.py +++ b/src/candles.py @@ -9,23 +9,23 @@ from shared_utilities import timeframe_to_minutes, ts_of_n_minutes_ago # log.basicConfig(level=log.ERROR) class Candles: - def __init__(self, exchanges, config_obj, data_source): + def __init__(self, exchanges, users, data_source, config): # A reference to the app configuration - self.config = config_obj + self.users = users # The maximum amount of candles to load at one time. - self.max_records = self.config.app_data.get('max_data_loaded') + self.max_records = config.get_setting('max_data_loaded') # This object maintains all the cached data. self.data = data_source - # print('Setting the candle cache.') - # # Populate the cache: - # self.set_cache(symbol=self.config.users.get_chart_view(user_name='guest', specific_property='market'), - # interval=self.config.users.get_chart_view(user_name='guest', specific_property='timeframe'), - # exchange_name=self.config.users.get_chart_view(user_name='guest', specific_property='exchange_name')) - # print('DONE Setting cache') + # print('Setting the candle data.') + # # Populate the data: + # self.set_cache(symbol=self.users.get_chart_view(user_name='guest', specific_property='market'), + # interval=self.users.get_chart_view(user_name='guest', specific_property='timeframe'), + # exchange_name=self.users.get_chart_view(user_name='guest', specific_property='exchange_name')) + # print('DONE Setting data') def get_last_n_candles(self, num_candles: int, asset: str, timeframe: str, exchange: str, user_name: str): """ @@ -83,7 +83,7 @@ class Candles: # def set_cache(self, symbol=None, interval=None, exchange_name=None, user_name=None): """ - This method requests a chart from memory to ensure the cache is initialized. + This method requests a chart from memory to ensure the data is initialized. :param user_name: :param symbol: str - The symbol of the market. @@ -91,18 +91,18 @@ class Candles: :param exchange_name: str - The name of the exchange_name to fetch from. :return: None """ - # By default, initialise cache with the last viewed chart. + # By default, initialise data with the last viewed chart. if not symbol: assert user_name is not None - symbol = self.config.users.get_chart_view(user_name=user_name, prop='market') + symbol = self.users.get_chart_view(user_name=user_name, prop='market') log.info(f'set_candle_history(): No symbol provided. Using{symbol}') if not interval: assert user_name is not None - interval = self.config.users.get_chart_view(user_name=user_name, prop='timeframe') + interval = self.users.get_chart_view(user_name=user_name, prop='timeframe') log.info(f'set_candle_history(): No timeframe provided. Using{interval}') if not exchange_name: assert user_name is not None - exchange_name = self.config.users.get_chart_view(user_name=user_name, prop='exchange_name') + exchange_name = self.users.get_chart_view(user_name=user_name, prop='exchange_name') # Log the completion to the console. log.info('set_candle_history(): Loading candle data...') @@ -199,20 +199,20 @@ class Candles: :param user_name: str - The name of the user who owns the exchange. :return: list - Candle records in the lightweight charts format. """ - # By default, initialise cache with the last viewed chart. + # By default, initialise data with the last viewed chart. if not symbol: assert user_name is not None - symbol = self.config.users.get_chart_view(user_name=user_name, prop='market''market') + symbol = self.users.get_chart_view(user_name=user_name, prop='market''market') log.info(f'set_candle_history(): No symbol provided. Using{symbol}') if not interval: assert user_name is not None - interval = self.config.users.get_chart_view(user_name=user_name, prop='market''timeframe') + interval = self.users.get_chart_view(user_name=user_name, prop='market''timeframe') log.info(f'set_candle_history(): No timeframe provided. Using{interval}') if not exchange_name: assert user_name is not None - exchange_name = self.config.users.get_chart_view(user_name=user_name, prop='market''exchange_name') + exchange_name = self.users.get_chart_view(user_name=user_name, prop='market''exchange_name') log.info(f'get_candle_history(): No exchange name provided. Using {exchange_name}') candlesticks = self.get_last_n_candles(num_candles=num_records, asset=symbol, timeframe=interval, diff --git a/src/indicators.py b/src/indicators.py index 47bf112..0cd77be 100644 --- a/src/indicators.py +++ b/src/indicators.py @@ -308,12 +308,12 @@ indicator_types.append('MACD') class Indicators: - def __init__(self, candles, config): + def __init__(self, candles, users): # Object manages and serves price and candle data. self.candles = candles - # A connection to an object that handles user configuration and persistent data. - self.config = config + # A connection to an object that handles user data. + self.users = users # Collection of instantiated indicators objects self.indicators = pd.DataFrame(columns=['creator', 'name', 'visible', @@ -331,7 +331,7 @@ class Indicators: Get the users watch-list from the database and load the indicators into a dataframe. :return: None """ - active_indicators: pd.DataFrame = self.config.users.get_indicators(user_name) + active_indicators: pd.DataFrame = self.users.get_indicators(user_name) if active_indicators is not None: # Create an instance for each indicator. @@ -347,7 +347,7 @@ class Indicators: Saves the indicators in the database indexed by the user id. :return: None """ - self.config.users.save_indicators(indicator) + self.users.save_indicators(indicator) # @staticmethod # def get_indicator_defaults(): @@ -389,7 +389,7 @@ class Indicators: :param only_enabled: bool - If True, return only indicators marked as visible. :return: dict - A dictionary of indicator names as keys and their attributes as values. """ - user_id = self.config.users.get_id(username) + user_id = self.users.get_id(username) if not user_id: raise ValueError(f"Invalid user_name: {username}") @@ -498,7 +498,7 @@ class Indicators: :param num_results: The number of results being requested. :return: The results of the indicator analysis as a DataFrame. """ - username = self.config.users.get_username(indicator.creator) + username = self.users.get_username(indicator.creator) src = indicator.source symbol, timeframe, exchange_name = src['symbol'], src['timeframe'], src['exchange_name'] @@ -532,7 +532,7 @@ class Indicators: if start_ts: print("Warning: start_ts has not implemented in get_indicator_data()!") - user_id = self.config.users.get_id(user_name=user_name) + user_id = self.users.get_id(user_name=user_name) # Construct the query based on user_id and visibility. query = f"creator == {user_id}" @@ -582,7 +582,7 @@ class Indicators: if not indicator_name: raise ValueError("No indicator name provided.") self.indicators = self.indicators.query("name != @indicator_name").reset_index(drop=True) - self.config.users.save_indicators() + self.users.save_indicators() def create_indicator(self, creator: str, name: str, kind: str, source: dict, properties: dict, visible: bool = True): @@ -618,7 +618,7 @@ class Indicators: indicator = indicator_class(name, kind, properties) # Add the new indicator to a pandas dataframe. - creator_id = self.config.users.get_id(creator) + creator_id = self.users.get_id(creator) row_data = { 'creator': creator_id, 'name': name, diff --git a/src/maintenence/debuging_testing.py b/src/maintenence/debuging_testing.py index 5d3e367..0f13369 100644 --- a/src/maintenence/debuging_testing.py +++ b/src/maintenence/debuging_testing.py @@ -1,38 +1,33 @@ -import ccxt import pandas as pd -import datetime +import matplotlib.pyplot as plt +from matplotlib.table import Table + +# Simulating the cache as a DataFrame +data = { + 'key': ['BTC/USD_2h_binance', 'ETH/USD_1h_coinbase'], + 'data': ['{"open": 50000, "close": 50500}', '{"open": 1800, "close": 1825}'] +} +cache_df = pd.DataFrame(data) -def fetch_and_print_ohlcv(exchange_name, symbol, timeframe, since, limit=5): - # Initialize the exchange - exchange_class = getattr(ccxt, exchange_name) - exchange = exchange_class() +# Visualization function +def visualize_cache(df): + fig, ax = plt.subplots(figsize=(6, 3)) + ax.set_axis_off() + tb = Table(ax, bbox=[0, 0, 1, 1]) - # Fetch historical candlestick data with a limit - ohlcv = exchange.fetch_ohlcv(symbol, timeframe, since=since, limit=limit) + # Adding column headers + for i, column in enumerate(df.columns): + tb.add_cell(0, i, width=0.4, height=0.3, text=column, loc='center', facecolor='lightgrey') - # Convert to DataFrame for better readability - df = pd.DataFrame(ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']) + # Adding rows and cells + for i in range(len(df)): + for j, value in enumerate(df.iloc[i]): + tb.add_cell(i + 1, j, width=0.4, height=0.3, text=value, loc='center', facecolor='white') - # Print the first few rows of the DataFrame - print("First few rows of the fetched OHLCV data:") - print(df.head()) - - # Print the timestamps in human-readable format - df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms') - print("\nFirst few timestamps in human-readable format:") - print(df[['timestamp', 'datetime']].head()) - - # Confirm the format of the timestamps - print("\nTimestamp format confirmation:") - for ts in df['timestamp']: - print(f"{ts} (milliseconds since Unix epoch)") + ax.add_table(tb) + plt.title("Visualizing Cache Data") + plt.show() -# Example usage -exchange_name = 'binance' # Change this to your exchange -symbol = 'BTC/USDT' -timeframe = '5m' -since = int((datetime.datetime(2024, 8, 1) - datetime.datetime(1970, 1, 1)).total_seconds() * 1000) - -fetch_and_print_ohlcv(exchange_name, symbol, timeframe, since, limit=5) +visualize_cache(cache_df) diff --git a/src/trade.py b/src/trade.py index 263b678..75b3976 100644 --- a/src/trade.py +++ b/src/trade.py @@ -267,10 +267,10 @@ class Trade: class Trades: - def __init__(self, loaded_trades=None): + def __init__(self, users): """ This class receives, executes, tracks and stores all active_trades. - :param loaded_trades: A bunch of active_trades to create and store. + :param users: A class that maintains users each user may have trades. """ # Object that maintains exchange_interface and account data self.exchange_interface = None @@ -290,7 +290,8 @@ class Trades: self.settled_trades = [] self.stats = {'num_trades': 0, 'total_position': 0, 'total_position_value': 0} - # Load any trades that were passed into the constructor. + # Load all trades. + loaded_trades = users.get_all_active_user_trades() if loaded_trades is not None: # Create the active_trades loaded from file. self.load_trades(loaded_trades) diff --git a/tests/test_DataCache_v2.py b/tests/test_DataCache_v2.py index 6c729ab..77edae4 100644 --- a/tests/test_DataCache_v2.py +++ b/tests/test_DataCache_v2.py @@ -1,5 +1,5 @@ import pytz -from DataCache_v2 import DataCache +from DataCache_v2 import DataCache, timeframe_to_timedelta, estimate_record_count from ExchangeInterface import ExchangeInterface import unittest import pandas as pd @@ -224,10 +224,23 @@ class TestDataCacheV2(unittest.TestCase): exchange_id INTEGER, FOREIGN KEY (exchange_id) REFERENCES exchange(id) )""" + sql_create_table_4 = f""" + CREATE TABLE IF NOT EXISTS test_table_2 ( + key TEXT PRIMARY KEY, + data TEXT NOT NULL + )""" + sql_create_table_5 = f""" + CREATE TABLE IF NOT EXISTS users ( + users_data TEXT PRIMARY KEY, + data TEXT NOT NULL + )""" + with SQLite(db_file=self.db_file) as con: con.execute(sql_create_table_1) con.execute(sql_create_table_2) con.execute(sql_create_table_3) + con.execute(sql_create_table_4) + con.execute(sql_create_table_5) self.data = DataCache(self.exchanges) self.data.db = self.database @@ -241,30 +254,41 @@ class TestDataCacheV2(unittest.TestCase): def test_set_cache(self): print('\nTesting set_cache() method without no-overwrite flag:') self.data.set_cache(data='data', key=self.key) - attr = self.data.__getattribute__('cached_data') - self.assertEqual(attr[self.key], 'data') - print(' - Set cache without no-overwrite flag passed.') + + # Access the cache data using the DataFrame structure + cached_value = self.data.get_cache(key=self.key) + self.assertEqual(cached_value, 'data') + print(' - Set data without no-overwrite flag passed.') print('Testing set_cache() once again with new data without no-overwrite flag:') self.data.set_cache(data='more_data', key=self.key) - attr = self.data.__getattribute__('cached_data') - self.assertEqual(attr[self.key], 'more_data') - print(' - Set cache with new data without no-overwrite flag passed.') + + # Access the updated cache data + cached_value = self.data.get_cache(key=self.key) + self.assertEqual(cached_value, 'more_data') + print(' - Set data with new data without no-overwrite flag passed.') print('Testing set_cache() method once again with more data with no-overwrite flag set:') self.data.set_cache(data='even_more_data', key=self.key, do_not_overwrite=True) - attr = self.data.__getattribute__('cached_data') - self.assertEqual(attr[self.key], 'more_data') - print(' - Set cache with no-overwrite flag passed.') + + # Since do_not_overwrite is True, the cached data should not change + cached_value = self.data.get_cache(key=self.key) + self.assertEqual(cached_value, 'more_data') + print(' - Set data with no-overwrite flag passed.') def test_cache_exists(self): print('Testing cache_exists() method:') - self.assertFalse(self.data.cache_exists(key=self.key)) - print(' - Check for non-existent cache passed.') + # Check that the cache does not contain the key before setting it + self.assertFalse(self.data.cache_exists(key=self.key)) + print(' - Check for non-existent data passed.') + + # Set the cache with a DataFrame containing the key-value pair self.data.set_cache(data='data', key=self.key) + + # Check that the cache now contains the key self.assertTrue(self.data.cache_exists(key=self.key)) - print(' - Check for existent cache passed.') + print(' - Check for existent data passed.') def test_update_candle_cache(self): print('Testing update_candle_cache() method:') @@ -273,21 +297,21 @@ class TestDataCacheV2(unittest.TestCase): data_gen = DataGenerator('5m') # Create initial DataFrame and insert into cache - df_initial = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 0, 0)) + df_initial = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 0, 0, tzinfo=dt.timezone.utc)) print(f'Inserting this table into cache:\n{df_initial}\n') self.data.set_cache(data=df_initial, key=self.key) # Create new DataFrame to be added to cache - df_new = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 15, 0)) + df_new = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 15, 0, tzinfo=dt.timezone.utc)) print(f'Updating cache with this table:\n{df_new}\n') - self.data.update_candle_cache(more_records=df_new, key=self.key) + self.data._update_candle_cache(more_records=df_new, key=self.key) # Retrieve the resulting DataFrame from cache result = self.data.get_cache(key=self.key) print(f'The resulting table in cache is:\n{result}\n') # Create the expected DataFrame - expected = data_gen.create_table(num_rec=6, start=dt.datetime(2024, 8, 9, 0, 0, 0)) + expected = data_gen.create_table(num_rec=6, start=dt.datetime(2024, 8, 9, 0, 0, 0, tzinfo=dt.timezone.utc)) print(f'The expected open_time values are:\n{expected["open_time"].tolist()}\n') # Assert that the open_time values in the result match those in the expected DataFrame, in order @@ -295,37 +319,50 @@ class TestDataCacheV2(unittest.TestCase): f"open_time values in result are {result['open_time'].tolist()}" \ f" but expected {expected['open_time'].tolist()}" - print(f'The results open_time values match:\n{result["open_time"].tolist()}\n') + print(f'The result open_time values match:\n{result["open_time"].tolist()}\n') print(' - Update cache with new records passed.') def test_update_cached_dict(self): print('Testing update_cached_dict() method:') + + # Set an empty dictionary in the cache for the specified key self.data.set_cache(data={}, key=self.key) + + # Update the cached dictionary with a new key-value pair self.data.update_cached_dict(cache_key=self.key, dict_key='sub_key', data='value') + # Retrieve the updated cache cache = self.data.get_cache(key=self.key) + + # Verify that the 'sub_key' in the cached dictionary has the correct value self.assertEqual(cache['sub_key'], 'value') print(' - Update dictionary in cache passed.') def test_get_cache(self): print('Testing get_cache() method:') + + # Set some data into the cache self.data.set_cache(data='data', key=self.key) + + # Retrieve the cached data using the get_cache method result = self.data.get_cache(key=self.key) + + # Verify that the result matches the data we set self.assertEqual(result, 'data') - print(' - Retrieve cache passed.') + print(' - Retrieve data passed.') def _test_get_records_since(self, set_cache=True, set_db=True, query_offset=None, num_rec=None, ex_details=None, simulate_scenarios=None): """ Test the get_records_since() method by generating a table of simulated data, - inserting it into cache and/or database, and then querying the records. + inserting it into data and/or database, and then querying the records. Parameters: set_cache (bool): If True, the generated table is inserted into the cache. set_db (bool): If True, the generated table is inserted into the database. query_offset (int, optional): The offset in the timeframe units for the query. num_rec (int, optional): The number of records to generate in the simulated table. - ex_details (list, optional): Exchange details to generate the cache key. + ex_details (list, optional): Exchange details to generate the data key. simulate_scenarios (str, optional): The type of scenario to simulate. Options are: - 'not_enough_data': The table data doesn't go far enough back. - 'incomplete_data': The table doesn't have enough records to satisfy the query. @@ -336,7 +373,7 @@ class TestDataCacheV2(unittest.TestCase): # Use provided ex_details or fallback to the class attribute. ex_details = ex_details or self.ex_details - # Generate a cache/database key using exchange details. + # Generate a data/database key using exchange details. key = f'{ex_details[0]}_{ex_details[1]}_{ex_details[2]}' # Set default number of records if not provided. @@ -374,13 +411,13 @@ class TestDataCacheV2(unittest.TestCase): print(f'Table Created:\n{temp_df}') if set_cache: - # Insert the generated table into cache. - print('Inserting table into cache.') + # Insert the generated table into the cache. + print('Inserting table into the cache.') self.data.set_cache(data=df_initial, key=key) if set_db: # Insert the generated table into the database. - print('Inserting table into database.') + print('Inserting table into the database.') with SQLite(self.db_file) as con: df_initial.to_sql(key, con, if_exists='replace', index=False) @@ -414,7 +451,7 @@ class TestDataCacheV2(unittest.TestCase): # Check that the result has more rows than the expected incomplete data. assert result.shape[0] > expected.shape[ 0], "Result has fewer or equal rows compared to the incomplete data." - print("\nThe returned DataFrames has filled in the missing data!") + print("\nThe returned DataFrame has filled in the missing data!") else: # Ensure the result and expected dataframes match in shape and content. assert result.shape == expected.shape, f"Shape mismatch: {result.shape} vs {expected.shape}" @@ -444,10 +481,10 @@ class TestDataCacheV2(unittest.TestCase): print(' - Fetch records within the specified time range passed.') def test_get_records_since(self): - print('\nTest get_records_since with records set in cache') + print('\nTest get_records_since with records set in data') self._test_get_records_since() - print('\nTest get_records_since with records not in cache') + print('\nTest get_records_since with records not in data') self._test_get_records_since(set_cache=False) print('\nTest get_records_since with records not in database') @@ -467,21 +504,21 @@ class TestDataCacheV2(unittest.TestCase): self._test_get_records_since(simulate_scenarios='missing_section') def test_other_timeframes(self): - # print('\nTest get_records_since with a different timeframe') - # ex_details = ['BTC/USD', '15m', 'binance', 'test_guy'] - # start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=2) - # # Query the records since the calculated start time. - # result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details) - # last_record_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC') - # assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=15.1) - # - # print('\nTest get_records_since with a different timeframe') - # ex_details = ['BTC/USD', '5m', 'binance', 'test_guy'] - # start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=1) - # # Query the records since the calculated start time. - # result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details) - # last_record_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC') - # assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=5.1) + print('\nTest get_records_since with a different timeframe') + ex_details = ['BTC/USD', '15m', 'binance', 'test_guy'] + start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=2) + # Query the records since the calculated start time. + result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details) + last_record_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC') + assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=15.1) + + print('\nTest get_records_since with a different timeframe') + ex_details = ['BTC/USD', '5m', 'binance', 'test_guy'] + start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=1) + # Query the records since the calculated start time. + result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details) + last_record_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC') + assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=5.1) print('\nTest get_records_since with a different timeframe') ex_details = ['BTC/USD', '4h', 'binance', 'test_guy'] @@ -506,16 +543,217 @@ class TestDataCacheV2(unittest.TestCase): def test_fetch_candles_from_exchange(self): print('Testing _fetch_candles_from_exchange() method:') - start_time = dt.datetime.utcnow() - dt.timedelta(days=1) - end_time = dt.datetime.utcnow() - result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h', exchange_name='binance', - user_name='test_guy', start_datetime=start_time, - end_datetime=end_time) + # Define start and end times for the data fetch + start_time = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc) - dt.timedelta(days=1) + end_time = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc) + + # Fetch the candles from the exchange using the method + result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h', + exchange_name='binance', user_name='test_guy', + start_datetime=start_time, end_datetime=end_time) + + # Validate that the result is a DataFrame self.assertIsInstance(result, pd.DataFrame) - self.assertFalse(result.empty) + + # Validate that the DataFrame is not empty + self.assertFalse(result.empty, "The DataFrame returned from the exchange is empty.") + + # Ensure that the 'open_time' column exists in the DataFrame + self.assertIn('open_time', result.columns, "'open_time' column is missing in the result DataFrame.") + + # Check if the DataFrame contains valid timestamps within the specified range + min_time = pd.to_datetime(result['open_time'].min(), unit='ms').tz_localize('UTC') + max_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC') + self.assertTrue(start_time <= min_time <= end_time, f"Data starts outside the expected range: {min_time}") + self.assertTrue(start_time <= max_time <= end_time, f"Data ends outside the expected range: {max_time}") + print(' - Fetch candle data from exchange passed.') + def test_remove_row(self): + print('Testing remove_row() method:') + + # Insert data into the cache with the expected columns + df = pd.DataFrame({ + 'key': [self.key], + 'data': ['test_data'] + }) + self.data.set_cache(data='test_data', key=self.key) + + # Ensure the data is in the cache + self.assertTrue(self.data.cache_exists(self.key), "Data was not correctly inserted into the cache.") + + # Remove the row from the cache only (soft delete) + self.data.remove_row(table='test_table_2', filter_vals=('key', self.key), remove_from_db=False) + + # Verify the row has been removed from the cache + self.assertFalse(self.data.cache_exists(self.key), "Row was not correctly removed from the cache.") + + # Reinsert the data for hard delete test + self.data.set_cache(data='test_data', key=self.key) + + # Mock database delete by adding the row to the database + self.data.db.insert_row(table='test_table_2', columns=('key', 'data'), values=(self.key, 'test_data')) + + # Remove the row from both cache and database (hard delete) + self.data.remove_row(table='test_table_2', filter_vals=('key', self.key), remove_from_db=True) + + # Verify the row has been removed from the cache + self.assertFalse(self.data.cache_exists(self.key), "Row was not correctly removed from the cache.") + + # Verify the row has been removed from the database + with SQLite(self.db_file) as con: + result = pd.read_sql(f'SELECT * FROM test_table_2 WHERE key="{self.key}"', con) + self.assertTrue(result.empty, "Row was not correctly removed from the database.") + + print(' - Remove row from cache and database passed.') + + def test_timeframe_to_timedelta(self): + print('Testing timeframe_to_timedelta() function:') + + result = timeframe_to_timedelta('2h') + expected = pd.Timedelta(hours=2) + self.assertEqual(result, expected, "Failed to convert '2h' to Timedelta") + + result = timeframe_to_timedelta('5m') + expected = pd.Timedelta(minutes=5) + self.assertEqual(result, expected, "Failed to convert '5m' to Timedelta") + + result = timeframe_to_timedelta('1d') + expected = pd.Timedelta(days=1) + self.assertEqual(result, expected, "Failed to convert '1d' to Timedelta") + + result = timeframe_to_timedelta('3M') + expected = pd.DateOffset(months=3) + self.assertEqual(result, expected, "Failed to convert '3M' to DateOffset") + + result = timeframe_to_timedelta('1Y') + expected = pd.DateOffset(years=1) + self.assertEqual(result, expected, "Failed to convert '1Y' to DateOffset") + + with self.assertRaises(ValueError): + timeframe_to_timedelta('5x') + + print(' - All timeframe_to_timedelta() tests passed.') + + def test_estimate_record_count(self): + print('Testing estimate_record_count() function:') + + start_time = dt.datetime(2023, 8, 1, 0, 0, 0, tzinfo=dt.timezone.utc) + end_time = dt.datetime(2023, 8, 2, 0, 0, 0, tzinfo=dt.timezone.utc) + + result = estimate_record_count(start_time, end_time, '1h') + expected = 24 + self.assertEqual(result, expected, "Failed to estimate record count for 1h timeframe") + + result = estimate_record_count(start_time, end_time, '1d') + expected = 1 + self.assertEqual(result, expected, "Failed to estimate record count for 1d timeframe") + + start_time = int(start_time.timestamp() * 1000) # Convert to milliseconds + end_time = int(end_time.timestamp() * 1000) # Convert to milliseconds + + result = estimate_record_count(start_time, end_time, '1h') + expected = 24 + self.assertEqual(result, expected, "Failed to estimate record count for 1h timeframe with milliseconds") + + with self.assertRaises(ValueError): + estimate_record_count("invalid_start", end_time, '1h') + + print(' - All estimate_record_count() tests passed.') + + def test_fetch_cached_rows(self): + print('Testing fetch_cached_rows() method:') + + # Set up mock data in the cache + df = pd.DataFrame({ + 'table': ['test_table_2'], + 'key': ['test_key'], + 'data': ['test_data'] + }) + self.data.cache = pd.concat([self.data.cache, df]) + + # Test fetching from cache + result = self.data.fetch_cached_rows('test_table_2', ('key', 'test_key')) + self.assertIsInstance(result, pd.DataFrame, "Failed to fetch DataFrame from cache") + self.assertFalse(result.empty, "The fetched DataFrame is empty") + self.assertEqual(result.iloc[0]['data'], 'test_data', "Incorrect data fetched from cache") + + # Test fetching from database (assuming the method calls it) + # Here we would typically mock the database call + # But since we're not doing I/O, we will skip that part + print(' - Fetch from cache and database simulated.') + + def test_is_attr_taken(self): + print('Testing is_attr_taken() method:') + + # Set up mock data in the cache + df = pd.DataFrame({ + 'table': ['users'], + 'user_name': ['test_user'], + 'data': ['test_data'] + }) + self.data.cache = pd.concat([self.data.cache, df]) + + # Test for existing attribute + result = self.data.is_attr_taken('users', 'user_name', 'test_user') + self.assertTrue(result, "Failed to detect existing attribute") + + # Test for non-existing attribute + result = self.data.is_attr_taken('users', 'user_name', 'non_existing_user') + self.assertFalse(result, "Incorrectly detected non-existing attribute") + + print(' - All is_attr_taken() tests passed.') + + def test_insert_data(self): + print('Testing insert_data() method:') + + # Create a DataFrame to insert + df = pd.DataFrame({ + 'key': ['new_key'], + 'data': ['new_data'] + }) + + # Insert data into the database and cache + self.data.insert_data(df=df, table='test_table_2') + + # Verify that the data was added to the cache + cached_value = self.data.get_cache('new_key') + self.assertEqual(cached_value, 'new_data', "Failed to insert data into cache") + + # Normally, we would also verify that the data was inserted into the database + # This would typically be done with a mock database or by checking the database state directly + print(' - Data insertion into cache and database simulated.') + + def test_insert_row(self): + print('Testing insert_row() method:') + + self.data.insert_row(table='test_table_2', columns=('key', 'data'), values=('test_key', 'test_data')) + + # Verify the row was inserted + with SQLite(self.db_file) as con: + result = pd.read_sql('SELECT * FROM test_table_2 WHERE key="test_key"', con) + self.assertFalse(result.empty, "Row was not inserted into the database.") + + print(' - Insert row passed.') + + def test_fill_data_holes(self): + print('Testing _fill_data_holes() method:') + + # Create mock data with gaps + df = pd.DataFrame({ + 'open_time': [dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc).timestamp() * 1000, + dt.datetime(2023, 1, 1, 2, tzinfo=dt.timezone.utc).timestamp() * 1000, + dt.datetime(2023, 1, 1, 6, tzinfo=dt.timezone.utc).timestamp() * 1000, + dt.datetime(2023, 1, 1, 8, tzinfo=dt.timezone.utc).timestamp() * 1000, + dt.datetime(2023, 1, 1, 12, tzinfo=dt.timezone.utc).timestamp() * 1000] + }) + + # Call the method + result = self.data._fill_data_holes(records=df, interval='2h') + self.assertEqual(len(result), 7, "Data holes were not filled correctly.") + print(' - _fill_data_holes passed.') + if __name__ == '__main__': unittest.main() diff --git a/tests/test_Users.py b/tests/test_Users.py index 3f33a2c..0f1b118 100644 --- a/tests/test_Users.py +++ b/tests/test_Users.py @@ -115,7 +115,7 @@ def test_log_out_all_users(): def test_load_user_data(): # Todo method incomplete - result = config.users.load_user_data(user_name='RobbieD') + result = config.users.get_user_data(user_name='RobbieD') print('\n Test result:', result) assert result is None