diff --git a/.gitignore b/.gitignore index 4234c28..aa480d3 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,7 @@ # Ignore Flask session data flask_session/ -# Ignore testing cache +# Ignore testing data .pytest_cache/ # Ignore databases diff --git a/src/BrighterTrades.py b/src/BrighterTrades.py index b9c85b8..7755dc7 100644 --- a/src/BrighterTrades.py +++ b/src/BrighterTrades.py @@ -1,6 +1,6 @@ - from typing import Any +from Users import Users from DataCache_v2 import DataCache from Strategies import Strategies from backtesting import Backtester @@ -20,25 +20,29 @@ class BrighterTrades: # Object that interacts with the persistent data. self.data = DataCache(self.exchanges) - # Configuration and settings for the user app and charts - self.config = Configuration(cache=self.data) + # Configuration for the app + self.config = Configuration() - # Object that maintains signals. Initialize with any signals loaded from file. - self.signals = Signals(self.config.signals_list) + # The object that manages users in the system. + self.users = Users(data_cache=self.data) + + # Object that maintains signals. + self.signals = Signals(self.config) # Object that maintains candlestick and price data. - self.candles = Candles(config_obj=self.config, exchanges=self.exchanges, data_source=self.data) + self.candles = Candles(users=self.users, exchanges=self.exchanges, data_source=self.data, + config=self.config) # Object that interacts with and maintains data from available indicators - self.indicators = Indicators(self.candles, self.config) + self.indicators = Indicators(self.candles, self.users) # Object that maintains the trades data - self.trades = Trades(self.config.trades) + self.trades = Trades(self.users) # The Trades object needs to connect to an exchange_interface. self.trades.connect_exchanges(exchanges=self.exchanges) # Object that maintains the strategies data - self.strategies = Strategies(self.config.strategies_list, self.trades) + self.strategies = Strategies(self.data, self.trades) # Object responsible for testing trade and strategies data. self.backtester = Backtester() @@ -56,8 +60,8 @@ class BrighterTrades: raise ValueError("Missing required arguments for 'create_new_user'") try: - self.config.users.create_new_user(email=email, username=username, password=password) - login_successful = self.config.users.log_in_user(username=username, password=password) + self.users.create_new_user(email=email, username=username, password=password) + login_successful = self.users.log_in_user(username=username, password=password) return login_successful except Exception as e: # Handle specific exceptions or log the error @@ -77,11 +81,11 @@ class BrighterTrades: try: if cmd == 'logout': - return self.config.users.log_out_user(username=user_name) + return self.users.log_out_user(username=user_name) elif cmd == 'login': if password is None: raise ValueError("Password is required for login.") - return self.config.users.log_in_user(username=user_name, password=password) + return self.users.log_in_user(username=user_name, password=password) except Exception as e: # Handle specific exceptions or log the error raise ValueError("Error during user login/logout: " + str(e)) @@ -98,19 +102,19 @@ class BrighterTrades: if info == 'Chart View': try: - return self.config.users.get_chart_view(user_name=user_name) + return self.users.get_chart_view(user_name=user_name) except Exception as e: # Handle specific exceptions or log the error raise ValueError("Error retrieving chart view information: " + str(e)) elif info == 'Is logged in?': try: - return self.config.users.is_logged_in(user_name=user_name) + return self.users.is_logged_in(user_name=user_name) except Exception as e: # Handle specific exceptions or log the error raise ValueError("Error checking logged in status: " + str(e)) elif info == 'User_id': try: - return self.config.users.get_id(user_name=user_name) + return self.users.get_id(user_name=user_name) except Exception as e: # Handle specific exceptions or log the error raise ValueError("Error fetching id: " + str(e)) @@ -181,10 +185,10 @@ class BrighterTrades: :param default_keys: default API keys. :return: bool - True on success. """ - active_exchanges = self.config.users.get_exchanges(user_name, category='active_exchanges') + active_exchanges = self.users.get_exchanges(user_name, category='active_exchanges') success = False for exchange in active_exchanges: - keys = self.config.users.get_api_keys(user_name, exchange) + keys = self.users.get_api_keys(user_name, exchange) success = self.connect_or_config_exchange(user_name=user_name, exchange_name=exchange, api_keys=keys) @@ -202,7 +206,7 @@ class BrighterTrades: :param user_name: str - The name of the user making the query. """ - chart_view = self.config.users.get_chart_view(user_name=user_name) + chart_view = self.users.get_chart_view(user_name=user_name) indicator_types = self.indicators.indicator_types available_indicators = self.indicators.get_indicator_list(user_name) @@ -231,19 +235,20 @@ class BrighterTrades: :return: A dictionary containing the requested data. """ - chart_view = self.config.users.get_chart_view(user_name=user_name) + chart_view = self.users.get_chart_view(user_name=user_name) exchange = self.exchanges.get_exchange(ename=chart_view.get('exchange'), uname=user_name) # noinspection PyDictCreation r_data = {} - r_data['title'] = self.config.app_data.get('application_title', '') + r_data['title'] = self.config.get_setting('application_title') r_data['chart_interval'] = chart_view.get('timeframe', '') r_data['selected_exchange'] = chart_view.get('exchange', '') r_data['intervals'] = exchange.intervals if exchange else [] r_data['symbols'] = exchange.get_symbols() if exchange else {} r_data['available_exchanges'] = self.exchanges.get_available_exchanges() or [] r_data['connected_exchanges'] = self.exchanges.get_connected_exchanges(user_name) or [] - r_data['configured_exchanges'] = self.config.users.get_exchanges(user_name, category='configured_exchanges') or [] + r_data['configured_exchanges'] = self.users.get_exchanges( + user_name, category='configured_exchanges') or [] r_data['my_balances'] = self.exchanges.get_all_balances(user_name) or {} r_data['indicator_types'] = self.indicators.indicator_types or [] r_data['indicator_list'] = self.indicators.get_indicator_list(user_name) or [] @@ -295,7 +300,7 @@ class BrighterTrades: return "The new signal must have a 'name' attribute." self.signals.new_signal(data) - self.config.update_data('signals', self.signals.get_signals('dict')) + self.config.set_setting('signals_list', self.signals.get_signals('dict')) return data def received_new_strategy(self, data: dict) -> str | dict: @@ -309,7 +314,7 @@ class BrighterTrades: return "The new strategy must have a 'name' attribute." self.strategies.new_strategy(data) - self.config.update_data('strategies', self.strategies.get_strategies('dict')) + self.config.set_setting('strategies', self.strategies.get_strategies('dict')) return data def delete_strategy(self, strategy_name: str) -> None: @@ -377,10 +382,9 @@ class BrighterTrades: success = self.exchanges.connect_exchange(exchange_name=exchange_name, user_name=user_name, api_keys=api_keys) if success: - self.config.users.active_exchange(exchange=exchange_name, user_name=user_name, cmd='set') + self.users.active_exchange(exchange=exchange_name, user_name=user_name, cmd='set') if api_keys: - self.config.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, - user_name=user_name) + self.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, user_name=user_name) return True else: return False # Failed to connect @@ -391,7 +395,7 @@ class BrighterTrades: else: # Exchange is already connected, update API keys if provided if api_keys: - self.config.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, user_name=user_name) + self.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, user_name=user_name) return True # Already connected def close_trade(self, trade_id): @@ -470,19 +474,19 @@ class BrighterTrades: if setting == 'interval': interval_state = params['timeframe'] - self.config.users.set_chart_view(values=interval_state, specific_property='timeframe', user_name=user_name) + self.users.set_chart_view(values=interval_state, specific_property='timeframe', user_name=user_name) elif setting == 'trading_pair': trading_pair = params['symbol'] - self.config.users.set_chart_view(values=trading_pair, specific_property='market', user_name=user_name) + self.users.set_chart_view(values=trading_pair, specific_property='market', user_name=user_name) elif setting == 'exchange': exchange_name = params['exchange_name'] # Get the first result of a list of available symbols from this exchange_name. market = self.exchanges.get_exchange(ename=exchange_name, uname=user_name).get_symbols()[0] # Change the exchange in the chart view and pass it a default market to display. - self.config.users.set_chart_view(values=exchange_name, specific_property='exchange_name', - user_name=user_name, default_market=market) + self.users.set_chart_view(values=exchange_name, specific_property='exchange_name', + user_name=user_name, default_market=market) elif setting == 'toggle_indicator': indicators_to_toggle = params.getlist('indicator') diff --git a/src/Configuration.py b/src/Configuration.py index 67b82d3..70f7974 100644 --- a/src/Configuration.py +++ b/src/Configuration.py @@ -1,130 +1,279 @@ -import pandas import yaml -from Signals import Signals -from indicators import Indicators -from Users import Users +import os +import time +import logging +from typing import Any + +logger = logging.getLogger(__name__) class Configuration: - def __init__(self, cache): - # ************** Default values************** + """ + Manages the application's settings, loading and saving them to a YAML file. + - Automatically loads the settings from file when instantiated. + - Settings can be added or modified with set_setting('key', value) + or by editing the file directly. + """ - self.app_data = { - 'application_title': 'BrighterTrades', # The title of our program. - 'max_data_loaded': 1000 # The maximum number of candles to store in memory. + def __init__(self, config_file='config.yml'): + """Initializes with default settings and loads saved data.""" + self.config_file = config_file + self.save_in_progress = False # Persistent flag to prevent infinite recursion + self._set_default_settings() + self.manage_config('load') + + def _set_default_settings(self): + """Set default settings.""" + self.settings = { + 'application_title': 'BrighterTrades', + 'max_data_loaded': 1000 } - # The object that interacts with the database. - self.data = cache.db + def _generate_config_data(self): + """Generates a list of settings to be saved to the config file.""" + return self.settings.copy() - # The object that manages users in the system. - self.users = Users(data_cache=cache) - - # The name of the file that stores saved_data - self.config_FN = 'config.yml' - - # A list of all the available Signals. - # Calls a static method of Signals that initializes a default list. - self.signals_list = Signals.get_signals_defaults() - - # list of all the available strategies. - self.strategies_list = [] - - # list of trades. - self.trades = [] - - # The data that will be saved and loaded from file . - self.saved_data = None - - # Load any saved data from file - self.config_and_states('load') - - def update_data(self, data_type, data): + def manage_config(self, cmd: str): """ - Replace current list of data sets with an updated list. - - :param data_type: The data being replaced - :param data: The replacement data. - :return: None. + Loads or saves settings to the config file according to cmd: 'load' | 'save' """ - if data_type == 'strategies': - self.strategies_list = data - elif data_type == 'signals': - self.signals_list = data - elif data_type == 'trades': - self.trades = data - else: - raise ValueError(f'Configuration: update_data(): Unsupported data_type: {data_type}') - # Save it to file. - self.config_and_states('save') - def remove(self, what, name): - # Removes by name an item from a list in saved data. + def _apply_loaded_settings(data): + """Updates settings from loaded data and marks if saving is needed.""" + needs_save = False + for key, value in data.items(): + if key not in self.settings or self.settings[key] != value: + self.settings[key] = value + needs_save = True + return needs_save - # Trades are indexed by unique_id, while Signals and Strategies are indexed by name. - if what == 'trades': - prop = 'unique_id' - else: - prop = 'name' + def _load_config_from_file(filepath): + try: + with open(filepath, "r") as file_descriptor: + return yaml.safe_load(file_descriptor) + except yaml.YAMLError: + timestamp = time.strftime("%Y%m%d-%H%M%S") + backup_path = f"{filepath}.{timestamp}.backup" + os.rename(filepath, backup_path) + logging.warning(f"Corrupt YAML file detected. Backup saved to {backup_path}") + logging.info(f"Recreating the configuration file with default settings.") + return None - for obj in self.saved_data[what]: - if obj[prop] == name: - self.saved_data[what].remove(obj) - break - # Save it to file. - self.config_and_states('save') - - def config_and_states(self, cmd): - """Loads or saves configurable data to the file set in self.config_FN""" - - # The data stored and retrieved from file session. - self.saved_data = { - 'signals': self.signals_list, - 'strategies': self.strategies_list, - 'trades': self.trades - } - - def set_loaded_values(): - # Sets the values in the saved_data object. - - if 'signals' in self.saved_data: - self.signals_list = self.saved_data['signals'] - - if 'strategies' in self.saved_data: - self.strategies_list = self.saved_data['strategies'] - - if 'trades' in self.saved_data: - self.trades = self.saved_data['trades'] - - def load_configuration(filepath): - """load file data""" - with open(filepath, "r") as file_descriptor: - data = yaml.safe_load(file_descriptor) - return data - - def save_configuration(filepath, data): - """Saves file data""" - with open(filepath, "w") as file_descriptor: - yaml.dump(data, file_descriptor) + def _save_config_to_file(filepath, data): + try: + with open(filepath, "w") as file_descriptor: + yaml.dump(data, file_descriptor) + except (IOError, OSError) as e: + logging.error(f"Failed to save configuration to {filepath}: {e}") + raise ValueError(f"Failed to save configuration to {filepath}: {e}") if cmd == 'load': - # If load_configuration() finds a file it overwrites - # the saved_data object otherwise it creates a new file - # with the defaults contained in saved_data> - - # If file exist load the values. - try: - self.saved_data = load_configuration(self.config_FN) - set_loaded_values() - # If file doesn't exist create a file and save the default values. - except IOError: - save_configuration(self.config_FN, self.saved_data) - + if os.path.exists(self.config_file): + loaded_data = _load_config_from_file(self.config_file) + if loaded_data is None: # Corrupt file case, recreate it + _save_config_to_file(self.config_file, self._generate_config_data()) + elif _apply_loaded_settings(loaded_data) and not self.save_in_progress: + self.save_in_progress = True + self.manage_config('save') + self.save_in_progress = False + else: + logging.info(f"Configuration file not found. Creating a new one at {self.config_file}.") + _save_config_to_file(self.config_file, self._generate_config_data()) elif cmd == 'save': - try: - # Write saved_data to the file. - save_configuration(self.config_FN, self.saved_data) - except IOError: - raise ValueError("save_configuration(): Couldn't save the file.") + _save_config_to_file(self.config_file, self._generate_config_data()) else: - raise ValueError('save_configuration(): Invalid command received.') + raise ValueError('manage_config(): Invalid command.') + + def reset_settings_to_defaults(self): + """Resets settings to default values.""" + self._set_default_settings() + self.manage_config('save') + + def get_setting(self, key: str) -> Any: + """ + Returns the value of the specified setting or None if the key is not found. + """ + return self.settings.get(key, None) + + def set_setting(self, key: str, value: Any): + """ + Receives a key and value of any setting and saves the configuration. + """ + self.settings[key] = value + self.manage_config('save') + +# import yaml +# from Signals import Signals +# +# +# class Configuration: +# """ +# Configuration class manages the application's settings, +# signals, strategies, and trades. It loads and saves these +# configurations to a YAML file. +# +# Attributes: +# app_data (dict): Default application settings like title and maximum data load. +# data (object): Database interaction object from the data. +# config_FN (str): Filename for storing and loading configurations. +# signals_list (list): List of available signals initialized by the Signals class. +# strategies_list (list): List of strategies available for the application. +# trades (list): List of trades managed by the application. +# saved_data (dict): Data structure for saving/loading configuration data. +# """ +# +# def __init__(self, data): +# """ +# Initializes the Configuration object with default values and loads saved data. +# +# Args: +# data (object): Cache object with a database attribute to interact with the database. +# """ +# # ************** Default values ************** +# # Application metadata such as title and maximum data to be loaded. +# self.app_data = { +# 'application_title': 'BrighterTrades', # The title of the program. +# 'max_data_loaded': 1000 # Maximum number of candles to store in memory. +# } +# +# # The database object for interaction. +# self.data = data.db +# +# # Name of the configuration file. +# self.config_FN = 'config.yml' +# +# # List of available signals initialized with default values. +# self.signals_list = Signals.get_signals_defaults() +# +# # List to hold available strategies. +# self.strategies_list = [] +# +# # List to hold trades. +# self.trades = [] +# +# # Placeholder for data loaded from or to be saved to the file. +# self.saved_data = None +# +# # Load any saved data from the configuration file. +# self.config_and_states('load') +# +# def update_data(self, data_type, data): +# """ +# Replace the current list of data sets with an updated list. +# +# Args: +# data_type (str): Type of data to be updated ('strategies', 'signals', or 'trades'). +# data (list): The new data to replace the old one. +# +# Raises: +# ValueError: If the provided data_type is not supported. +# """ +# if data_type == 'strategies': +# self.strategies_list = data +# elif data_type == 'signals': +# self.signals_list = data +# elif data_type == 'trades': +# self.trades = data +# else: +# raise ValueError(f'Configuration: update_data(): Unsupported data_type: {data_type}') +# +# # Save the updated data to the configuration file. +# self.config_and_states('save') +# +# def remove(self, what, name): +# """ +# Removes an item by name from the saved data list. +# +# Args: +# what (str): Type of data to remove ('strategies', 'signals', or 'trades'). +# name (str): The name or unique_id of the item to remove. +# +# Raises: +# ValueError: If the item with the specified name is not found. +# """ +# # Determine the property to match based on the type of data. +# if what == 'trades': +# prop = 'unique_id' +# else: +# prop = 'name' +# +# # Remove the item from the list if it matches the name or unique_id. +# for obj in self.saved_data[what]: +# if obj[prop] == name: +# self.saved_data[what].remove(obj) +# break +# +# # Save the updated data to the configuration file. +# self.config_and_states('save') +# +# def config_and_states(self, cmd): +# """ +# Loads or saves configurable data to the file set in self.config_FN. +# +# Args: +# cmd (str): Command to either 'load' or 'save' the configuration data. +# +# Raises: +# ValueError: If the command is neither 'load' nor 'save'. +# """ +# +# # Data structure to hold the current state of signals, strategies, and trades. +# self.saved_data = { +# 'signals': self.signals_list, +# 'strategies': self.strategies_list, +# 'trades': self.trades +# } +# +# def set_loaded_values(): +# """Sets the values in the saved_data object to the class attributes.""" +# if 'signals' in self.saved_data: +# self.signals_list = self.saved_data['signals'] +# +# if 'strategies' in self.saved_data: +# self.strategies_list = self.saved_data['strategies'] +# +# if 'trades' in self.saved_data: +# self.trades = self.saved_data['trades'] +# +# def load_configuration(filepath): +# """ +# Load configuration data from a YAML file. +# +# Args: +# filepath (str): Path to the configuration file. +# +# Returns: +# dict: Loaded configuration data. +# """ +# with open(filepath, "r") as file_descriptor: +# data = yaml.safe_load(file_descriptor) +# return data +# +# def save_configuration(filepath, data): +# """ +# Save configuration data to a YAML file. +# +# Args: +# filepath (str): Path to the configuration file. +# data (dict): Data to save. +# """ +# with open(filepath, "w") as file_descriptor: +# yaml.dump(data, file_descriptor) +# +# if cmd == 'load': +# try: +# # Attempt to load the configuration from the file. +# self.saved_data = load_configuration(self.config_FN) +# set_loaded_values() +# except IOError: +# # If the file doesn't exist, save the default values. +# save_configuration(self.config_FN, self.saved_data) +# +# elif cmd == 'save': +# try: +# # Save the current state to the configuration file. +# save_configuration(self.config_FN, self.saved_data) +# except IOError: +# raise ValueError("save_configuration(): Couldn't save the file.") +# else: +# raise ValueError('save_configuration(): Invalid command received.') diff --git a/src/DataCache_v2.py b/src/DataCache_v2.py index d8f0a74..d9a2861 100644 --- a/src/DataCache_v2.py +++ b/src/DataCache_v2.py @@ -1,4 +1,5 @@ -from typing import List, Any +import json +from typing import List, Any, Tuple import pandas as pd import datetime as dt import logging @@ -59,23 +60,31 @@ def estimate_record_count(start_time, end_time, timeframe: str) -> int: return int(expected_records) -def generate_expected_timestamps(start_datetime: dt.datetime, end_datetime: dt.datetime, - timeframe: str) -> pd.DatetimeIndex: - if start_datetime.tzinfo is None: - raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") - if end_datetime.tzinfo is None: - raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.") - - delta = timeframe_to_timedelta(timeframe) - if isinstance(delta, pd.Timedelta): - return pd.date_range(start=start_datetime, end=end_datetime, freq=delta) - elif isinstance(delta, pd.DateOffset): - current = start_datetime - timestamps = [] - while current <= end_datetime: - timestamps.append(current) - current += delta - return pd.DatetimeIndex(timestamps) +# def generate_expected_timestamps(start_datetime: dt.datetime, end_datetime: dt.datetime, +# timeframe: str) -> pd.DatetimeIndex: +# """ +# What it says. Todo: confirm this is unused and archive. +# +# :param start_datetime: +# :param end_datetime: +# :param timeframe: +# :return: +# """ +# if start_datetime.tzinfo is None: +# raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") +# if end_datetime.tzinfo is None: +# raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.") +# +# delta = timeframe_to_timedelta(timeframe) +# if isinstance(delta, pd.Timedelta): +# return pd.date_range(start=start_datetime, end=end_datetime, freq=delta) +# elif isinstance(delta, pd.DateOffset): +# current = start_datetime +# timestamps = [] +# while current <= end_datetime: +# timestamps.append(current) +# current += delta +# return pd.DatetimeIndex(timestamps) class DataCache: @@ -84,10 +93,170 @@ class DataCache: def __init__(self, exchanges): self.db = Database() self.exchanges = exchanges - self.cached_data = {} + # Single DataFrame for all cached data + self.cache = pd.DataFrame(columns=['key', 'data']) # Assuming 'key' and 'data' are necessary logger.info("DataCache initialized.") + def fetch_cached_rows(self, table: str, filter_vals: Tuple[str, Any]) -> pd.DataFrame | None: + """ + Retrieves rows from the cache if available; otherwise, queries the database and caches the result. + + :param table: Name of the database table to query. + :param filter_vals: A tuple containing the column name and the value to filter by. + :return: A DataFrame containing the requested rows, or None if no matching rows are found. + """ + # Construct a filter condition for the cache based on the table name and filter values. + cache_filter = (self.cache['table'] == table) & (self.cache[filter_vals[0]] == filter_vals[1]) + cached_rows = self.cache[cache_filter] + + # If the data is found in the cache, return it. + if not cached_rows.empty: + return cached_rows + + # If the data is not found in the cache, query the database. + rows = self.db.get_rows_where(table, filter_vals) + if rows is not None: + # Tag the rows with the table name and add them to the cache. + rows['table'] = table + self.cache = pd.concat([self.cache, rows]) + return rows + + def remove_row(self, filter_vals: Tuple[str, Any], additional_filter: Tuple[str, Any] = None, + remove_from_db: bool = True, table: str = None) -> None: + """ + Removes a specific row from the cache and optionally from the database based on filter criteria. + + :param filter_vals: A tuple containing the column name and the value to filter by. + :param additional_filter: An optional additional filter to apply. + :param remove_from_db: If True, also removes the row from the database. Default is True. + :param table: The name of the table from which to remove the row in the database (optional). + """ + logger.debug( + f"Removing row from cache: filter={filter_vals}," + f" additional_filter={additional_filter}, remove_from_db={remove_from_db}, table={table}") + + # Construct the filter condition for the cache + cache_filter = (self.cache[filter_vals[0]] == filter_vals[1]) + + if additional_filter: + cache_filter = cache_filter & (self.cache[additional_filter[0]] == additional_filter[1]) + + # Remove the row from the cache + self.cache = self.cache.drop(self.cache[cache_filter].index) + logger.info(f"Row removed from cache: filter={filter_vals}") + + if remove_from_db and table: + # Construct the SQL query to delete from the database + sql = f"DELETE FROM {table} WHERE {filter_vals[0]} = ?" + params = [filter_vals[1]] + + if additional_filter: + sql += f" AND {additional_filter[0]} = ?" + params.append(additional_filter[1]) + + # Execute the SQL query to remove the row from the database + self.db.execute_sql(sql, tuple(params)) + logger.info( + f"Row removed from database: table={table}, filter={filter_vals}," + f" additional_filter={additional_filter}") + + def is_attr_taken(self, table: str, attr: str, val: Any) -> bool: + """ + Checks if a specific attribute in a table is already taken. + + :param table: The name of the table to check. + :param attr: The attribute to check (e.g., 'user_name', 'email'). + :param val: The value of the attribute to check. + :return: True if the attribute is already taken, False otherwise. + """ + # Fetch rows from the specified table where the attribute matches the given value + result = self.fetch_cached_rows(table=table, filter_vals=(attr, val)) + return result is not None and not result.empty + + def fetch_cached_item(self, item_name: str, table_name: str, filter_vals: Tuple[str, Any]) -> Any: + """ + Retrieves a specific item from the cache or database, caching the result if necessary. + + :param item_name: The name of the column to retrieve. + :param table_name: The name of the table where the item is stored. + :param filter_vals: A tuple containing the column name and the value to filter by. + :return: The value of the requested item. + :raises ValueError: If the item is not found in either the cache or the database. + """ + # Fetch the relevant rows + rows = self.fetch_cached_rows(table_name, filter_vals) + if rows is not None and not rows.empty: + # Return the specific item from the first matching row. + return rows.iloc[0][item_name] + + # If the item is not found, raise an error. + raise ValueError(f"Item {item_name} not found in {table_name} where {filter_vals[0]} = {filter_vals[1]}") + + def modify_cached_row(self, table: str, filter_vals: Tuple[str, Any], field_name: str, new_data: Any) -> None: + """ + Modifies a specific field in a row within the cache and updates the database accordingly. + + :param table: The name of the table where the data is stored. + :param filter_vals: A tuple containing the column name and the value to filter by. + :param field_name: The field to be updated. + :param new_data: The new data to be set. + """ + # Retrieve the row from the cache or database + row = self.fetch_cached_rows(table, filter_vals) + + if row is None or row.empty: + raise ValueError(f"Row not found in cache or database for {filter_vals[0]} = {filter_vals[1]}") + + # Modify the specified field + if isinstance(new_data, str): + row.loc[0, field_name] = new_data + else: + # If new_data is not a string, it’s converted to a JSON string before being inserted into the DataFrame. + row.loc[0, field_name] = json.dumps(new_data) + + # Update the cache by removing the old entry and adding the modified row + self.cache = self.cache.drop( + self.cache[(self.cache['table'] == table) & (self.cache[filter_vals[0]] == filter_vals[1])].index + ) + self.cache = pd.concat([self.cache, row]) + + # Update the database with the modified row + self.db.insert_dataframe(row.drop(columns='id'), table) + + def insert_data(self, df: pd.DataFrame, table: str, skip_cache: bool = False) -> None: + """ + Inserts data into the specified table in the database, with an option to skip cache insertion. + + :param df: The DataFrame containing the data to insert. + :param table: The name of the table where the data should be inserted. + :param skip_cache: If True, skips inserting the data into the cache. Default is False. + """ + # Insert the data into the database + self.db.insert_dataframe(df=df, table=table) + + # Optionally insert the data into the cache + if not skip_cache: + df['table'] = table # Add table name for cache identification + self.cache = pd.concat([self.cache, df]) + + def insert_row(self, table: str, columns: tuple, values: tuple) -> None: + """ + Inserts a single row into the specified table in the database. + + :param table: The name of the table where the row should be inserted. + :param columns: A tuple of column names corresponding to the values. + :param values: A tuple of values to insert into the specified columns. + """ + self.db.insert_row(table=table, columns=columns, values=values) + def get_records_since(self, start_datetime: dt.datetime, ex_details: List[str]) -> pd.DataFrame: + """ + This gets up-to-date records from a specified market and exchange. + + :param start_datetime: The approximate time the first record should represent. + :param ex_details: The user exchange and market. + :return: The records. + """ if self.TYPECHECKING_ENABLED: if not isinstance(start_datetime, dt.datetime): raise TypeError("start_datetime must be a datetime object") @@ -106,12 +275,20 @@ class DataCache: 'end_datetime': end_datetime, 'ex_details': ex_details, } - return self.get_or_fetch_from('cache', **args) + return self._get_or_fetch_from('data', **args) except Exception as e: logger.error(f"An error occurred: {str(e)}") raise - def get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame: + def _get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame: + """ + Fetches market records from a resource stack (data, database, exchange). + fills incomplete request by fetching down the stack then updates the rest. + + :param target: Starting point for the fetch. ['data', 'database', 'exchange'] + :param kwargs: Details and credentials for the request. + :return: Records in a dataframe. + """ start_datetime = kwargs.get('start_datetime') if start_datetime.tzinfo is None: raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") @@ -144,12 +321,12 @@ class DataCache: key = self._make_key(ex_details) combined_data = pd.DataFrame() - if target == 'cache': - resources = [self.get_candles_from_cache, self.get_from_database, self.get_from_server] + if target == 'data': + resources = [self._get_candles_from_cache, self._get_from_database, self._get_from_server] elif target == 'database': - resources = [self.get_from_database, self.get_from_server] + resources = [self._get_from_database, self._get_from_server] elif target == 'server': - resources = [self.get_from_server] + resources = [self._get_from_server] else: raise ValueError('Not a valid Target!') @@ -165,11 +342,11 @@ class DataCache: combined_data = pd.concat([combined_data, result]).drop_duplicates(subset='open_time').sort_values( by='open_time') - is_complete, request_criteria = self.data_complete(combined_data, **request_criteria) + is_complete, request_criteria = self._data_complete(combined_data, **request_criteria) if is_complete: - if fetch_method in [self.get_from_database, self.get_from_server]: - self.update_candle_cache(combined_data, key) - if fetch_method == self.get_from_server: + if fetch_method in [self._get_from_database, self._get_from_server]: + self._update_candle_cache(combined_data, key) + if fetch_method == self._get_from_server: self._populate_db(ex_details, combined_data) return combined_data @@ -178,7 +355,7 @@ class DataCache: logger.error('Unable to fetch the requested data.') return combined_data if not combined_data.empty else pd.DataFrame() - def get_candles_from_cache(self, **kwargs) -> pd.DataFrame: + def _get_candles_from_cache(self, **kwargs) -> pd.DataFrame: start_datetime = kwargs.get('start_datetime') if start_datetime.tzinfo is None: raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") @@ -200,7 +377,7 @@ class DataCache: raise ValueError("Missing required arguments") key = self._make_key(ex_details) - logger.debug('Getting records from cache.') + logger.debug('Getting records from data.') df = self.get_cache(key) if df is None: logger.debug("Cache records didn't exist.") @@ -210,7 +387,7 @@ class DataCache: df['open_time'] <= unix_time_millis(end_datetime))].reset_index(drop=True) return df_filtered - def get_from_database(self, **kwargs) -> pd.DataFrame: + def _get_from_database(self, **kwargs) -> pd.DataFrame: start_datetime = kwargs.get('start_datetime') if start_datetime.tzinfo is None: raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") @@ -240,7 +417,7 @@ class DataCache: return self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=start_datetime, et=end_datetime) - def get_from_server(self, **kwargs) -> pd.DataFrame: + def _get_from_server(self, **kwargs) -> pd.DataFrame: symbol = kwargs.get('ex_details')[0] interval = kwargs.get('ex_details')[1] exchange_name = kwargs.get('ex_details')[2] @@ -272,7 +449,7 @@ class DataCache: end_datetime) @staticmethod - def data_complete(data: pd.DataFrame, **kwargs) -> (bool, dict): + def _data_complete(data: pd.DataFrame, **kwargs) -> (bool, dict): """ Checks if the data completely satisfies the request. @@ -341,17 +518,21 @@ class DataCache: return True, kwargs def cache_exists(self, key: str) -> bool: - return key in self.cached_data + return key in self.cache['key'].values def get_cache(self, key: str) -> Any | None: - if key not in self.cached_data: - logger.warning(f"The requested cache key({key}) doesn't exist!") + # Check if the key exists in the cache + if key not in self.cache['key'].values: + logger.warning(f"The requested data key ({key}) doesn't exist!") return None - return self.cached_data[key] - def update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None: - logger.debug('Updating cache with new records.') - # Concatenate the new records with the existing cache + # Retrieve the data associated with the key + result = self.cache[self.cache['key'] == key]['data'].iloc[0] + return result + + def _update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None: + logger.debug('Updating data with new records.') + # Concatenate the new records with the existing data records = pd.concat([self.get_cache(key), more_records], axis=0, ignore_index=True) # Drop duplicates based on 'open_time' and keep the first occurrence records = records.drop_duplicates(subset="open_time", keep='first') @@ -359,24 +540,49 @@ class DataCache: records = records.sort_values(by='open_time').reset_index(drop=True) # Reindex 'id' to ensure the expected order records['id'] = range(1, len(records) + 1) - # Set the updated DataFrame back to cache + # Set the updated DataFrame back to data self.set_cache(data=records, key=key) def update_cached_dict(self, cache_key: str, dict_key: str, data: Any) -> None: """ - Updates a dictionary stored in cache. + Updates a dictionary stored in the DataFrame cache. - :param data: The data to insert into cache. - :param cache_key: The cache index key for the dictionary. - :param dict_key: The dictionary key for the data. + :param data: The data to insert into the dictionary. + :param cache_key: The cache key for the dictionary. + :param dict_key: The key within the dictionary to update. :return: None """ - self.cached_data[cache_key].update({dict_key: data}) + # Locate the row in the DataFrame that matches the cache_key + cache_index = self.cache.index[self.cache['key'] == cache_key] + + if not cache_index.empty: + # Update the dictionary stored in the 'data' column + cache_dict = self.cache.at[cache_index[0], 'data'] + + if isinstance(cache_dict, dict): + cache_dict[dict_key] = data + + # Ensure the DataFrame is updated with the new dictionary + self.cache.at[cache_index[0], 'data'] = cache_dict + else: + raise ValueError(f"Expected a dictionary in cache, but found {type(cache_dict)}.") + else: + raise KeyError(f"Cache key '{cache_key}' not found.") def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None: - if do_not_overwrite and key in self.cached_data: + if do_not_overwrite and key in self.cache['key'].values: return - self.cached_data[key] = data + + # Corrected construction of the new row + new_row = pd.DataFrame({'key': [key], 'data': [data]}) + + # If the key already exists, drop the old entry + self.cache = self.cache[self.cache['key'] != key] + + # Append the new row to the cache + self.cache = pd.concat([self.cache, new_row], ignore_index=True) + + print(f'Current Cache: {self.cache}') logger.debug(f'Cache set for key: {key}') def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str, @@ -422,12 +628,12 @@ class DataCache: if num_rec_records < estimated_num_records: logger.info('Detected gaps in the data, attempting to fill missing records.') - candles = self.fill_data_holes(candles, interval) + candles = self._fill_data_holes(candles, interval) return candles @staticmethod - def fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame: + def _fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame: time_span = timeframe_to_timedelta(interval).total_seconds() / 60 last_timestamp = None filled_records = [] diff --git a/src/Database.py b/src/Database.py index 27f1aa4..b0477c1 100644 --- a/src/Database.py +++ b/src/Database.py @@ -33,7 +33,7 @@ class SQLite: class HDict(dict): """ - Hashable dictionary to use as cache keys. + Hashable dictionary to use as data keys. Example usage: -------------- @@ -85,15 +85,16 @@ class Database: def __init__(self, db_file: str = None): self.db_file = db_file - def execute_sql(self, sql: str) -> None: + def execute_sql(self, sql: str, params: tuple = ()) -> None: """ - Executes a raw SQL statement. + Executes a raw SQL statement with optional parameters. :param sql: SQL statement to execute. + :param params: Optional tuple of parameters to pass with the SQL statement. """ with SQLite(self.db_file) as con: cur = con.cursor() - cur.execute(sql) + cur.execute(sql, params) def get_item_where(self, item_name: str, table_name: str, filter_vals: Tuple[str, Any]) -> Any: """ @@ -120,12 +121,17 @@ class Database: :param table: Name of the table. :param filter_vals: Tuple of column name and value to filter by. - :return: DataFrame of the query result or None if empty. + :return: DataFrame of the query result or None if empty or column does not exist. """ - with SQLite(self.db_file) as con: - qry = f"SELECT * FROM {table} WHERE {filter_vals[0]} = ?" - result = pd.read_sql(qry, con, params=(filter_vals[1],)) - return result if not result.empty else None + try: + with SQLite(self.db_file) as con: + qry = f"SELECT * FROM {table} WHERE {filter_vals[0]} = ?" + result = pd.read_sql(qry, con, params=(filter_vals[1],)) + return result if not result.empty else None + except (sqlite3.OperationalError, pd.errors.DatabaseError) as e: + # Log the error or handle it appropriately + print(f"Error querying table '{table}' for column '{filter_vals[0]}': {e}") + return None def insert_dataframe(self, df: pd.DataFrame, table: str) -> None: """ diff --git a/src/Signals.py b/src/Signals.py index bdec03b..fe99d8e 100644 --- a/src/Signals.py +++ b/src/Signals.py @@ -52,11 +52,18 @@ class Signal: class Signals: - def __init__(self, loaded_signals=None): + def __init__(self, config): # list of Signal objects. self.signals = [] + # load a list of existing signals from file. + loaded_signals = config.get_setting('signals_list') + if loaded_signals is None: + # Populate the list and file with defaults defined in this class. + loaded_signals = self.get_signals_defaults() + config.set_setting('signals_list', loaded_signals) + # Initialize signals with loaded data. if loaded_signals is not None: self.create_signal_from_dic(loaded_signals) diff --git a/src/Strategies.py b/src/Strategies.py index 64301f9..fb0d303 100644 --- a/src/Strategies.py +++ b/src/Strategies.py @@ -1,5 +1,5 @@ import json - +from DataCache_v2 import DataCache class Strategy: def __init__(self, **args): @@ -154,14 +154,26 @@ class Strategy: class Strategies: - def __init__(self, loaded_strats, trades): + def __init__(self, data: DataCache, trades): + # Handles database connections + self.data = data + # Reference to the trades object that maintains all trading actions and data. self.trades = trades - # A list of all the Strategies created. - self.strat_list = [] - # Initialise all the stately objects with the data saved to file. - for entry in loaded_strats: - self.strat_list.append(Strategy(**entry)) + + def get_all_strategy_names(self) -> list | None: + """Return a list of all strategies in the database""" + self.data._get_from_database() + # Load existing Strategies from file. + loaded_strategies = config.get_setting('strategies') + if loaded_strategies is None: + # Populate the list and file with defaults defined in this class. + loaded_strategies = self.get_strategy_defaults() + config.set_setting('strategies', loaded_strategies) + + for entry in loaded_strategies: + # Initialise all the strategy objects with data from file. + self.strat_list.append(Strategy(**entry)) return None def new_strategy(self, data): # Create an instance of the new Strategy. @@ -238,6 +250,7 @@ class Strategies: published strategies and evaluates conditions against the data. This function returns a list of strategies and action commands. """ + def process_strategy(strategy): action, cmd = strategy.evaluate_strategy(signals) if action != 'do_nothing': @@ -262,4 +275,3 @@ class Strategies: return False else: return return_obj - diff --git a/src/Users.py b/src/Users.py index 7bcb84a..6f450ec 100644 --- a/src/Users.py +++ b/src/Users.py @@ -2,531 +2,399 @@ import datetime as dt import json import random from typing import Any - from passlib.hash import bcrypt import pandas as pd -from Database import HDict +from DataCache_v2 import DataCache -class Users: +class BaseUser: """ - Manages user data and states. + Handles basic user data retrieval and manipulation. + This is the base class for all user-related operations. """ - def __init__(self, data_cache): - - # The object that handles all cached data. - self.cache = data_cache - - # The unique identifiers for any guests in the system. - self.cache.set_cache(data=[], key='guest_suffixes', do_not_overwrite=True) - - # Maximum number of guest allowed in the system at one time. - self.max_guests = 100 - - # A place to cache data for users logged into the system. - self.cache.set_cache(data={}, key='cached_users', do_not_overwrite=True) - - # The class contains methods that interact with the database. - self.db = data_cache.db - - # Clear the status of any users that were signed in before the application last shut down. - # Configured it to recover the user, but I may put this back in. - # self.log_out_all_users() - - def get_indicators(self, user_name: str) -> pd.DataFrame | None: + def __init__(self, data_cache: DataCache): """ - Return a dataframe containing all the indicators for a specific user. + Initialize the BaseUser with caching and database interaction capabilities. - :param user_name: The name of the user. - :return: pd.Dataframe - Indicator attributes and properties or None if query fails. + :param data_cache: Object responsible for managing cached data and database interaction. """ - user_id = self.get_id(user_name) - df = self.db.get_rows_where(table='indicators', filter_vals=('creator', user_id)) - - # Convert the 'source' and 'properties' columns from strings to dictionaries - if df is not None: - df['source'] = df['source'].apply(json.loads) - df['properties'] = df['properties'].apply(json.loads) - - return df + self.data = data_cache def get_id(self, user_name: str) -> int: """ - Gets user id from the users table in the database. - :param user_name: str - The name of the user. - :return: The id of the user. + Retrieves the user ID based on the username. + + :param user_name: The name of the user. + :return: The ID of the user as an integer. """ - return self.db.get_item_where(item_name='id', table_name='users', filter_vals=('user_name', user_name)) + return self.data.fetch_cached_item( + item_name='id', + table_name='users', + filter_vals=('user_name', user_name) + ) - def get_username(self, u_id: int) -> str: + def _remove_user_from_memory(self, user_name: str) -> None: """ - Gets user_name from the users table in the database. - :param u_id: int - The id of the user. - :return: str - The name of the user. + Private method to remove a user's data from the cache (memory). + + :param user_name: The name of the user to remove from the cache. """ - return self.db.get_item_where(item_name='user_name', table_name='users', filter_vals=('id', u_id)) + # Use DataCache to remove the user from the cache only + self.data.remove_row( + table='users', + filter_vals=('user_name', user_name) + ) - def save_indicators(self, indicators: pd.DataFrame) -> None: + def delete_user(self, user_name: str) -> None: """ - Store one or many indicators in the database. + Deletes the user from both the cache and the database. - :param indicators: pd.dataframe - Indicator attributes and properties. - :return: None. + :param user_name: The name of the user to delete. """ - for _, indicator in indicators.iterrows(): - src_string = json.dumps(indicator['source']) - prop_string = json.dumps(indicator['properties']) - values = (indicator['creator'], indicator['name'], indicator['visible'], - indicator['kind'], src_string, prop_string) - columns = ('creator', 'name', 'visible', 'kind', 'source', 'properties') - self.db.insert_row(table='indicators', columns=columns, values=values) + self.data.remove_row( + filter_vals=('user_name', user_name), + table='users' + ) - def remove_indicator(self): - return - - def is_logged_in(self, user_name) -> bool: + def get_user_data(self, user_name: str) -> pd.DataFrame | None: """ - Returns True if the session var contains a currently logged-in user. - :param user_name: The name of the user. - :return: True|False + Retrieves user data from the cache or database. If not found in the cache, + it loads from the database and caches it. + + :param user_name: The name of the user. + :return: A DataFrame containing the user's data. + :raises ValueError: If the user is not found in both the cache and the database. """ + # Attempt to fetch the user data from the cache or database via DataCache + user = self.data.fetch_cached_rows( + table='users', + filter_vals=('user_name', user_name) + ) - def poll_login_flag(name: str): - # Get the user from the cache. - user = self.get_user_from_cache(name) + if user is None or user.empty: + raise ValueError(f"User '{user_name}' not found in database or cache!") - if user is not None: - # If the user is in the cache. - if user.at[0, 'status'] == 'logged_in': - # Return true if the user is logged in. - return True - else: - # Remove the user from the cached if the user's status flag is not set. - self.remove_user_from_cache(name) - # Return False. - return False + return user - # If the user was not in the cache, check the db incase the application has been restarted. - user = self.db.get_rows_where(table='users', filter_vals=('user_name', name)) - if user is None: - # The user is not in the database. - return False + def modify_user_data(self, username: str, field_name: str, new_data: Any) -> None: + """ + Updates user data in both the cache and the database. Ensures consistency between cache and database. + + :param username: The name of the user whose data is being modified. + :param field_name: The field to be updated. + :param new_data: The new data to be set. + """ + # Use DataCache to modify the user's data + self.data.modify_cached_row( + table='users', + filter_vals=('user_name', username), + field_name=field_name, + new_data=new_data + ) + + +class UserAccountManagement(BaseUser): + """ + Manages user login, logout, validation, and account creation processes. + """ + + def __init__(self, data_cache, max_guests: int = 100): + """ + Initialize the UserAccountManagement with caching, database interaction, + and guest management capabilities. + + :param data_cache: Object responsible for managing cached data and database interaction. + :param max_guests: Maximum number of guests allowed in the system at one time. + """ + super().__init__(data_cache) + + self.max_guests = max_guests # Maximum number of guests + + # Initialize data for guest suffixes and cached users + self.data.set_cache(data=[], key='guest_suffixes', do_not_overwrite=True) + self.data.set_cache(data={}, key='cached_users', do_not_overwrite=True) + + def is_logged_in(self, user_name: str) -> bool: + """ + Checks if the user is logged in by verifying their status in the cache or database. + + :param user_name: The name of the user. + :return: True if the user is logged in, False otherwise. + """ + if user_name is None: + return False + + def is_guest(_user) -> bool: + split_name = _user.at[0, 'user_name'].split("_") + return 'guest' in split_name and len(split_name) == 2 + + # Retrieve the user's data from either the cache or the database. + user = self.get_user_data(user_name) + + if user is not None and not user.empty: + # If the user exists, check their login status. if user.at[0, 'status'] == 'logged_in': - # Add the user back into the cache and return true. - self.add_user_to_cache(user) - # Check if the user was a guest. - split_name = user.at[0, 'user_name'].split("_") - if ('guest' in split_name) and (len(split_name) == 2): - # if it was, add the suffix to the cache. - guest_suffixes = self.cache.get_cache('guest_suffixes') - guest_suffixes.append(name[1]) - self.cache.set_cache(data=guest_suffixes, key='guest_suffixes') + # If the user is logged in, check if they are a guest. + if is_guest(user): + # Update the guest suffix cache if the user is a guest. + guest_suffixes = self.data.get_cache('guest_suffixes') + guest_suffixes.append(user_name[1]) + self.data.set_cache(data=guest_suffixes, key='guest_suffixes') return True else: - # The user is not logged in. + # If the user is not logged in, remove their data from the cache. + self._remove_user_from_memory(user_name) return False - if user_name is None: - # Not signed in if the session var is not set. - return False - if poll_login_flag(user_name): - # The user is logged in. - return True - else: - # The user is not logged in. + # If the user data is not found or the status is not 'logged_in', return False. + return False + + def validate_password(self, username: str, password: str) -> bool: + """ + Validates the user's password against the stored hash. + + :param username: The name of the user. + :param password: The plain text password to validate. + :return: True if the password is correct, False otherwise. + """ + # Retrieve the hashed password using DataCache + user_data = self.data.fetch_cached_rows(table='users', filter_vals=('user_name', username)) + + if user_data is None or user_data.empty: return False - def create_unique_guest_name(self): - # Retrieve the used suffixes from the cache. - guest_suffixes = self.cache.get_cache('guest_suffixes') - # Limit the number of guest allowed in the system at one time. - if len(guest_suffixes) > self.max_guests: - return None - # Pick a random suffix. - suffix = random.choice(range(0, (self.max_guests * 9))) - # If the suffix is already used try again until an unused one is found. - while suffix in guest_suffixes: - suffix = random.choice(range(0, (self.max_guests * 9))) - # return a unique user_name by appending the suffix. - return f'guest_{suffix}' + hashed_pass = user_data.iloc[0].get('password') - def create_guest(self) -> str | None: - """ - Insert a generic user into the database. - :return: str|None - Returns the user_name of the guest. | None if the - self.max_guest limit is exceeded or any other login error. - """ - - # if None is returned the guest limit is reached. - if (username := self.create_unique_guest_name()) is None: - return None - - attrs = ({'user_name': username},) - self.create_new_user_in_db(attrs=attrs) - - # Log the user in. - login_success = self.log_in_user(username=username, password='password') - if login_success: - return username - - def user_attr_is_taken(self, attr, val): - # Submit a request from the db for the satus of any user where attr == val. - status = self.db.get_from_static_table(item='status', table='users', indexes=HDict({attr: val})) - if status is None: - return False - else: - # If a result was returned the attribute is already taken. - return True - - @staticmethod - def scramble_text(some_text): - # More rounds increases complexity but makes it slower. - hasher = bcrypt.using(rounds=13) - return hasher.hash(some_text) - - def create_new_user_in_db(self, attrs: tuple): - """ - Retrieves a default user, modifies it, then appends the user table. - - :param attrs: tuple - a tuple of key, value dicts to apply to the new user. ({attr:value},...) - :return: None. - """ - # Get the default user template. - default_user = self.db.get_rows_where(table='users', filter_vals=('user_name', 'guest')) - - for each in attrs: - # Modify the every attribute received in the template. - key = str(*each) - default_user.loc[0, key] = each.get(key) - - # Drop the database index. - default_user = default_user.drop(columns='id') - - # Insert the modified user as a new record in the table. - self.db.insert_dataframe(df=default_user, table="users") - - def create_new_user(self, username: str, email: str, password: str): - - if self.user_attr_is_taken('user_name', username): + if not hashed_pass: return False - if self.user_attr_is_taken('email', email): + # Verify the password using bcrypt + try: + return bcrypt.verify(password, hashed_pass) + except ValueError: + # Handle any issues with the verification process return False - encrypted_password = self.scramble_text(password) - - attrs = ({'user_name': username}, {'email': email}, {'password': encrypted_password}) - self.create_new_user_in_db(attrs=attrs) - - return True - - def load_or_create_user(self, username: str) -> str: + def log_in_user(self, username: str, password: str) -> bool: """ - Loads user data either for existing user or a newly created user. + Validates user credentials and signs in the user by updating their status. - :param username: str|None - The user_name if exists in session. - :return: str - Either the same as passed in or a new user_name. - """ - - # Check if the user is logged in. - if not self.is_logged_in(user_name=username): - # Create a guest session if the user is not logged in. - username = self.create_guest() - if not username: - # There was a problem creating the user and system should display an error page. - raise ValueError('GuestLimitExceeded!') - # Load user data if not loaded. - self.load_user_data(user_name=username) - # Return the user_name. - return username - - def log_in_user(self, username: str, password: str): - """ - Validate User credentials and sign in the user. - - :param username: The name index for the user. - :param password: The unencrypted password. - :return: bool - True on success | False on fail. + :param username: The name of the user. + :param password: The unencrypted password. + :return: True on successful login, False otherwise. """ if success := self.validate_password(username=username, password=password): - # Set the user's status flag as logged in. self.modify_user_data(username=username, field_name="status", new_data="logged_in") - # Record the login time. self.modify_user_data(username=username, field_name="signin_time", new_data=dt.datetime.utcnow().timestamp()) return success - def log_out_user(self, username: str): + def log_out_user(self, username: str) -> bool: """ - Log out the user. + Logs out the user by updating their status and removing them from the cache. - :param username: The name of the user. - :return: bool + :param username: The name of the user. + :return: True on successful logout. """ - # Set the user's status flag as logged out. + # Update the user's status and sign-in time in both cache and database self.modify_user_data(username=username, field_name='status', new_data='logged_out') - # Record the logout time. self.modify_user_data(username=username, field_name='signin_time', new_data=dt.datetime.utcnow().timestamp()) - # Remove the user from the cache. - self.remove_user_from_cache(user_name=username) + + # Remove the user's data from the cache + self._remove_user_from_memory(user_name=username) + return True - def log_out_all_users(self): + def log_out_all_users(self, enforcement: str = 'hard') -> None: """ - Log out all users. - :return: None + Logs out all users by updating their status in the database and clearing the data. + + :param enforcement: 'soft' or 'hard' - Determines how strictly users are logged out. """ - for key in self.cache.cached_data: - if (self.cache.cached_data[key] is dict) and ('user_name' in self.cache.cached_data[key]): - # remove the user from the cache. - del self.cache.cached_data[key] + if enforcement == 'soft': + self._soft_log_out_all_users() + elif enforcement == 'hard': + # Clear all user-related entries from the cache + for index, row in self.data.cache.iterrows(): + if 'user_name' in row: + self._remove_user_from_memory(row['user_name']) - df = self.db.get_rows_where(table='users', filter_vals=('status', 'logged_in')) - if df is not None: - df = df[df.user_name != 'guest'] + df = self.data.fetch_cached_rows(table='users', filter_vals=('status', 'logged_in')) + if df is not None: + df = df[df.user_name != 'guest'] - # Set all logged-in user's status flag to logged_out. - ids = df.user_name.values - if ids is not None: - for value in ids: - sql = f"UPDATE users SET status = 'logged_out' WHERE user_name = '{value}';" - self.db.execute_sql(sql) - return + # Update the status of all logged-in users to 'logged_out' + for user_name in df.user_name.values: + self.modify_user_data(username=user_name, field_name='status', new_data='logged_out') + else: + raise ValueError("Invalid enforcement type. Use 'soft' or 'hard'.") - def load_user_data(self, user_name: str) -> None: + def _soft_log_out_all_users(self) -> None: """ - Check if user data is in the cache and load if it is not. - - :param user_name: The name of the user. - :return: None + Placeholder for soft logout logic. Soft logout might involve suspending logins, halting trades, + or allowing current operations to finish. (TODO: Implement soft logout logic) """ - # if self.cache.is_loaded(user_name): - # # Return if already loaded. - # return - - print(f'Retrieving this user: {user_name}') - # Request a dataframe containing the users properties. - user = self.get_user_from_db(user_name=user_name) - if user.empty: - raise ValueError('Attempted to load user not in db!') - - # Load chart view stored in the database and convert back to a dictionary. - chart_view = json.loads(user.at[0, 'chart_views']) - # todo : save chart view somewhere you can do it - # load indicators - pass - # load strategies + # TODO: Implement soft logout logic here pass - def modify_user_data(self, username: str, field_name: str, new_data: Any) -> None: + def user_attr_is_taken(self, attr: str, val: str) -> bool: """ - Update a user data field in both the cache and the database. + Checks if a specific user attribute (e.g., username, email) is already taken. - :param username: str - The name of the user. - :param field_name: str - The field to update. - :param new_data: Any - The data to be set. - :return: None + :param attr: The attribute to check (e.g., 'user_name', 'email'). + :param val: The value of the attribute to check. + :return: True if the attribute is already taken, False otherwise. """ + # Use DataCache to check if the attribute is taken + return self.data.is_attr_taken(table='users', attr=attr, val=val) - # Get the user data from the db. - user = self.db.get_rows_where(table='users', filter_vals=('user_name', username)) - - # Drop the db id column. - modified_user = user.drop(columns='id') - - # Modify the targeted field. - if type(new_data) is str: - modified_user.loc[0, field_name] = new_data - else: - modified_user.loc[0, field_name] = json.dumps(new_data) - - # Replace the user records in the cache with the updated data record. - self.cache.update_cached_dict(cache_key='cached_users', dict_key=username, data=modified_user) - - # Delete the old record from the db - sql = f"DELETE FROM users WHERE user_name = '{user.loc[0, 'user_name']}';" - self.db.execute_sql(sql) - - # Replace the records in the database. - self.db.insert_dataframe(modified_user, 'users') - - return - - def validate_password(self, username: str, password: str) -> bool: + def create_unique_guest_name(self) -> str | None: """ - Validate a user password with the hashed password stored in the db. + Creates a unique guest username by appending a random suffix. - :param username: str - The name of the user. - :param password: str - The plain text password. - :return: bool - True on success. + :return: A unique guest username or None if the guest limit is reached. """ - # Validate input - if (password is None) or (username is None): + guest_suffixes = self.data.get_cache('guest_suffixes') + if len(guest_suffixes) > self.max_guests: + return None + suffix = random.choice(range(0, (self.max_guests * 9))) + while suffix in guest_suffixes: + suffix = random.choice(range(0, (self.max_guests * 9))) + return f'guest_{suffix}' + + def create_guest(self) -> str | None: + """ + Creates a guest user in the database and logs them in. + + :return: The guest username or None if the guest limit is reached. + """ + if (username := self.create_unique_guest_name()) is None: + return None + attrs = ({'user_name': username},) + self.create_new_user_in_db(attrs=attrs) + + login_success = self.log_in_user(username=username, password='password') + if login_success: + return username + + def create_new_user_in_db(self, attrs: tuple) -> None: + """ + Creates a new user in the database by modifying a default template. + + :param attrs: A tuple of key-value pairs representing user attributes. + :raises ValueError: If attrs are not well-formed. + """ + if not attrs or not all(isinstance(attr, dict) and len(attr) == 1 for attr in attrs): + raise ValueError("Attributes must be a tuple of single key-value pair dictionaries.") + + # Retrieve the default user template from the database using DataCache + default_user = self.data.fetch_cached_rows(table='users', filter_vals=('user_name', 'guest')) + + if default_user is None or default_user.empty: + raise ValueError("Default user template not found in the database.") + + # Modify the default user template with the provided attributes + for attr in attrs: + key, value = next(iter(attr.items())) + default_user.loc[0, key] = value + + # Remove the 'id' column before inserting into the database + default_user = default_user.drop(columns='id') + + # Insert the modified user data into the database, skipping cache insertion + self.data.insert_data(df=default_user, table="users", skip_cache=True) + + def create_new_user(self, username: str, email: str, password: str) -> bool: + """ + Creates a new user in the system if the username and email are not already taken. + + :param username: The desired username. + :param email: The user's email address. + :param password: The user's password. + :return: True if the user was successfully created, False otherwise. + """ + if self.user_attr_is_taken('user_name', username) or self.user_attr_is_taken('email', email): return False + encrypted_password = self.scramble_text(password) + attrs = ({'user_name': username}, {'email': email}, {'password': encrypted_password}) + self.create_new_user_in_db(attrs=attrs) + return True - # Get the password from the database. - hashed_pass = self.db.get_from_static_table(item='password', table='users', - indexes=HDict({'user_name': username})) - if hashed_pass is None: - return False - hasher = bcrypt.using(rounds=13) # Make it slower - return hasher.verify(password, hashed_pass) - - def get_chart_view(self, user_name: str, prop: str | None = None): + def load_or_create_user(self, username: str) -> str: """ - Fetches the chart view or one specific_property of it for a specific user. + Loads existing user data or creates a guest user if the user is not logged in. - :param user_name: str - The name of the user the query is for. - :param prop: str|None - Optional the specific specific_property of the chart view. - :return: dict|str - The value of the specific_property specified in specific_property or a dict of 3 values. + :param username: The username to load or create. + :return: The username of the loaded or newly created user. + :raises ValueError: If a guest user cannot be created. """ + if not self.is_logged_in(user_name=username): + username = self.create_guest() + if not username: + raise ValueError('GuestLimitExceeded!') + self.get_user_data(user_name=username) + return username - # Get a pd.DataFrame that contains the user properties. - print(f'Retrieving this user: {user_name}') - user = self.get_user_from_db(user_name=user_name) - if user.empty: - return - # The chart view is dict object stored as a string in the database to simplify the data storage. - # This converts it back to a dictionary. - # chart_view = literal_eval(user.at[0, 'chart_views']) - chart_view = json.loads(user.at[0, 'chart_views']) - - if prop is None: - # If no specific_property is specified return the entire dict. - return chart_view - if prop == 'exchange_name': - # Return the exchange_name name as a string. - return chart_view.get('exchange') - if prop == 'timeframe': - # Return the timeframe as a string. - return chart_view.get('timeframe') - if prop == 'market': - # Return the market as a string. - return chart_view.get('market') - - def set_chart_view(self, user_name: str, values: dict | str, - specific_property: str | None = None, default_market: str | None = None): + @staticmethod + def scramble_text(some_text: str) -> str: """ + Encrypts text using bcrypt with 13 rounds of hashing. - :param user_name: str - The name of the user whose data is being modified. - :param values: str|list - The value of the property or values of a dict to be set. - :param specific_property: str|None - Optional the specific property of the chart view to modify. - :param default_market: str|None - a default market must be provided to display when - changing the exchange_name view. - :return: + :param some_text: The text to encrypt. + :return: The hashed text. """ + return bcrypt.hash(some_text, rounds=13) - # Get a pd.DataFrame that contains the user properties. - user = self.get_user_from_db(user_name=user_name) - if specific_property is None: - # Validate values - assert (type(values) is dict) - assert (len(values) == 3) - # user['chart_view'] = str(values) - # TODO!!! setting directly then useing modifiy user data elsewhere - user['chart_view'] = json.dumps(values) - return - - # Validate specific_property - assert (type(specific_property) is str) - - # chart_view = literal_eval(user.at[0, 'chart_views']) - chart_view = json.loads(user.at[0, 'chart_views']) - if specific_property == 'exchange_name': - # Ensure a market view is always passed in while changing the exchange_name. - assert (default_market is not None) - # Set the market. - chart_view['market'] = default_market - # Set the exchange_name. - chart_view['exchange_name'] = values - elif specific_property == 'timeframe': - # Set the timeframe - chart_view['timeframe'] = values - elif specific_property == 'market': - # Set the market. - chart_view['market'] = values - else: - raise ValueError(f'{specific_property} is not a specific_property of chart_views') - - self.modify_user_data(username=user_name, field_name='chart_views', new_data=chart_view) - return - - def get_user_from_db(self, user_name: str) -> pd.DataFrame | None: - user = self.db.get_rows_where(table='users', filter_vals=('user_name', user_name)) - if not user.empty: - return user - - def get_user_from_cache(self, user_name: str) -> pd.DataFrame | None: - """ - Returns the DataFrame object that contains the data of only . - or None if the user is not logged in. - """ - # Get a dictionary of all cached users data indexed by user_name. - if all_users := self.cache.get_cache('cached_users'): - # Return the specific data. - return all_users.get(user_name) - - def remove_user_from_cache(self, user_name: str) -> None: - """ - Removes a user object from the cache. - """ - # Get a dictionary of the cached user data indexed by user_name. - users = self.cache.get_cache('cached_users') - # Remove the specific user from the cached data. - del users[user_name] - # Reset the cached data - self.cache.set_cache(data=users, key='cached_users') - return - - def add_user_to_cache(self, user: pd.DataFrame) -> None: - """ - Adds a user object to the cache. - """ - self.cache.update_cached_dict(cache_key='cached_users', dict_key=user.at[0, 'user_name'], data=user) - return +class UserExchangeManagement(UserAccountManagement): + """ + Manages user exchange-related data and operations. + """ def get_api_keys(self, user_name: str, exchange: str) -> dict: - # Get the user records from the database. - user = self.get_user_from_db(user_name) + """ + Retrieves the API keys for a specific exchange associated with a user. + + :param user_name: The name of the user. + :param exchange: The name of the exchange. + :return: A dictionary containing the API keys for the exchange. + """ + user = self.get_user_data(user_name) + if user is None or user.empty or 'api_keys' not in user.columns: + return {} - # Get the api keys dictionary. user_keys = user.loc[0, 'api_keys'] user_keys = json.loads(user_keys) if user_keys else {} return user_keys.get(exchange) - def update_api_keys(self, api_keys, exchange, user_name: str): - # Get the user records from the database. - user = self.get_user_from_db(user_name) + def update_api_keys(self, api_keys, exchange, user_name: str) -> None: + """ + Updates the API keys for a specific exchange and user, and activates the exchange. - # Get the old api keys dictionary. + :param api_keys: The new API keys to store. + :param exchange: The exchange for which the keys are being updated. + :param user_name: The name of the user. + """ + user = self.get_user_data(user_name) user_keys = user.loc[0, 'api_keys'] user_keys = json.loads(user_keys) if user_keys else {} - # Add the new keys to the dictionary. user_keys.update({exchange: api_keys}) - # Modify the cache and db. self.modify_user_data(username=user_name, field_name='api_keys', new_data=json.dumps(user_keys)) - # Get the old list of configured exchanges. configured_exchanges = json.loads(user.loc[0, 'configured_exchanges']) - - # Check if the exchange is already in the list if exchange not in configured_exchanges: - # Append the list of configured exchanges. configured_exchanges.append(exchange) - # Modify the cache and db. self.modify_user_data(username=user_name, field_name='configured_exchanges', new_data=json.dumps(configured_exchanges)) - # Activate the newly configured exchange. self.active_exchange(exchange=exchange, user_name=user_name, cmd='set') def get_exchanges(self, user_name: str, category: str) -> list | None: @@ -535,17 +403,13 @@ class Users: :param user_name: The name of the user. :param category: The category to retrieve ('active_exchanges' or 'configured_exchanges'). - :return: List of exchanges or None if user or category is not found. + :return: A list of exchanges or None if user or category is not found. """ try: - # Get the user records from the database. - user = self.get_user_from_db(user_name) - # Get the exchanges list based on the field. + user = self.get_user_data(user_name) exchanges = user.loc[0, category] - # Return the list if it exists, otherwise return an empty list. return json.loads(exchanges) if exchanges else [] except (KeyError, IndexError, json.JSONDecodeError) as e: - # Log the error to the console print(f"Error retrieving exchanges for user '{user_name}' and field '{category}': {str(e)}") return None @@ -553,31 +417,157 @@ class Users: """ Inserts, removes, or toggles an exchange name in a list of active exchanges. - :param cmd: str - 'toggle' | 'set' | 'remove' - :param exchange: str - The name of the exchange. - :param user_name: str - The name of the user executing this command. - :return: None + :param cmd: 'toggle' | 'set' | 'remove' - The action to perform on the exchange. + :param exchange: The name of the exchange. + :param user_name: The name of the user executing this command. """ - # Get the user records from the database. - user = self.get_user_from_db(user_name) - # Get the old active_exchanges list, or initialize as an empty list if it is None. + user = self.get_user_data(user_name) active_exchanges = user.loc[0, 'active_exchanges'] if active_exchanges is None: active_exchanges = [] else: active_exchanges = json.loads(active_exchanges) - # Define the actions for each command actions = { 'toggle': lambda x: active_exchanges.remove(x) if x in active_exchanges else active_exchanges.append(x), 'set': lambda x: active_exchanges.append(x) if x not in active_exchanges else active_exchanges, 'remove': lambda x: active_exchanges.remove(x) if x in active_exchanges else active_exchanges } - # Perform the action based on the command action = actions.get(cmd) if action: action(exchange) - # Modify the cache and db. self.modify_user_data(username=user_name, field_name='active_exchanges', new_data=json.dumps(active_exchanges)) + +class UserIndicatorManagement(UserExchangeManagement): + """ + Manages user indicators and related operations. + """ + + def get_indicators(self, user_name: str) -> pd.DataFrame | None: + """ + Retrieves all indicators for a specific user. + + :param user_name: The name of the user. + :return: A DataFrame containing the user's indicators or None if not found. + """ + # Retrieve the user's ID + user_id = self.get_id(user_name) + + # Fetch the indicators from the database using DataCache + df = self.data.fetch_cached_rows(table='indicators', filter_vals=('creator', user_id)) + + # If indicators are found, process the JSON fields + if df is not None and not df.empty: + df['source'] = df['source'].apply(json.loads) + df['properties'] = df['properties'].apply(json.loads) + + return df + + def save_indicators(self, indicators: pd.DataFrame) -> None: + """ + Stores one or many indicators in the database. + + :param indicators: A DataFrame containing indicator attributes and properties. + """ + for _, indicator in indicators.iterrows(): + # Convert necessary fields to JSON strings + src_string = json.dumps(indicator['source']) + prop_string = json.dumps(indicator['properties']) + + # Prepare the values and columns for insertion + values = (indicator['creator'], indicator['name'], indicator['visible'], + indicator['kind'], src_string, prop_string) + columns = ('creator', 'name', 'visible', 'kind', 'source', 'properties') + + # Insert the row into the database using DataCache + self.data.insert_row(table='indicators', columns=columns, values=values) + + def remove_indicator(self, indicator_name: str, user_name: str) -> None: + """ + Removes a specific indicator from the database and cache. + + :param indicator_name: The name of the indicator to remove. + :param user_name: The name of the user who created the indicator. + """ + user_id = self.get_id(user_name) + self.data.remove_row( + filter_vals=('name', indicator_name), + additional_filter=('creator', user_id), + table='indicators' + ) + + def get_chart_view(self, user_name: str, prop: str | None = None): + """ + Fetches the chart view or one specific property of it for a specific user. + + :param user_name: The name of the user. + :param prop: Optional specific property of the chart view to retrieve. + :return: A dictionary of chart views or a specific property if specified. + """ + user = self.get_user_data(user_name) + if user.empty: + return + chart_view = json.loads(user.at[0, 'chart_views']) + + if prop is None: + return chart_view + if prop == 'exchange_name': + return chart_view.get('exchange') + if prop == 'timeframe': + return chart_view.get('timeframe') + if prop == 'market': + return chart_view.get('market') + + def set_chart_view(self, user_name: str, values: dict | str, + specific_property: str | None = None, default_market: str | None = None) -> None: + """ + Sets or updates the chart view for a specific user. + + :param user_name: The name of the user. + :param values: The values to set for the chart view. + :param specific_property: Optional specific property of the chart view to modify. + :param default_market: Default market to display when changing the exchange view. + :raises ValueError: If inputs are not valid. + """ + user = self.get_user_data(user_name) + + if specific_property is None: + if not isinstance(values, dict) or len(values) != 3: + raise ValueError("Values must be a dictionary with exactly 3 key-value pairs.") + user['chart_view'] = json.dumps(values) + return + + if not isinstance(specific_property, str): + raise ValueError("Specific property must be a string.") + + chart_view = json.loads(user.at[0, 'chart_views']) + if specific_property == 'exchange_name': + if default_market is None: + raise ValueError("Default market must be provided when setting the exchange name.") + chart_view['market'] = default_market + chart_view['exchange_name'] = values + elif specific_property == 'timeframe': + chart_view['timeframe'] = values + elif specific_property == 'market': + chart_view['market'] = values + else: + raise ValueError(f'{specific_property} is not a specific property of chart_views') + + self.modify_user_data(username=user_name, field_name='chart_views', new_data=chart_view) + + +class Users(UserIndicatorManagement): + """ + Comprehensive user management class that inherits all functionalities from UserIndicatorManagement. + """ + + def __init__(self, data_cache, max_guests: int = 100): + """ + Initialize the Users class with caching, database interaction, and user management capabilities. + + :param data_cache: Object responsible for managing cached data and database interaction. + :param max_guests: Maximum number of guests allowed in the system at one time. + """ + super().__init__(data_cache, max_guests) diff --git a/src/app.py b/src/app.py index 66ab56e..388accc 100644 --- a/src/app.py +++ b/src/app.py @@ -7,7 +7,6 @@ from flask_sock import Sock from email_validator import validate_email, EmailNotValidError # Handles all updates and requests for locally stored data. -import config from BrighterTrades import BrighterTrades # Set up logging @@ -54,7 +53,7 @@ def index(): try: # Log the user in. - user_name = brighter_trades.config.users.load_or_create_user(username=session.get('user')) + user_name = brighter_trades.users.load_or_create_user(username=session.get('user')) except ValueError as e: if str(e) != 'GuestLimitExceeded!': raise diff --git a/src/archived_code/DataCache.py b/src/archived_code/DataCache.py index c2388ca..2853bff 100644 --- a/src/archived_code/DataCache.py +++ b/src/archived_code/DataCache.py @@ -37,16 +37,16 @@ class DataCache: def cache_exists(self, key: str) -> bool: """ - Checks if a cache exists for the given key. + Checks if a data exists for the given key. :param key: The access key. - :return: True if cache exists, False otherwise. + :return: True if data exists, False otherwise. """ return key in self.cached_data def update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None: """ - Adds records to existing cache. + Adds records to existing data. :param more_records: The new records to be added. :param key: The access key. @@ -59,9 +59,9 @@ class DataCache: def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None: """ - Creates a new cache key and inserts data. + Creates a new data key and inserts data. - :param data: The records to insert into cache. + :param data: The records to insert into data. :param key: The index key for the data. :param do_not_overwrite: Flag to prevent overwriting existing data. :return: None @@ -72,10 +72,10 @@ class DataCache: def update_cached_dict(self, cache_key: str, dict_key: str, data: Any) -> None: """ - Updates a dictionary stored in cache. + Updates a dictionary stored in data. - :param data: The data to insert into cache. - :param cache_key: The cache index key for the dictionary. + :param data: The data to insert into data. + :param cache_key: The data index key for the dictionary. :param dict_key: The dictionary key for the data. :return: None """ @@ -89,7 +89,7 @@ class DataCache: :return: Any|None - The requested data or None on key error. """ if key not in self.cached_data: - logger.warning(f"The requested cache key({key}) doesn't exist!") + logger.warning(f"The requested data key({key}) doesn't exist!") return None return self.cached_data[key] @@ -98,14 +98,14 @@ class DataCache: """ Fetches records since the specified start datetime. - :param key: The cache key. + :param key: The data key. :param start_datetime: The start datetime to fetch records from. :param record_length: The required number of records. :param ex_details: Exchange details. :return: DataFrame containing the records. """ try: - target = 'cache' + target = 'data' args = { 'key': key, 'start_datetime': start_datetime, @@ -167,7 +167,7 @@ class DataCache: 'record_length': record_length, } - if target == 'cache': + if target == 'data': result = get_from_cache() if data_complete(result, **request_criteria): return result @@ -193,7 +193,7 @@ class DataCache: """ Fetches records since the specified start datetime. - :param key: The cache key. + :param key: The data key. :param start_datetime: The start datetime to fetch records from. :param record_length: The required number of records. :param ex_details: Exchange details. @@ -203,11 +203,11 @@ class DataCache: end_datetime = dt.datetime.utcnow() if self.cache_exists(key=key): - logger.debug('Getting records from cache.') + logger.debug('Getting records from data.') records = self.get_cache(key) else: logger.debug( - f'Records not in cache. Requesting from DB: starting at: {start_datetime} to: {end_datetime}') + f'Records not in data. Requesting from DB: starting at: {start_datetime} to: {end_datetime}') records = self.get_records_since_from_db(table_name=key, st=start_datetime, et=end_datetime, rl=record_length, ex_details=ex_details) logger.debug(f'Got {len(records.index)} records from DB.') diff --git a/src/candles.py b/src/candles.py index 2c00ccc..3ac8ec9 100644 --- a/src/candles.py +++ b/src/candles.py @@ -9,23 +9,23 @@ from shared_utilities import timeframe_to_minutes, ts_of_n_minutes_ago # log.basicConfig(level=log.ERROR) class Candles: - def __init__(self, exchanges, config_obj, data_source): + def __init__(self, exchanges, users, data_source, config): # A reference to the app configuration - self.config = config_obj + self.users = users # The maximum amount of candles to load at one time. - self.max_records = self.config.app_data.get('max_data_loaded') + self.max_records = config.get_setting('max_data_loaded') # This object maintains all the cached data. self.data = data_source - # print('Setting the candle cache.') - # # Populate the cache: - # self.set_cache(symbol=self.config.users.get_chart_view(user_name='guest', specific_property='market'), - # interval=self.config.users.get_chart_view(user_name='guest', specific_property='timeframe'), - # exchange_name=self.config.users.get_chart_view(user_name='guest', specific_property='exchange_name')) - # print('DONE Setting cache') + # print('Setting the candle data.') + # # Populate the data: + # self.set_cache(symbol=self.users.get_chart_view(user_name='guest', specific_property='market'), + # interval=self.users.get_chart_view(user_name='guest', specific_property='timeframe'), + # exchange_name=self.users.get_chart_view(user_name='guest', specific_property='exchange_name')) + # print('DONE Setting data') def get_last_n_candles(self, num_candles: int, asset: str, timeframe: str, exchange: str, user_name: str): """ @@ -83,7 +83,7 @@ class Candles: # def set_cache(self, symbol=None, interval=None, exchange_name=None, user_name=None): """ - This method requests a chart from memory to ensure the cache is initialized. + This method requests a chart from memory to ensure the data is initialized. :param user_name: :param symbol: str - The symbol of the market. @@ -91,18 +91,18 @@ class Candles: :param exchange_name: str - The name of the exchange_name to fetch from. :return: None """ - # By default, initialise cache with the last viewed chart. + # By default, initialise data with the last viewed chart. if not symbol: assert user_name is not None - symbol = self.config.users.get_chart_view(user_name=user_name, prop='market') + symbol = self.users.get_chart_view(user_name=user_name, prop='market') log.info(f'set_candle_history(): No symbol provided. Using{symbol}') if not interval: assert user_name is not None - interval = self.config.users.get_chart_view(user_name=user_name, prop='timeframe') + interval = self.users.get_chart_view(user_name=user_name, prop='timeframe') log.info(f'set_candle_history(): No timeframe provided. Using{interval}') if not exchange_name: assert user_name is not None - exchange_name = self.config.users.get_chart_view(user_name=user_name, prop='exchange_name') + exchange_name = self.users.get_chart_view(user_name=user_name, prop='exchange_name') # Log the completion to the console. log.info('set_candle_history(): Loading candle data...') @@ -199,20 +199,20 @@ class Candles: :param user_name: str - The name of the user who owns the exchange. :return: list - Candle records in the lightweight charts format. """ - # By default, initialise cache with the last viewed chart. + # By default, initialise data with the last viewed chart. if not symbol: assert user_name is not None - symbol = self.config.users.get_chart_view(user_name=user_name, prop='market''market') + symbol = self.users.get_chart_view(user_name=user_name, prop='market''market') log.info(f'set_candle_history(): No symbol provided. Using{symbol}') if not interval: assert user_name is not None - interval = self.config.users.get_chart_view(user_name=user_name, prop='market''timeframe') + interval = self.users.get_chart_view(user_name=user_name, prop='market''timeframe') log.info(f'set_candle_history(): No timeframe provided. Using{interval}') if not exchange_name: assert user_name is not None - exchange_name = self.config.users.get_chart_view(user_name=user_name, prop='market''exchange_name') + exchange_name = self.users.get_chart_view(user_name=user_name, prop='market''exchange_name') log.info(f'get_candle_history(): No exchange name provided. Using {exchange_name}') candlesticks = self.get_last_n_candles(num_candles=num_records, asset=symbol, timeframe=interval, diff --git a/src/indicators.py b/src/indicators.py index 47bf112..0cd77be 100644 --- a/src/indicators.py +++ b/src/indicators.py @@ -308,12 +308,12 @@ indicator_types.append('MACD') class Indicators: - def __init__(self, candles, config): + def __init__(self, candles, users): # Object manages and serves price and candle data. self.candles = candles - # A connection to an object that handles user configuration and persistent data. - self.config = config + # A connection to an object that handles user data. + self.users = users # Collection of instantiated indicators objects self.indicators = pd.DataFrame(columns=['creator', 'name', 'visible', @@ -331,7 +331,7 @@ class Indicators: Get the users watch-list from the database and load the indicators into a dataframe. :return: None """ - active_indicators: pd.DataFrame = self.config.users.get_indicators(user_name) + active_indicators: pd.DataFrame = self.users.get_indicators(user_name) if active_indicators is not None: # Create an instance for each indicator. @@ -347,7 +347,7 @@ class Indicators: Saves the indicators in the database indexed by the user id. :return: None """ - self.config.users.save_indicators(indicator) + self.users.save_indicators(indicator) # @staticmethod # def get_indicator_defaults(): @@ -389,7 +389,7 @@ class Indicators: :param only_enabled: bool - If True, return only indicators marked as visible. :return: dict - A dictionary of indicator names as keys and their attributes as values. """ - user_id = self.config.users.get_id(username) + user_id = self.users.get_id(username) if not user_id: raise ValueError(f"Invalid user_name: {username}") @@ -498,7 +498,7 @@ class Indicators: :param num_results: The number of results being requested. :return: The results of the indicator analysis as a DataFrame. """ - username = self.config.users.get_username(indicator.creator) + username = self.users.get_username(indicator.creator) src = indicator.source symbol, timeframe, exchange_name = src['symbol'], src['timeframe'], src['exchange_name'] @@ -532,7 +532,7 @@ class Indicators: if start_ts: print("Warning: start_ts has not implemented in get_indicator_data()!") - user_id = self.config.users.get_id(user_name=user_name) + user_id = self.users.get_id(user_name=user_name) # Construct the query based on user_id and visibility. query = f"creator == {user_id}" @@ -582,7 +582,7 @@ class Indicators: if not indicator_name: raise ValueError("No indicator name provided.") self.indicators = self.indicators.query("name != @indicator_name").reset_index(drop=True) - self.config.users.save_indicators() + self.users.save_indicators() def create_indicator(self, creator: str, name: str, kind: str, source: dict, properties: dict, visible: bool = True): @@ -618,7 +618,7 @@ class Indicators: indicator = indicator_class(name, kind, properties) # Add the new indicator to a pandas dataframe. - creator_id = self.config.users.get_id(creator) + creator_id = self.users.get_id(creator) row_data = { 'creator': creator_id, 'name': name, diff --git a/src/maintenence/debuging_testing.py b/src/maintenence/debuging_testing.py index 5d3e367..0f13369 100644 --- a/src/maintenence/debuging_testing.py +++ b/src/maintenence/debuging_testing.py @@ -1,38 +1,33 @@ -import ccxt import pandas as pd -import datetime +import matplotlib.pyplot as plt +from matplotlib.table import Table + +# Simulating the cache as a DataFrame +data = { + 'key': ['BTC/USD_2h_binance', 'ETH/USD_1h_coinbase'], + 'data': ['{"open": 50000, "close": 50500}', '{"open": 1800, "close": 1825}'] +} +cache_df = pd.DataFrame(data) -def fetch_and_print_ohlcv(exchange_name, symbol, timeframe, since, limit=5): - # Initialize the exchange - exchange_class = getattr(ccxt, exchange_name) - exchange = exchange_class() +# Visualization function +def visualize_cache(df): + fig, ax = plt.subplots(figsize=(6, 3)) + ax.set_axis_off() + tb = Table(ax, bbox=[0, 0, 1, 1]) - # Fetch historical candlestick data with a limit - ohlcv = exchange.fetch_ohlcv(symbol, timeframe, since=since, limit=limit) + # Adding column headers + for i, column in enumerate(df.columns): + tb.add_cell(0, i, width=0.4, height=0.3, text=column, loc='center', facecolor='lightgrey') - # Convert to DataFrame for better readability - df = pd.DataFrame(ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']) + # Adding rows and cells + for i in range(len(df)): + for j, value in enumerate(df.iloc[i]): + tb.add_cell(i + 1, j, width=0.4, height=0.3, text=value, loc='center', facecolor='white') - # Print the first few rows of the DataFrame - print("First few rows of the fetched OHLCV data:") - print(df.head()) - - # Print the timestamps in human-readable format - df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms') - print("\nFirst few timestamps in human-readable format:") - print(df[['timestamp', 'datetime']].head()) - - # Confirm the format of the timestamps - print("\nTimestamp format confirmation:") - for ts in df['timestamp']: - print(f"{ts} (milliseconds since Unix epoch)") + ax.add_table(tb) + plt.title("Visualizing Cache Data") + plt.show() -# Example usage -exchange_name = 'binance' # Change this to your exchange -symbol = 'BTC/USDT' -timeframe = '5m' -since = int((datetime.datetime(2024, 8, 1) - datetime.datetime(1970, 1, 1)).total_seconds() * 1000) - -fetch_and_print_ohlcv(exchange_name, symbol, timeframe, since, limit=5) +visualize_cache(cache_df) diff --git a/src/trade.py b/src/trade.py index 263b678..75b3976 100644 --- a/src/trade.py +++ b/src/trade.py @@ -267,10 +267,10 @@ class Trade: class Trades: - def __init__(self, loaded_trades=None): + def __init__(self, users): """ This class receives, executes, tracks and stores all active_trades. - :param loaded_trades: A bunch of active_trades to create and store. + :param users: A class that maintains users each user may have trades. """ # Object that maintains exchange_interface and account data self.exchange_interface = None @@ -290,7 +290,8 @@ class Trades: self.settled_trades = [] self.stats = {'num_trades': 0, 'total_position': 0, 'total_position_value': 0} - # Load any trades that were passed into the constructor. + # Load all trades. + loaded_trades = users.get_all_active_user_trades() if loaded_trades is not None: # Create the active_trades loaded from file. self.load_trades(loaded_trades) diff --git a/tests/test_DataCache_v2.py b/tests/test_DataCache_v2.py index 6c729ab..77edae4 100644 --- a/tests/test_DataCache_v2.py +++ b/tests/test_DataCache_v2.py @@ -1,5 +1,5 @@ import pytz -from DataCache_v2 import DataCache +from DataCache_v2 import DataCache, timeframe_to_timedelta, estimate_record_count from ExchangeInterface import ExchangeInterface import unittest import pandas as pd @@ -224,10 +224,23 @@ class TestDataCacheV2(unittest.TestCase): exchange_id INTEGER, FOREIGN KEY (exchange_id) REFERENCES exchange(id) )""" + sql_create_table_4 = f""" + CREATE TABLE IF NOT EXISTS test_table_2 ( + key TEXT PRIMARY KEY, + data TEXT NOT NULL + )""" + sql_create_table_5 = f""" + CREATE TABLE IF NOT EXISTS users ( + users_data TEXT PRIMARY KEY, + data TEXT NOT NULL + )""" + with SQLite(db_file=self.db_file) as con: con.execute(sql_create_table_1) con.execute(sql_create_table_2) con.execute(sql_create_table_3) + con.execute(sql_create_table_4) + con.execute(sql_create_table_5) self.data = DataCache(self.exchanges) self.data.db = self.database @@ -241,30 +254,41 @@ class TestDataCacheV2(unittest.TestCase): def test_set_cache(self): print('\nTesting set_cache() method without no-overwrite flag:') self.data.set_cache(data='data', key=self.key) - attr = self.data.__getattribute__('cached_data') - self.assertEqual(attr[self.key], 'data') - print(' - Set cache without no-overwrite flag passed.') + + # Access the cache data using the DataFrame structure + cached_value = self.data.get_cache(key=self.key) + self.assertEqual(cached_value, 'data') + print(' - Set data without no-overwrite flag passed.') print('Testing set_cache() once again with new data without no-overwrite flag:') self.data.set_cache(data='more_data', key=self.key) - attr = self.data.__getattribute__('cached_data') - self.assertEqual(attr[self.key], 'more_data') - print(' - Set cache with new data without no-overwrite flag passed.') + + # Access the updated cache data + cached_value = self.data.get_cache(key=self.key) + self.assertEqual(cached_value, 'more_data') + print(' - Set data with new data without no-overwrite flag passed.') print('Testing set_cache() method once again with more data with no-overwrite flag set:') self.data.set_cache(data='even_more_data', key=self.key, do_not_overwrite=True) - attr = self.data.__getattribute__('cached_data') - self.assertEqual(attr[self.key], 'more_data') - print(' - Set cache with no-overwrite flag passed.') + + # Since do_not_overwrite is True, the cached data should not change + cached_value = self.data.get_cache(key=self.key) + self.assertEqual(cached_value, 'more_data') + print(' - Set data with no-overwrite flag passed.') def test_cache_exists(self): print('Testing cache_exists() method:') - self.assertFalse(self.data.cache_exists(key=self.key)) - print(' - Check for non-existent cache passed.') + # Check that the cache does not contain the key before setting it + self.assertFalse(self.data.cache_exists(key=self.key)) + print(' - Check for non-existent data passed.') + + # Set the cache with a DataFrame containing the key-value pair self.data.set_cache(data='data', key=self.key) + + # Check that the cache now contains the key self.assertTrue(self.data.cache_exists(key=self.key)) - print(' - Check for existent cache passed.') + print(' - Check for existent data passed.') def test_update_candle_cache(self): print('Testing update_candle_cache() method:') @@ -273,21 +297,21 @@ class TestDataCacheV2(unittest.TestCase): data_gen = DataGenerator('5m') # Create initial DataFrame and insert into cache - df_initial = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 0, 0)) + df_initial = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 0, 0, tzinfo=dt.timezone.utc)) print(f'Inserting this table into cache:\n{df_initial}\n') self.data.set_cache(data=df_initial, key=self.key) # Create new DataFrame to be added to cache - df_new = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 15, 0)) + df_new = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 15, 0, tzinfo=dt.timezone.utc)) print(f'Updating cache with this table:\n{df_new}\n') - self.data.update_candle_cache(more_records=df_new, key=self.key) + self.data._update_candle_cache(more_records=df_new, key=self.key) # Retrieve the resulting DataFrame from cache result = self.data.get_cache(key=self.key) print(f'The resulting table in cache is:\n{result}\n') # Create the expected DataFrame - expected = data_gen.create_table(num_rec=6, start=dt.datetime(2024, 8, 9, 0, 0, 0)) + expected = data_gen.create_table(num_rec=6, start=dt.datetime(2024, 8, 9, 0, 0, 0, tzinfo=dt.timezone.utc)) print(f'The expected open_time values are:\n{expected["open_time"].tolist()}\n') # Assert that the open_time values in the result match those in the expected DataFrame, in order @@ -295,37 +319,50 @@ class TestDataCacheV2(unittest.TestCase): f"open_time values in result are {result['open_time'].tolist()}" \ f" but expected {expected['open_time'].tolist()}" - print(f'The results open_time values match:\n{result["open_time"].tolist()}\n') + print(f'The result open_time values match:\n{result["open_time"].tolist()}\n') print(' - Update cache with new records passed.') def test_update_cached_dict(self): print('Testing update_cached_dict() method:') + + # Set an empty dictionary in the cache for the specified key self.data.set_cache(data={}, key=self.key) + + # Update the cached dictionary with a new key-value pair self.data.update_cached_dict(cache_key=self.key, dict_key='sub_key', data='value') + # Retrieve the updated cache cache = self.data.get_cache(key=self.key) + + # Verify that the 'sub_key' in the cached dictionary has the correct value self.assertEqual(cache['sub_key'], 'value') print(' - Update dictionary in cache passed.') def test_get_cache(self): print('Testing get_cache() method:') + + # Set some data into the cache self.data.set_cache(data='data', key=self.key) + + # Retrieve the cached data using the get_cache method result = self.data.get_cache(key=self.key) + + # Verify that the result matches the data we set self.assertEqual(result, 'data') - print(' - Retrieve cache passed.') + print(' - Retrieve data passed.') def _test_get_records_since(self, set_cache=True, set_db=True, query_offset=None, num_rec=None, ex_details=None, simulate_scenarios=None): """ Test the get_records_since() method by generating a table of simulated data, - inserting it into cache and/or database, and then querying the records. + inserting it into data and/or database, and then querying the records. Parameters: set_cache (bool): If True, the generated table is inserted into the cache. set_db (bool): If True, the generated table is inserted into the database. query_offset (int, optional): The offset in the timeframe units for the query. num_rec (int, optional): The number of records to generate in the simulated table. - ex_details (list, optional): Exchange details to generate the cache key. + ex_details (list, optional): Exchange details to generate the data key. simulate_scenarios (str, optional): The type of scenario to simulate. Options are: - 'not_enough_data': The table data doesn't go far enough back. - 'incomplete_data': The table doesn't have enough records to satisfy the query. @@ -336,7 +373,7 @@ class TestDataCacheV2(unittest.TestCase): # Use provided ex_details or fallback to the class attribute. ex_details = ex_details or self.ex_details - # Generate a cache/database key using exchange details. + # Generate a data/database key using exchange details. key = f'{ex_details[0]}_{ex_details[1]}_{ex_details[2]}' # Set default number of records if not provided. @@ -374,13 +411,13 @@ class TestDataCacheV2(unittest.TestCase): print(f'Table Created:\n{temp_df}') if set_cache: - # Insert the generated table into cache. - print('Inserting table into cache.') + # Insert the generated table into the cache. + print('Inserting table into the cache.') self.data.set_cache(data=df_initial, key=key) if set_db: # Insert the generated table into the database. - print('Inserting table into database.') + print('Inserting table into the database.') with SQLite(self.db_file) as con: df_initial.to_sql(key, con, if_exists='replace', index=False) @@ -414,7 +451,7 @@ class TestDataCacheV2(unittest.TestCase): # Check that the result has more rows than the expected incomplete data. assert result.shape[0] > expected.shape[ 0], "Result has fewer or equal rows compared to the incomplete data." - print("\nThe returned DataFrames has filled in the missing data!") + print("\nThe returned DataFrame has filled in the missing data!") else: # Ensure the result and expected dataframes match in shape and content. assert result.shape == expected.shape, f"Shape mismatch: {result.shape} vs {expected.shape}" @@ -444,10 +481,10 @@ class TestDataCacheV2(unittest.TestCase): print(' - Fetch records within the specified time range passed.') def test_get_records_since(self): - print('\nTest get_records_since with records set in cache') + print('\nTest get_records_since with records set in data') self._test_get_records_since() - print('\nTest get_records_since with records not in cache') + print('\nTest get_records_since with records not in data') self._test_get_records_since(set_cache=False) print('\nTest get_records_since with records not in database') @@ -467,21 +504,21 @@ class TestDataCacheV2(unittest.TestCase): self._test_get_records_since(simulate_scenarios='missing_section') def test_other_timeframes(self): - # print('\nTest get_records_since with a different timeframe') - # ex_details = ['BTC/USD', '15m', 'binance', 'test_guy'] - # start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=2) - # # Query the records since the calculated start time. - # result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details) - # last_record_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC') - # assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=15.1) - # - # print('\nTest get_records_since with a different timeframe') - # ex_details = ['BTC/USD', '5m', 'binance', 'test_guy'] - # start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=1) - # # Query the records since the calculated start time. - # result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details) - # last_record_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC') - # assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=5.1) + print('\nTest get_records_since with a different timeframe') + ex_details = ['BTC/USD', '15m', 'binance', 'test_guy'] + start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=2) + # Query the records since the calculated start time. + result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details) + last_record_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC') + assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=15.1) + + print('\nTest get_records_since with a different timeframe') + ex_details = ['BTC/USD', '5m', 'binance', 'test_guy'] + start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=1) + # Query the records since the calculated start time. + result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details) + last_record_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC') + assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=5.1) print('\nTest get_records_since with a different timeframe') ex_details = ['BTC/USD', '4h', 'binance', 'test_guy'] @@ -506,16 +543,217 @@ class TestDataCacheV2(unittest.TestCase): def test_fetch_candles_from_exchange(self): print('Testing _fetch_candles_from_exchange() method:') - start_time = dt.datetime.utcnow() - dt.timedelta(days=1) - end_time = dt.datetime.utcnow() - result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h', exchange_name='binance', - user_name='test_guy', start_datetime=start_time, - end_datetime=end_time) + # Define start and end times for the data fetch + start_time = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc) - dt.timedelta(days=1) + end_time = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc) + + # Fetch the candles from the exchange using the method + result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h', + exchange_name='binance', user_name='test_guy', + start_datetime=start_time, end_datetime=end_time) + + # Validate that the result is a DataFrame self.assertIsInstance(result, pd.DataFrame) - self.assertFalse(result.empty) + + # Validate that the DataFrame is not empty + self.assertFalse(result.empty, "The DataFrame returned from the exchange is empty.") + + # Ensure that the 'open_time' column exists in the DataFrame + self.assertIn('open_time', result.columns, "'open_time' column is missing in the result DataFrame.") + + # Check if the DataFrame contains valid timestamps within the specified range + min_time = pd.to_datetime(result['open_time'].min(), unit='ms').tz_localize('UTC') + max_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC') + self.assertTrue(start_time <= min_time <= end_time, f"Data starts outside the expected range: {min_time}") + self.assertTrue(start_time <= max_time <= end_time, f"Data ends outside the expected range: {max_time}") + print(' - Fetch candle data from exchange passed.') + def test_remove_row(self): + print('Testing remove_row() method:') + + # Insert data into the cache with the expected columns + df = pd.DataFrame({ + 'key': [self.key], + 'data': ['test_data'] + }) + self.data.set_cache(data='test_data', key=self.key) + + # Ensure the data is in the cache + self.assertTrue(self.data.cache_exists(self.key), "Data was not correctly inserted into the cache.") + + # Remove the row from the cache only (soft delete) + self.data.remove_row(table='test_table_2', filter_vals=('key', self.key), remove_from_db=False) + + # Verify the row has been removed from the cache + self.assertFalse(self.data.cache_exists(self.key), "Row was not correctly removed from the cache.") + + # Reinsert the data for hard delete test + self.data.set_cache(data='test_data', key=self.key) + + # Mock database delete by adding the row to the database + self.data.db.insert_row(table='test_table_2', columns=('key', 'data'), values=(self.key, 'test_data')) + + # Remove the row from both cache and database (hard delete) + self.data.remove_row(table='test_table_2', filter_vals=('key', self.key), remove_from_db=True) + + # Verify the row has been removed from the cache + self.assertFalse(self.data.cache_exists(self.key), "Row was not correctly removed from the cache.") + + # Verify the row has been removed from the database + with SQLite(self.db_file) as con: + result = pd.read_sql(f'SELECT * FROM test_table_2 WHERE key="{self.key}"', con) + self.assertTrue(result.empty, "Row was not correctly removed from the database.") + + print(' - Remove row from cache and database passed.') + + def test_timeframe_to_timedelta(self): + print('Testing timeframe_to_timedelta() function:') + + result = timeframe_to_timedelta('2h') + expected = pd.Timedelta(hours=2) + self.assertEqual(result, expected, "Failed to convert '2h' to Timedelta") + + result = timeframe_to_timedelta('5m') + expected = pd.Timedelta(minutes=5) + self.assertEqual(result, expected, "Failed to convert '5m' to Timedelta") + + result = timeframe_to_timedelta('1d') + expected = pd.Timedelta(days=1) + self.assertEqual(result, expected, "Failed to convert '1d' to Timedelta") + + result = timeframe_to_timedelta('3M') + expected = pd.DateOffset(months=3) + self.assertEqual(result, expected, "Failed to convert '3M' to DateOffset") + + result = timeframe_to_timedelta('1Y') + expected = pd.DateOffset(years=1) + self.assertEqual(result, expected, "Failed to convert '1Y' to DateOffset") + + with self.assertRaises(ValueError): + timeframe_to_timedelta('5x') + + print(' - All timeframe_to_timedelta() tests passed.') + + def test_estimate_record_count(self): + print('Testing estimate_record_count() function:') + + start_time = dt.datetime(2023, 8, 1, 0, 0, 0, tzinfo=dt.timezone.utc) + end_time = dt.datetime(2023, 8, 2, 0, 0, 0, tzinfo=dt.timezone.utc) + + result = estimate_record_count(start_time, end_time, '1h') + expected = 24 + self.assertEqual(result, expected, "Failed to estimate record count for 1h timeframe") + + result = estimate_record_count(start_time, end_time, '1d') + expected = 1 + self.assertEqual(result, expected, "Failed to estimate record count for 1d timeframe") + + start_time = int(start_time.timestamp() * 1000) # Convert to milliseconds + end_time = int(end_time.timestamp() * 1000) # Convert to milliseconds + + result = estimate_record_count(start_time, end_time, '1h') + expected = 24 + self.assertEqual(result, expected, "Failed to estimate record count for 1h timeframe with milliseconds") + + with self.assertRaises(ValueError): + estimate_record_count("invalid_start", end_time, '1h') + + print(' - All estimate_record_count() tests passed.') + + def test_fetch_cached_rows(self): + print('Testing fetch_cached_rows() method:') + + # Set up mock data in the cache + df = pd.DataFrame({ + 'table': ['test_table_2'], + 'key': ['test_key'], + 'data': ['test_data'] + }) + self.data.cache = pd.concat([self.data.cache, df]) + + # Test fetching from cache + result = self.data.fetch_cached_rows('test_table_2', ('key', 'test_key')) + self.assertIsInstance(result, pd.DataFrame, "Failed to fetch DataFrame from cache") + self.assertFalse(result.empty, "The fetched DataFrame is empty") + self.assertEqual(result.iloc[0]['data'], 'test_data', "Incorrect data fetched from cache") + + # Test fetching from database (assuming the method calls it) + # Here we would typically mock the database call + # But since we're not doing I/O, we will skip that part + print(' - Fetch from cache and database simulated.') + + def test_is_attr_taken(self): + print('Testing is_attr_taken() method:') + + # Set up mock data in the cache + df = pd.DataFrame({ + 'table': ['users'], + 'user_name': ['test_user'], + 'data': ['test_data'] + }) + self.data.cache = pd.concat([self.data.cache, df]) + + # Test for existing attribute + result = self.data.is_attr_taken('users', 'user_name', 'test_user') + self.assertTrue(result, "Failed to detect existing attribute") + + # Test for non-existing attribute + result = self.data.is_attr_taken('users', 'user_name', 'non_existing_user') + self.assertFalse(result, "Incorrectly detected non-existing attribute") + + print(' - All is_attr_taken() tests passed.') + + def test_insert_data(self): + print('Testing insert_data() method:') + + # Create a DataFrame to insert + df = pd.DataFrame({ + 'key': ['new_key'], + 'data': ['new_data'] + }) + + # Insert data into the database and cache + self.data.insert_data(df=df, table='test_table_2') + + # Verify that the data was added to the cache + cached_value = self.data.get_cache('new_key') + self.assertEqual(cached_value, 'new_data', "Failed to insert data into cache") + + # Normally, we would also verify that the data was inserted into the database + # This would typically be done with a mock database or by checking the database state directly + print(' - Data insertion into cache and database simulated.') + + def test_insert_row(self): + print('Testing insert_row() method:') + + self.data.insert_row(table='test_table_2', columns=('key', 'data'), values=('test_key', 'test_data')) + + # Verify the row was inserted + with SQLite(self.db_file) as con: + result = pd.read_sql('SELECT * FROM test_table_2 WHERE key="test_key"', con) + self.assertFalse(result.empty, "Row was not inserted into the database.") + + print(' - Insert row passed.') + + def test_fill_data_holes(self): + print('Testing _fill_data_holes() method:') + + # Create mock data with gaps + df = pd.DataFrame({ + 'open_time': [dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc).timestamp() * 1000, + dt.datetime(2023, 1, 1, 2, tzinfo=dt.timezone.utc).timestamp() * 1000, + dt.datetime(2023, 1, 1, 6, tzinfo=dt.timezone.utc).timestamp() * 1000, + dt.datetime(2023, 1, 1, 8, tzinfo=dt.timezone.utc).timestamp() * 1000, + dt.datetime(2023, 1, 1, 12, tzinfo=dt.timezone.utc).timestamp() * 1000] + }) + + # Call the method + result = self.data._fill_data_holes(records=df, interval='2h') + self.assertEqual(len(result), 7, "Data holes were not filled correctly.") + print(' - _fill_data_holes passed.') + if __name__ == '__main__': unittest.main() diff --git a/tests/test_Users.py b/tests/test_Users.py index 3f33a2c..0f1b118 100644 --- a/tests/test_Users.py +++ b/tests/test_Users.py @@ -115,7 +115,7 @@ def test_log_out_all_users(): def test_load_user_data(): # Todo method incomplete - result = config.users.load_user_data(user_name='RobbieD') + result = config.users.get_user_data(user_name='RobbieD') print('\n Test result:', result) assert result is None