Refactored DataCache, Database and Users. All db interactions are now all inside Database. All data request from Users now go through DataCache. Expanded DataCache test to include all methods. All DataCache tests pass.

This commit is contained in:
Rob 2024-08-19 23:10:13 -03:00
parent fc407708ba
commit 8361efd965
16 changed files with 1383 additions and 776 deletions

2
.gitignore vendored
View File

@ -4,7 +4,7 @@
# Ignore Flask session data # Ignore Flask session data
flask_session/ flask_session/
# Ignore testing cache # Ignore testing data
.pytest_cache/ .pytest_cache/
# Ignore databases # Ignore databases

View File

@ -1,6 +1,6 @@
from typing import Any from typing import Any
from Users import Users
from DataCache_v2 import DataCache from DataCache_v2 import DataCache
from Strategies import Strategies from Strategies import Strategies
from backtesting import Backtester from backtesting import Backtester
@ -20,25 +20,29 @@ class BrighterTrades:
# Object that interacts with the persistent data. # Object that interacts with the persistent data.
self.data = DataCache(self.exchanges) self.data = DataCache(self.exchanges)
# Configuration and settings for the user app and charts # Configuration for the app
self.config = Configuration(cache=self.data) self.config = Configuration()
# Object that maintains signals. Initialize with any signals loaded from file. # The object that manages users in the system.
self.signals = Signals(self.config.signals_list) self.users = Users(data_cache=self.data)
# Object that maintains signals.
self.signals = Signals(self.config)
# Object that maintains candlestick and price data. # Object that maintains candlestick and price data.
self.candles = Candles(config_obj=self.config, exchanges=self.exchanges, data_source=self.data) self.candles = Candles(users=self.users, exchanges=self.exchanges, data_source=self.data,
config=self.config)
# Object that interacts with and maintains data from available indicators # Object that interacts with and maintains data from available indicators
self.indicators = Indicators(self.candles, self.config) self.indicators = Indicators(self.candles, self.users)
# Object that maintains the trades data # Object that maintains the trades data
self.trades = Trades(self.config.trades) self.trades = Trades(self.users)
# The Trades object needs to connect to an exchange_interface. # The Trades object needs to connect to an exchange_interface.
self.trades.connect_exchanges(exchanges=self.exchanges) self.trades.connect_exchanges(exchanges=self.exchanges)
# Object that maintains the strategies data # Object that maintains the strategies data
self.strategies = Strategies(self.config.strategies_list, self.trades) self.strategies = Strategies(self.data, self.trades)
# Object responsible for testing trade and strategies data. # Object responsible for testing trade and strategies data.
self.backtester = Backtester() self.backtester = Backtester()
@ -56,8 +60,8 @@ class BrighterTrades:
raise ValueError("Missing required arguments for 'create_new_user'") raise ValueError("Missing required arguments for 'create_new_user'")
try: try:
self.config.users.create_new_user(email=email, username=username, password=password) self.users.create_new_user(email=email, username=username, password=password)
login_successful = self.config.users.log_in_user(username=username, password=password) login_successful = self.users.log_in_user(username=username, password=password)
return login_successful return login_successful
except Exception as e: except Exception as e:
# Handle specific exceptions or log the error # Handle specific exceptions or log the error
@ -77,11 +81,11 @@ class BrighterTrades:
try: try:
if cmd == 'logout': if cmd == 'logout':
return self.config.users.log_out_user(username=user_name) return self.users.log_out_user(username=user_name)
elif cmd == 'login': elif cmd == 'login':
if password is None: if password is None:
raise ValueError("Password is required for login.") raise ValueError("Password is required for login.")
return self.config.users.log_in_user(username=user_name, password=password) return self.users.log_in_user(username=user_name, password=password)
except Exception as e: except Exception as e:
# Handle specific exceptions or log the error # Handle specific exceptions or log the error
raise ValueError("Error during user login/logout: " + str(e)) raise ValueError("Error during user login/logout: " + str(e))
@ -98,19 +102,19 @@ class BrighterTrades:
if info == 'Chart View': if info == 'Chart View':
try: try:
return self.config.users.get_chart_view(user_name=user_name) return self.users.get_chart_view(user_name=user_name)
except Exception as e: except Exception as e:
# Handle specific exceptions or log the error # Handle specific exceptions or log the error
raise ValueError("Error retrieving chart view information: " + str(e)) raise ValueError("Error retrieving chart view information: " + str(e))
elif info == 'Is logged in?': elif info == 'Is logged in?':
try: try:
return self.config.users.is_logged_in(user_name=user_name) return self.users.is_logged_in(user_name=user_name)
except Exception as e: except Exception as e:
# Handle specific exceptions or log the error # Handle specific exceptions or log the error
raise ValueError("Error checking logged in status: " + str(e)) raise ValueError("Error checking logged in status: " + str(e))
elif info == 'User_id': elif info == 'User_id':
try: try:
return self.config.users.get_id(user_name=user_name) return self.users.get_id(user_name=user_name)
except Exception as e: except Exception as e:
# Handle specific exceptions or log the error # Handle specific exceptions or log the error
raise ValueError("Error fetching id: " + str(e)) raise ValueError("Error fetching id: " + str(e))
@ -181,10 +185,10 @@ class BrighterTrades:
:param default_keys: default API keys. :param default_keys: default API keys.
:return: bool - True on success. :return: bool - True on success.
""" """
active_exchanges = self.config.users.get_exchanges(user_name, category='active_exchanges') active_exchanges = self.users.get_exchanges(user_name, category='active_exchanges')
success = False success = False
for exchange in active_exchanges: for exchange in active_exchanges:
keys = self.config.users.get_api_keys(user_name, exchange) keys = self.users.get_api_keys(user_name, exchange)
success = self.connect_or_config_exchange(user_name=user_name, success = self.connect_or_config_exchange(user_name=user_name,
exchange_name=exchange, exchange_name=exchange,
api_keys=keys) api_keys=keys)
@ -202,7 +206,7 @@ class BrighterTrades:
:param user_name: str - The name of the user making the query. :param user_name: str - The name of the user making the query.
""" """
chart_view = self.config.users.get_chart_view(user_name=user_name) chart_view = self.users.get_chart_view(user_name=user_name)
indicator_types = self.indicators.indicator_types indicator_types = self.indicators.indicator_types
available_indicators = self.indicators.get_indicator_list(user_name) available_indicators = self.indicators.get_indicator_list(user_name)
@ -231,19 +235,20 @@ class BrighterTrades:
:return: A dictionary containing the requested data. :return: A dictionary containing the requested data.
""" """
chart_view = self.config.users.get_chart_view(user_name=user_name) chart_view = self.users.get_chart_view(user_name=user_name)
exchange = self.exchanges.get_exchange(ename=chart_view.get('exchange'), uname=user_name) exchange = self.exchanges.get_exchange(ename=chart_view.get('exchange'), uname=user_name)
# noinspection PyDictCreation # noinspection PyDictCreation
r_data = {} r_data = {}
r_data['title'] = self.config.app_data.get('application_title', '') r_data['title'] = self.config.get_setting('application_title')
r_data['chart_interval'] = chart_view.get('timeframe', '') r_data['chart_interval'] = chart_view.get('timeframe', '')
r_data['selected_exchange'] = chart_view.get('exchange', '') r_data['selected_exchange'] = chart_view.get('exchange', '')
r_data['intervals'] = exchange.intervals if exchange else [] r_data['intervals'] = exchange.intervals if exchange else []
r_data['symbols'] = exchange.get_symbols() if exchange else {} r_data['symbols'] = exchange.get_symbols() if exchange else {}
r_data['available_exchanges'] = self.exchanges.get_available_exchanges() or [] r_data['available_exchanges'] = self.exchanges.get_available_exchanges() or []
r_data['connected_exchanges'] = self.exchanges.get_connected_exchanges(user_name) or [] r_data['connected_exchanges'] = self.exchanges.get_connected_exchanges(user_name) or []
r_data['configured_exchanges'] = self.config.users.get_exchanges(user_name, category='configured_exchanges') or [] r_data['configured_exchanges'] = self.users.get_exchanges(
user_name, category='configured_exchanges') or []
r_data['my_balances'] = self.exchanges.get_all_balances(user_name) or {} r_data['my_balances'] = self.exchanges.get_all_balances(user_name) or {}
r_data['indicator_types'] = self.indicators.indicator_types or [] r_data['indicator_types'] = self.indicators.indicator_types or []
r_data['indicator_list'] = self.indicators.get_indicator_list(user_name) or [] r_data['indicator_list'] = self.indicators.get_indicator_list(user_name) or []
@ -295,7 +300,7 @@ class BrighterTrades:
return "The new signal must have a 'name' attribute." return "The new signal must have a 'name' attribute."
self.signals.new_signal(data) self.signals.new_signal(data)
self.config.update_data('signals', self.signals.get_signals('dict')) self.config.set_setting('signals_list', self.signals.get_signals('dict'))
return data return data
def received_new_strategy(self, data: dict) -> str | dict: def received_new_strategy(self, data: dict) -> str | dict:
@ -309,7 +314,7 @@ class BrighterTrades:
return "The new strategy must have a 'name' attribute." return "The new strategy must have a 'name' attribute."
self.strategies.new_strategy(data) self.strategies.new_strategy(data)
self.config.update_data('strategies', self.strategies.get_strategies('dict')) self.config.set_setting('strategies', self.strategies.get_strategies('dict'))
return data return data
def delete_strategy(self, strategy_name: str) -> None: def delete_strategy(self, strategy_name: str) -> None:
@ -377,10 +382,9 @@ class BrighterTrades:
success = self.exchanges.connect_exchange(exchange_name=exchange_name, user_name=user_name, success = self.exchanges.connect_exchange(exchange_name=exchange_name, user_name=user_name,
api_keys=api_keys) api_keys=api_keys)
if success: if success:
self.config.users.active_exchange(exchange=exchange_name, user_name=user_name, cmd='set') self.users.active_exchange(exchange=exchange_name, user_name=user_name, cmd='set')
if api_keys: if api_keys:
self.config.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, self.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, user_name=user_name)
user_name=user_name)
return True return True
else: else:
return False # Failed to connect return False # Failed to connect
@ -391,7 +395,7 @@ class BrighterTrades:
else: else:
# Exchange is already connected, update API keys if provided # Exchange is already connected, update API keys if provided
if api_keys: if api_keys:
self.config.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, user_name=user_name) self.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, user_name=user_name)
return True # Already connected return True # Already connected
def close_trade(self, trade_id): def close_trade(self, trade_id):
@ -470,19 +474,19 @@ class BrighterTrades:
if setting == 'interval': if setting == 'interval':
interval_state = params['timeframe'] interval_state = params['timeframe']
self.config.users.set_chart_view(values=interval_state, specific_property='timeframe', user_name=user_name) self.users.set_chart_view(values=interval_state, specific_property='timeframe', user_name=user_name)
elif setting == 'trading_pair': elif setting == 'trading_pair':
trading_pair = params['symbol'] trading_pair = params['symbol']
self.config.users.set_chart_view(values=trading_pair, specific_property='market', user_name=user_name) self.users.set_chart_view(values=trading_pair, specific_property='market', user_name=user_name)
elif setting == 'exchange': elif setting == 'exchange':
exchange_name = params['exchange_name'] exchange_name = params['exchange_name']
# Get the first result of a list of available symbols from this exchange_name. # Get the first result of a list of available symbols from this exchange_name.
market = self.exchanges.get_exchange(ename=exchange_name, uname=user_name).get_symbols()[0] market = self.exchanges.get_exchange(ename=exchange_name, uname=user_name).get_symbols()[0]
# Change the exchange in the chart view and pass it a default market to display. # Change the exchange in the chart view and pass it a default market to display.
self.config.users.set_chart_view(values=exchange_name, specific_property='exchange_name', self.users.set_chart_view(values=exchange_name, specific_property='exchange_name',
user_name=user_name, default_market=market) user_name=user_name, default_market=market)
elif setting == 'toggle_indicator': elif setting == 'toggle_indicator':
indicators_to_toggle = params.getlist('indicator') indicators_to_toggle = params.getlist('indicator')

View File

@ -1,130 +1,279 @@
import pandas
import yaml import yaml
from Signals import Signals import os
from indicators import Indicators import time
from Users import Users import logging
from typing import Any
logger = logging.getLogger(__name__)
class Configuration: class Configuration:
def __init__(self, cache): """
# ************** Default values************** Manages the application's settings, loading and saving them to a YAML file.
- Automatically loads the settings from file when instantiated.
- Settings can be added or modified with set_setting('key', value)
or by editing the file directly.
"""
self.app_data = { def __init__(self, config_file='config.yml'):
'application_title': 'BrighterTrades', # The title of our program. """Initializes with default settings and loads saved data."""
'max_data_loaded': 1000 # The maximum number of candles to store in memory. self.config_file = config_file
self.save_in_progress = False # Persistent flag to prevent infinite recursion
self._set_default_settings()
self.manage_config('load')
def _set_default_settings(self):
"""Set default settings."""
self.settings = {
'application_title': 'BrighterTrades',
'max_data_loaded': 1000
} }
# The object that interacts with the database. def _generate_config_data(self):
self.data = cache.db """Generates a list of settings to be saved to the config file."""
return self.settings.copy()
# The object that manages users in the system. def manage_config(self, cmd: str):
self.users = Users(data_cache=cache)
# The name of the file that stores saved_data
self.config_FN = 'config.yml'
# A list of all the available Signals.
# Calls a static method of Signals that initializes a default list.
self.signals_list = Signals.get_signals_defaults()
# list of all the available strategies.
self.strategies_list = []
# list of trades.
self.trades = []
# The data that will be saved and loaded from file .
self.saved_data = None
# Load any saved data from file
self.config_and_states('load')
def update_data(self, data_type, data):
""" """
Replace current list of data sets with an updated list. Loads or saves settings to the config file according to cmd: 'load' | 'save'
:param data_type: The data being replaced
:param data: The replacement data.
:return: None.
""" """
if data_type == 'strategies':
self.strategies_list = data
elif data_type == 'signals':
self.signals_list = data
elif data_type == 'trades':
self.trades = data
else:
raise ValueError(f'Configuration: update_data(): Unsupported data_type: {data_type}')
# Save it to file.
self.config_and_states('save')
def remove(self, what, name): def _apply_loaded_settings(data):
# Removes by name an item from a list in saved data. """Updates settings from loaded data and marks if saving is needed."""
needs_save = False
for key, value in data.items():
if key not in self.settings or self.settings[key] != value:
self.settings[key] = value
needs_save = True
return needs_save
# Trades are indexed by unique_id, while Signals and Strategies are indexed by name. def _load_config_from_file(filepath):
if what == 'trades': try:
prop = 'unique_id' with open(filepath, "r") as file_descriptor:
else: return yaml.safe_load(file_descriptor)
prop = 'name' except yaml.YAMLError:
timestamp = time.strftime("%Y%m%d-%H%M%S")
backup_path = f"{filepath}.{timestamp}.backup"
os.rename(filepath, backup_path)
logging.warning(f"Corrupt YAML file detected. Backup saved to {backup_path}")
logging.info(f"Recreating the configuration file with default settings.")
return None
for obj in self.saved_data[what]: def _save_config_to_file(filepath, data):
if obj[prop] == name: try:
self.saved_data[what].remove(obj) with open(filepath, "w") as file_descriptor:
break yaml.dump(data, file_descriptor)
# Save it to file. except (IOError, OSError) as e:
self.config_and_states('save') logging.error(f"Failed to save configuration to {filepath}: {e}")
raise ValueError(f"Failed to save configuration to {filepath}: {e}")
def config_and_states(self, cmd):
"""Loads or saves configurable data to the file set in self.config_FN"""
# The data stored and retrieved from file session.
self.saved_data = {
'signals': self.signals_list,
'strategies': self.strategies_list,
'trades': self.trades
}
def set_loaded_values():
# Sets the values in the saved_data object.
if 'signals' in self.saved_data:
self.signals_list = self.saved_data['signals']
if 'strategies' in self.saved_data:
self.strategies_list = self.saved_data['strategies']
if 'trades' in self.saved_data:
self.trades = self.saved_data['trades']
def load_configuration(filepath):
"""load file data"""
with open(filepath, "r") as file_descriptor:
data = yaml.safe_load(file_descriptor)
return data
def save_configuration(filepath, data):
"""Saves file data"""
with open(filepath, "w") as file_descriptor:
yaml.dump(data, file_descriptor)
if cmd == 'load': if cmd == 'load':
# If load_configuration() finds a file it overwrites if os.path.exists(self.config_file):
# the saved_data object otherwise it creates a new file loaded_data = _load_config_from_file(self.config_file)
# with the defaults contained in saved_data> if loaded_data is None: # Corrupt file case, recreate it
_save_config_to_file(self.config_file, self._generate_config_data())
# If file exist load the values. elif _apply_loaded_settings(loaded_data) and not self.save_in_progress:
try: self.save_in_progress = True
self.saved_data = load_configuration(self.config_FN) self.manage_config('save')
set_loaded_values() self.save_in_progress = False
# If file doesn't exist create a file and save the default values. else:
except IOError: logging.info(f"Configuration file not found. Creating a new one at {self.config_file}.")
save_configuration(self.config_FN, self.saved_data) _save_config_to_file(self.config_file, self._generate_config_data())
elif cmd == 'save': elif cmd == 'save':
try: _save_config_to_file(self.config_file, self._generate_config_data())
# Write saved_data to the file.
save_configuration(self.config_FN, self.saved_data)
except IOError:
raise ValueError("save_configuration(): Couldn't save the file.")
else: else:
raise ValueError('save_configuration(): Invalid command received.') raise ValueError('manage_config(): Invalid command.')
def reset_settings_to_defaults(self):
"""Resets settings to default values."""
self._set_default_settings()
self.manage_config('save')
def get_setting(self, key: str) -> Any:
"""
Returns the value of the specified setting or None if the key is not found.
"""
return self.settings.get(key, None)
def set_setting(self, key: str, value: Any):
"""
Receives a key and value of any setting and saves the configuration.
"""
self.settings[key] = value
self.manage_config('save')
# import yaml
# from Signals import Signals
#
#
# class Configuration:
# """
# Configuration class manages the application's settings,
# signals, strategies, and trades. It loads and saves these
# configurations to a YAML file.
#
# Attributes:
# app_data (dict): Default application settings like title and maximum data load.
# data (object): Database interaction object from the data.
# config_FN (str): Filename for storing and loading configurations.
# signals_list (list): List of available signals initialized by the Signals class.
# strategies_list (list): List of strategies available for the application.
# trades (list): List of trades managed by the application.
# saved_data (dict): Data structure for saving/loading configuration data.
# """
#
# def __init__(self, data):
# """
# Initializes the Configuration object with default values and loads saved data.
#
# Args:
# data (object): Cache object with a database attribute to interact with the database.
# """
# # ************** Default values **************
# # Application metadata such as title and maximum data to be loaded.
# self.app_data = {
# 'application_title': 'BrighterTrades', # The title of the program.
# 'max_data_loaded': 1000 # Maximum number of candles to store in memory.
# }
#
# # The database object for interaction.
# self.data = data.db
#
# # Name of the configuration file.
# self.config_FN = 'config.yml'
#
# # List of available signals initialized with default values.
# self.signals_list = Signals.get_signals_defaults()
#
# # List to hold available strategies.
# self.strategies_list = []
#
# # List to hold trades.
# self.trades = []
#
# # Placeholder for data loaded from or to be saved to the file.
# self.saved_data = None
#
# # Load any saved data from the configuration file.
# self.config_and_states('load')
#
# def update_data(self, data_type, data):
# """
# Replace the current list of data sets with an updated list.
#
# Args:
# data_type (str): Type of data to be updated ('strategies', 'signals', or 'trades').
# data (list): The new data to replace the old one.
#
# Raises:
# ValueError: If the provided data_type is not supported.
# """
# if data_type == 'strategies':
# self.strategies_list = data
# elif data_type == 'signals':
# self.signals_list = data
# elif data_type == 'trades':
# self.trades = data
# else:
# raise ValueError(f'Configuration: update_data(): Unsupported data_type: {data_type}')
#
# # Save the updated data to the configuration file.
# self.config_and_states('save')
#
# def remove(self, what, name):
# """
# Removes an item by name from the saved data list.
#
# Args:
# what (str): Type of data to remove ('strategies', 'signals', or 'trades').
# name (str): The name or unique_id of the item to remove.
#
# Raises:
# ValueError: If the item with the specified name is not found.
# """
# # Determine the property to match based on the type of data.
# if what == 'trades':
# prop = 'unique_id'
# else:
# prop = 'name'
#
# # Remove the item from the list if it matches the name or unique_id.
# for obj in self.saved_data[what]:
# if obj[prop] == name:
# self.saved_data[what].remove(obj)
# break
#
# # Save the updated data to the configuration file.
# self.config_and_states('save')
#
# def config_and_states(self, cmd):
# """
# Loads or saves configurable data to the file set in self.config_FN.
#
# Args:
# cmd (str): Command to either 'load' or 'save' the configuration data.
#
# Raises:
# ValueError: If the command is neither 'load' nor 'save'.
# """
#
# # Data structure to hold the current state of signals, strategies, and trades.
# self.saved_data = {
# 'signals': self.signals_list,
# 'strategies': self.strategies_list,
# 'trades': self.trades
# }
#
# def set_loaded_values():
# """Sets the values in the saved_data object to the class attributes."""
# if 'signals' in self.saved_data:
# self.signals_list = self.saved_data['signals']
#
# if 'strategies' in self.saved_data:
# self.strategies_list = self.saved_data['strategies']
#
# if 'trades' in self.saved_data:
# self.trades = self.saved_data['trades']
#
# def load_configuration(filepath):
# """
# Load configuration data from a YAML file.
#
# Args:
# filepath (str): Path to the configuration file.
#
# Returns:
# dict: Loaded configuration data.
# """
# with open(filepath, "r") as file_descriptor:
# data = yaml.safe_load(file_descriptor)
# return data
#
# def save_configuration(filepath, data):
# """
# Save configuration data to a YAML file.
#
# Args:
# filepath (str): Path to the configuration file.
# data (dict): Data to save.
# """
# with open(filepath, "w") as file_descriptor:
# yaml.dump(data, file_descriptor)
#
# if cmd == 'load':
# try:
# # Attempt to load the configuration from the file.
# self.saved_data = load_configuration(self.config_FN)
# set_loaded_values()
# except IOError:
# # If the file doesn't exist, save the default values.
# save_configuration(self.config_FN, self.saved_data)
#
# elif cmd == 'save':
# try:
# # Save the current state to the configuration file.
# save_configuration(self.config_FN, self.saved_data)
# except IOError:
# raise ValueError("save_configuration(): Couldn't save the file.")
# else:
# raise ValueError('save_configuration(): Invalid command received.')

View File

@ -1,4 +1,5 @@
from typing import List, Any import json
from typing import List, Any, Tuple
import pandas as pd import pandas as pd
import datetime as dt import datetime as dt
import logging import logging
@ -59,23 +60,31 @@ def estimate_record_count(start_time, end_time, timeframe: str) -> int:
return int(expected_records) return int(expected_records)
def generate_expected_timestamps(start_datetime: dt.datetime, end_datetime: dt.datetime, # def generate_expected_timestamps(start_datetime: dt.datetime, end_datetime: dt.datetime,
timeframe: str) -> pd.DatetimeIndex: # timeframe: str) -> pd.DatetimeIndex:
if start_datetime.tzinfo is None: # """
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") # What it says. Todo: confirm this is unused and archive.
if end_datetime.tzinfo is None: #
raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.") # :param start_datetime:
# :param end_datetime:
delta = timeframe_to_timedelta(timeframe) # :param timeframe:
if isinstance(delta, pd.Timedelta): # :return:
return pd.date_range(start=start_datetime, end=end_datetime, freq=delta) # """
elif isinstance(delta, pd.DateOffset): # if start_datetime.tzinfo is None:
current = start_datetime # raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
timestamps = [] # if end_datetime.tzinfo is None:
while current <= end_datetime: # raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.")
timestamps.append(current) #
current += delta # delta = timeframe_to_timedelta(timeframe)
return pd.DatetimeIndex(timestamps) # if isinstance(delta, pd.Timedelta):
# return pd.date_range(start=start_datetime, end=end_datetime, freq=delta)
# elif isinstance(delta, pd.DateOffset):
# current = start_datetime
# timestamps = []
# while current <= end_datetime:
# timestamps.append(current)
# current += delta
# return pd.DatetimeIndex(timestamps)
class DataCache: class DataCache:
@ -84,10 +93,170 @@ class DataCache:
def __init__(self, exchanges): def __init__(self, exchanges):
self.db = Database() self.db = Database()
self.exchanges = exchanges self.exchanges = exchanges
self.cached_data = {} # Single DataFrame for all cached data
self.cache = pd.DataFrame(columns=['key', 'data']) # Assuming 'key' and 'data' are necessary
logger.info("DataCache initialized.") logger.info("DataCache initialized.")
def fetch_cached_rows(self, table: str, filter_vals: Tuple[str, Any]) -> pd.DataFrame | None:
"""
Retrieves rows from the cache if available; otherwise, queries the database and caches the result.
:param table: Name of the database table to query.
:param filter_vals: A tuple containing the column name and the value to filter by.
:return: A DataFrame containing the requested rows, or None if no matching rows are found.
"""
# Construct a filter condition for the cache based on the table name and filter values.
cache_filter = (self.cache['table'] == table) & (self.cache[filter_vals[0]] == filter_vals[1])
cached_rows = self.cache[cache_filter]
# If the data is found in the cache, return it.
if not cached_rows.empty:
return cached_rows
# If the data is not found in the cache, query the database.
rows = self.db.get_rows_where(table, filter_vals)
if rows is not None:
# Tag the rows with the table name and add them to the cache.
rows['table'] = table
self.cache = pd.concat([self.cache, rows])
return rows
def remove_row(self, filter_vals: Tuple[str, Any], additional_filter: Tuple[str, Any] = None,
remove_from_db: bool = True, table: str = None) -> None:
"""
Removes a specific row from the cache and optionally from the database based on filter criteria.
:param filter_vals: A tuple containing the column name and the value to filter by.
:param additional_filter: An optional additional filter to apply.
:param remove_from_db: If True, also removes the row from the database. Default is True.
:param table: The name of the table from which to remove the row in the database (optional).
"""
logger.debug(
f"Removing row from cache: filter={filter_vals},"
f" additional_filter={additional_filter}, remove_from_db={remove_from_db}, table={table}")
# Construct the filter condition for the cache
cache_filter = (self.cache[filter_vals[0]] == filter_vals[1])
if additional_filter:
cache_filter = cache_filter & (self.cache[additional_filter[0]] == additional_filter[1])
# Remove the row from the cache
self.cache = self.cache.drop(self.cache[cache_filter].index)
logger.info(f"Row removed from cache: filter={filter_vals}")
if remove_from_db and table:
# Construct the SQL query to delete from the database
sql = f"DELETE FROM {table} WHERE {filter_vals[0]} = ?"
params = [filter_vals[1]]
if additional_filter:
sql += f" AND {additional_filter[0]} = ?"
params.append(additional_filter[1])
# Execute the SQL query to remove the row from the database
self.db.execute_sql(sql, tuple(params))
logger.info(
f"Row removed from database: table={table}, filter={filter_vals},"
f" additional_filter={additional_filter}")
def is_attr_taken(self, table: str, attr: str, val: Any) -> bool:
"""
Checks if a specific attribute in a table is already taken.
:param table: The name of the table to check.
:param attr: The attribute to check (e.g., 'user_name', 'email').
:param val: The value of the attribute to check.
:return: True if the attribute is already taken, False otherwise.
"""
# Fetch rows from the specified table where the attribute matches the given value
result = self.fetch_cached_rows(table=table, filter_vals=(attr, val))
return result is not None and not result.empty
def fetch_cached_item(self, item_name: str, table_name: str, filter_vals: Tuple[str, Any]) -> Any:
"""
Retrieves a specific item from the cache or database, caching the result if necessary.
:param item_name: The name of the column to retrieve.
:param table_name: The name of the table where the item is stored.
:param filter_vals: A tuple containing the column name and the value to filter by.
:return: The value of the requested item.
:raises ValueError: If the item is not found in either the cache or the database.
"""
# Fetch the relevant rows
rows = self.fetch_cached_rows(table_name, filter_vals)
if rows is not None and not rows.empty:
# Return the specific item from the first matching row.
return rows.iloc[0][item_name]
# If the item is not found, raise an error.
raise ValueError(f"Item {item_name} not found in {table_name} where {filter_vals[0]} = {filter_vals[1]}")
def modify_cached_row(self, table: str, filter_vals: Tuple[str, Any], field_name: str, new_data: Any) -> None:
"""
Modifies a specific field in a row within the cache and updates the database accordingly.
:param table: The name of the table where the data is stored.
:param filter_vals: A tuple containing the column name and the value to filter by.
:param field_name: The field to be updated.
:param new_data: The new data to be set.
"""
# Retrieve the row from the cache or database
row = self.fetch_cached_rows(table, filter_vals)
if row is None or row.empty:
raise ValueError(f"Row not found in cache or database for {filter_vals[0]} = {filter_vals[1]}")
# Modify the specified field
if isinstance(new_data, str):
row.loc[0, field_name] = new_data
else:
# If new_data is not a string, its converted to a JSON string before being inserted into the DataFrame.
row.loc[0, field_name] = json.dumps(new_data)
# Update the cache by removing the old entry and adding the modified row
self.cache = self.cache.drop(
self.cache[(self.cache['table'] == table) & (self.cache[filter_vals[0]] == filter_vals[1])].index
)
self.cache = pd.concat([self.cache, row])
# Update the database with the modified row
self.db.insert_dataframe(row.drop(columns='id'), table)
def insert_data(self, df: pd.DataFrame, table: str, skip_cache: bool = False) -> None:
"""
Inserts data into the specified table in the database, with an option to skip cache insertion.
:param df: The DataFrame containing the data to insert.
:param table: The name of the table where the data should be inserted.
:param skip_cache: If True, skips inserting the data into the cache. Default is False.
"""
# Insert the data into the database
self.db.insert_dataframe(df=df, table=table)
# Optionally insert the data into the cache
if not skip_cache:
df['table'] = table # Add table name for cache identification
self.cache = pd.concat([self.cache, df])
def insert_row(self, table: str, columns: tuple, values: tuple) -> None:
"""
Inserts a single row into the specified table in the database.
:param table: The name of the table where the row should be inserted.
:param columns: A tuple of column names corresponding to the values.
:param values: A tuple of values to insert into the specified columns.
"""
self.db.insert_row(table=table, columns=columns, values=values)
def get_records_since(self, start_datetime: dt.datetime, ex_details: List[str]) -> pd.DataFrame: def get_records_since(self, start_datetime: dt.datetime, ex_details: List[str]) -> pd.DataFrame:
"""
This gets up-to-date records from a specified market and exchange.
:param start_datetime: The approximate time the first record should represent.
:param ex_details: The user exchange and market.
:return: The records.
"""
if self.TYPECHECKING_ENABLED: if self.TYPECHECKING_ENABLED:
if not isinstance(start_datetime, dt.datetime): if not isinstance(start_datetime, dt.datetime):
raise TypeError("start_datetime must be a datetime object") raise TypeError("start_datetime must be a datetime object")
@ -106,12 +275,20 @@ class DataCache:
'end_datetime': end_datetime, 'end_datetime': end_datetime,
'ex_details': ex_details, 'ex_details': ex_details,
} }
return self.get_or_fetch_from('cache', **args) return self._get_or_fetch_from('data', **args)
except Exception as e: except Exception as e:
logger.error(f"An error occurred: {str(e)}") logger.error(f"An error occurred: {str(e)}")
raise raise
def get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame: def _get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame:
"""
Fetches market records from a resource stack (data, database, exchange).
fills incomplete request by fetching down the stack then updates the rest.
:param target: Starting point for the fetch. ['data', 'database', 'exchange']
:param kwargs: Details and credentials for the request.
:return: Records in a dataframe.
"""
start_datetime = kwargs.get('start_datetime') start_datetime = kwargs.get('start_datetime')
if start_datetime.tzinfo is None: if start_datetime.tzinfo is None:
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
@ -144,12 +321,12 @@ class DataCache:
key = self._make_key(ex_details) key = self._make_key(ex_details)
combined_data = pd.DataFrame() combined_data = pd.DataFrame()
if target == 'cache': if target == 'data':
resources = [self.get_candles_from_cache, self.get_from_database, self.get_from_server] resources = [self._get_candles_from_cache, self._get_from_database, self._get_from_server]
elif target == 'database': elif target == 'database':
resources = [self.get_from_database, self.get_from_server] resources = [self._get_from_database, self._get_from_server]
elif target == 'server': elif target == 'server':
resources = [self.get_from_server] resources = [self._get_from_server]
else: else:
raise ValueError('Not a valid Target!') raise ValueError('Not a valid Target!')
@ -165,11 +342,11 @@ class DataCache:
combined_data = pd.concat([combined_data, result]).drop_duplicates(subset='open_time').sort_values( combined_data = pd.concat([combined_data, result]).drop_duplicates(subset='open_time').sort_values(
by='open_time') by='open_time')
is_complete, request_criteria = self.data_complete(combined_data, **request_criteria) is_complete, request_criteria = self._data_complete(combined_data, **request_criteria)
if is_complete: if is_complete:
if fetch_method in [self.get_from_database, self.get_from_server]: if fetch_method in [self._get_from_database, self._get_from_server]:
self.update_candle_cache(combined_data, key) self._update_candle_cache(combined_data, key)
if fetch_method == self.get_from_server: if fetch_method == self._get_from_server:
self._populate_db(ex_details, combined_data) self._populate_db(ex_details, combined_data)
return combined_data return combined_data
@ -178,7 +355,7 @@ class DataCache:
logger.error('Unable to fetch the requested data.') logger.error('Unable to fetch the requested data.')
return combined_data if not combined_data.empty else pd.DataFrame() return combined_data if not combined_data.empty else pd.DataFrame()
def get_candles_from_cache(self, **kwargs) -> pd.DataFrame: def _get_candles_from_cache(self, **kwargs) -> pd.DataFrame:
start_datetime = kwargs.get('start_datetime') start_datetime = kwargs.get('start_datetime')
if start_datetime.tzinfo is None: if start_datetime.tzinfo is None:
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
@ -200,7 +377,7 @@ class DataCache:
raise ValueError("Missing required arguments") raise ValueError("Missing required arguments")
key = self._make_key(ex_details) key = self._make_key(ex_details)
logger.debug('Getting records from cache.') logger.debug('Getting records from data.')
df = self.get_cache(key) df = self.get_cache(key)
if df is None: if df is None:
logger.debug("Cache records didn't exist.") logger.debug("Cache records didn't exist.")
@ -210,7 +387,7 @@ class DataCache:
df['open_time'] <= unix_time_millis(end_datetime))].reset_index(drop=True) df['open_time'] <= unix_time_millis(end_datetime))].reset_index(drop=True)
return df_filtered return df_filtered
def get_from_database(self, **kwargs) -> pd.DataFrame: def _get_from_database(self, **kwargs) -> pd.DataFrame:
start_datetime = kwargs.get('start_datetime') start_datetime = kwargs.get('start_datetime')
if start_datetime.tzinfo is None: if start_datetime.tzinfo is None:
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.") raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
@ -240,7 +417,7 @@ class DataCache:
return self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=start_datetime, return self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=start_datetime,
et=end_datetime) et=end_datetime)
def get_from_server(self, **kwargs) -> pd.DataFrame: def _get_from_server(self, **kwargs) -> pd.DataFrame:
symbol = kwargs.get('ex_details')[0] symbol = kwargs.get('ex_details')[0]
interval = kwargs.get('ex_details')[1] interval = kwargs.get('ex_details')[1]
exchange_name = kwargs.get('ex_details')[2] exchange_name = kwargs.get('ex_details')[2]
@ -272,7 +449,7 @@ class DataCache:
end_datetime) end_datetime)
@staticmethod @staticmethod
def data_complete(data: pd.DataFrame, **kwargs) -> (bool, dict): def _data_complete(data: pd.DataFrame, **kwargs) -> (bool, dict):
""" """
Checks if the data completely satisfies the request. Checks if the data completely satisfies the request.
@ -341,17 +518,21 @@ class DataCache:
return True, kwargs return True, kwargs
def cache_exists(self, key: str) -> bool: def cache_exists(self, key: str) -> bool:
return key in self.cached_data return key in self.cache['key'].values
def get_cache(self, key: str) -> Any | None: def get_cache(self, key: str) -> Any | None:
if key not in self.cached_data: # Check if the key exists in the cache
logger.warning(f"The requested cache key({key}) doesn't exist!") if key not in self.cache['key'].values:
logger.warning(f"The requested data key ({key}) doesn't exist!")
return None return None
return self.cached_data[key]
def update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None: # Retrieve the data associated with the key
logger.debug('Updating cache with new records.') result = self.cache[self.cache['key'] == key]['data'].iloc[0]
# Concatenate the new records with the existing cache return result
def _update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None:
logger.debug('Updating data with new records.')
# Concatenate the new records with the existing data
records = pd.concat([self.get_cache(key), more_records], axis=0, ignore_index=True) records = pd.concat([self.get_cache(key), more_records], axis=0, ignore_index=True)
# Drop duplicates based on 'open_time' and keep the first occurrence # Drop duplicates based on 'open_time' and keep the first occurrence
records = records.drop_duplicates(subset="open_time", keep='first') records = records.drop_duplicates(subset="open_time", keep='first')
@ -359,24 +540,49 @@ class DataCache:
records = records.sort_values(by='open_time').reset_index(drop=True) records = records.sort_values(by='open_time').reset_index(drop=True)
# Reindex 'id' to ensure the expected order # Reindex 'id' to ensure the expected order
records['id'] = range(1, len(records) + 1) records['id'] = range(1, len(records) + 1)
# Set the updated DataFrame back to cache # Set the updated DataFrame back to data
self.set_cache(data=records, key=key) self.set_cache(data=records, key=key)
def update_cached_dict(self, cache_key: str, dict_key: str, data: Any) -> None: def update_cached_dict(self, cache_key: str, dict_key: str, data: Any) -> None:
""" """
Updates a dictionary stored in cache. Updates a dictionary stored in the DataFrame cache.
:param data: The data to insert into cache. :param data: The data to insert into the dictionary.
:param cache_key: The cache index key for the dictionary. :param cache_key: The cache key for the dictionary.
:param dict_key: The dictionary key for the data. :param dict_key: The key within the dictionary to update.
:return: None :return: None
""" """
self.cached_data[cache_key].update({dict_key: data}) # Locate the row in the DataFrame that matches the cache_key
cache_index = self.cache.index[self.cache['key'] == cache_key]
if not cache_index.empty:
# Update the dictionary stored in the 'data' column
cache_dict = self.cache.at[cache_index[0], 'data']
if isinstance(cache_dict, dict):
cache_dict[dict_key] = data
# Ensure the DataFrame is updated with the new dictionary
self.cache.at[cache_index[0], 'data'] = cache_dict
else:
raise ValueError(f"Expected a dictionary in cache, but found {type(cache_dict)}.")
else:
raise KeyError(f"Cache key '{cache_key}' not found.")
def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None: def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None:
if do_not_overwrite and key in self.cached_data: if do_not_overwrite and key in self.cache['key'].values:
return return
self.cached_data[key] = data
# Corrected construction of the new row
new_row = pd.DataFrame({'key': [key], 'data': [data]})
# If the key already exists, drop the old entry
self.cache = self.cache[self.cache['key'] != key]
# Append the new row to the cache
self.cache = pd.concat([self.cache, new_row], ignore_index=True)
print(f'Current Cache: {self.cache}')
logger.debug(f'Cache set for key: {key}') logger.debug(f'Cache set for key: {key}')
def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str, def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str,
@ -422,12 +628,12 @@ class DataCache:
if num_rec_records < estimated_num_records: if num_rec_records < estimated_num_records:
logger.info('Detected gaps in the data, attempting to fill missing records.') logger.info('Detected gaps in the data, attempting to fill missing records.')
candles = self.fill_data_holes(candles, interval) candles = self._fill_data_holes(candles, interval)
return candles return candles
@staticmethod @staticmethod
def fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame: def _fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame:
time_span = timeframe_to_timedelta(interval).total_seconds() / 60 time_span = timeframe_to_timedelta(interval).total_seconds() / 60
last_timestamp = None last_timestamp = None
filled_records = [] filled_records = []

View File

@ -33,7 +33,7 @@ class SQLite:
class HDict(dict): class HDict(dict):
""" """
Hashable dictionary to use as cache keys. Hashable dictionary to use as data keys.
Example usage: Example usage:
-------------- --------------
@ -85,15 +85,16 @@ class Database:
def __init__(self, db_file: str = None): def __init__(self, db_file: str = None):
self.db_file = db_file self.db_file = db_file
def execute_sql(self, sql: str) -> None: def execute_sql(self, sql: str, params: tuple = ()) -> None:
""" """
Executes a raw SQL statement. Executes a raw SQL statement with optional parameters.
:param sql: SQL statement to execute. :param sql: SQL statement to execute.
:param params: Optional tuple of parameters to pass with the SQL statement.
""" """
with SQLite(self.db_file) as con: with SQLite(self.db_file) as con:
cur = con.cursor() cur = con.cursor()
cur.execute(sql) cur.execute(sql, params)
def get_item_where(self, item_name: str, table_name: str, filter_vals: Tuple[str, Any]) -> Any: def get_item_where(self, item_name: str, table_name: str, filter_vals: Tuple[str, Any]) -> Any:
""" """
@ -120,12 +121,17 @@ class Database:
:param table: Name of the table. :param table: Name of the table.
:param filter_vals: Tuple of column name and value to filter by. :param filter_vals: Tuple of column name and value to filter by.
:return: DataFrame of the query result or None if empty. :return: DataFrame of the query result or None if empty or column does not exist.
""" """
with SQLite(self.db_file) as con: try:
qry = f"SELECT * FROM {table} WHERE {filter_vals[0]} = ?" with SQLite(self.db_file) as con:
result = pd.read_sql(qry, con, params=(filter_vals[1],)) qry = f"SELECT * FROM {table} WHERE {filter_vals[0]} = ?"
return result if not result.empty else None result = pd.read_sql(qry, con, params=(filter_vals[1],))
return result if not result.empty else None
except (sqlite3.OperationalError, pd.errors.DatabaseError) as e:
# Log the error or handle it appropriately
print(f"Error querying table '{table}' for column '{filter_vals[0]}': {e}")
return None
def insert_dataframe(self, df: pd.DataFrame, table: str) -> None: def insert_dataframe(self, df: pd.DataFrame, table: str) -> None:
""" """

View File

@ -52,11 +52,18 @@ class Signal:
class Signals: class Signals:
def __init__(self, loaded_signals=None): def __init__(self, config):
# list of Signal objects. # list of Signal objects.
self.signals = [] self.signals = []
# load a list of existing signals from file.
loaded_signals = config.get_setting('signals_list')
if loaded_signals is None:
# Populate the list and file with defaults defined in this class.
loaded_signals = self.get_signals_defaults()
config.set_setting('signals_list', loaded_signals)
# Initialize signals with loaded data. # Initialize signals with loaded data.
if loaded_signals is not None: if loaded_signals is not None:
self.create_signal_from_dic(loaded_signals) self.create_signal_from_dic(loaded_signals)

View File

@ -1,5 +1,5 @@
import json import json
from DataCache_v2 import DataCache
class Strategy: class Strategy:
def __init__(self, **args): def __init__(self, **args):
@ -154,14 +154,26 @@ class Strategy:
class Strategies: class Strategies:
def __init__(self, loaded_strats, trades): def __init__(self, data: DataCache, trades):
# Handles database connections
self.data = data
# Reference to the trades object that maintains all trading actions and data. # Reference to the trades object that maintains all trading actions and data.
self.trades = trades self.trades = trades
# A list of all the Strategies created.
self.strat_list = [] def get_all_strategy_names(self) -> list | None:
# Initialise all the stately objects with the data saved to file. """Return a list of all strategies in the database"""
for entry in loaded_strats: self.data._get_from_database()
self.strat_list.append(Strategy(**entry)) # Load existing Strategies from file.
loaded_strategies = config.get_setting('strategies')
if loaded_strategies is None:
# Populate the list and file with defaults defined in this class.
loaded_strategies = self.get_strategy_defaults()
config.set_setting('strategies', loaded_strategies)
for entry in loaded_strategies:
# Initialise all the strategy objects with data from file.
self.strat_list.append(Strategy(**entry)) return None
def new_strategy(self, data): def new_strategy(self, data):
# Create an instance of the new Strategy. # Create an instance of the new Strategy.
@ -238,6 +250,7 @@ class Strategies:
published strategies and evaluates conditions against the data. published strategies and evaluates conditions against the data.
This function returns a list of strategies and action commands. This function returns a list of strategies and action commands.
""" """
def process_strategy(strategy): def process_strategy(strategy):
action, cmd = strategy.evaluate_strategy(signals) action, cmd = strategy.evaluate_strategy(signals)
if action != 'do_nothing': if action != 'do_nothing':
@ -262,4 +275,3 @@ class Strategies:
return False return False
else: else:
return return_obj return return_obj

File diff suppressed because it is too large Load Diff

View File

@ -7,7 +7,6 @@ from flask_sock import Sock
from email_validator import validate_email, EmailNotValidError from email_validator import validate_email, EmailNotValidError
# Handles all updates and requests for locally stored data. # Handles all updates and requests for locally stored data.
import config
from BrighterTrades import BrighterTrades from BrighterTrades import BrighterTrades
# Set up logging # Set up logging
@ -54,7 +53,7 @@ def index():
try: try:
# Log the user in. # Log the user in.
user_name = brighter_trades.config.users.load_or_create_user(username=session.get('user')) user_name = brighter_trades.users.load_or_create_user(username=session.get('user'))
except ValueError as e: except ValueError as e:
if str(e) != 'GuestLimitExceeded!': if str(e) != 'GuestLimitExceeded!':
raise raise

View File

@ -37,16 +37,16 @@ class DataCache:
def cache_exists(self, key: str) -> bool: def cache_exists(self, key: str) -> bool:
""" """
Checks if a cache exists for the given key. Checks if a data exists for the given key.
:param key: The access key. :param key: The access key.
:return: True if cache exists, False otherwise. :return: True if data exists, False otherwise.
""" """
return key in self.cached_data return key in self.cached_data
def update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None: def update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None:
""" """
Adds records to existing cache. Adds records to existing data.
:param more_records: The new records to be added. :param more_records: The new records to be added.
:param key: The access key. :param key: The access key.
@ -59,9 +59,9 @@ class DataCache:
def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None: def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None:
""" """
Creates a new cache key and inserts data. Creates a new data key and inserts data.
:param data: The records to insert into cache. :param data: The records to insert into data.
:param key: The index key for the data. :param key: The index key for the data.
:param do_not_overwrite: Flag to prevent overwriting existing data. :param do_not_overwrite: Flag to prevent overwriting existing data.
:return: None :return: None
@ -72,10 +72,10 @@ class DataCache:
def update_cached_dict(self, cache_key: str, dict_key: str, data: Any) -> None: def update_cached_dict(self, cache_key: str, dict_key: str, data: Any) -> None:
""" """
Updates a dictionary stored in cache. Updates a dictionary stored in data.
:param data: The data to insert into cache. :param data: The data to insert into data.
:param cache_key: The cache index key for the dictionary. :param cache_key: The data index key for the dictionary.
:param dict_key: The dictionary key for the data. :param dict_key: The dictionary key for the data.
:return: None :return: None
""" """
@ -89,7 +89,7 @@ class DataCache:
:return: Any|None - The requested data or None on key error. :return: Any|None - The requested data or None on key error.
""" """
if key not in self.cached_data: if key not in self.cached_data:
logger.warning(f"The requested cache key({key}) doesn't exist!") logger.warning(f"The requested data key({key}) doesn't exist!")
return None return None
return self.cached_data[key] return self.cached_data[key]
@ -98,14 +98,14 @@ class DataCache:
""" """
Fetches records since the specified start datetime. Fetches records since the specified start datetime.
:param key: The cache key. :param key: The data key.
:param start_datetime: The start datetime to fetch records from. :param start_datetime: The start datetime to fetch records from.
:param record_length: The required number of records. :param record_length: The required number of records.
:param ex_details: Exchange details. :param ex_details: Exchange details.
:return: DataFrame containing the records. :return: DataFrame containing the records.
""" """
try: try:
target = 'cache' target = 'data'
args = { args = {
'key': key, 'key': key,
'start_datetime': start_datetime, 'start_datetime': start_datetime,
@ -167,7 +167,7 @@ class DataCache:
'record_length': record_length, 'record_length': record_length,
} }
if target == 'cache': if target == 'data':
result = get_from_cache() result = get_from_cache()
if data_complete(result, **request_criteria): if data_complete(result, **request_criteria):
return result return result
@ -193,7 +193,7 @@ class DataCache:
""" """
Fetches records since the specified start datetime. Fetches records since the specified start datetime.
:param key: The cache key. :param key: The data key.
:param start_datetime: The start datetime to fetch records from. :param start_datetime: The start datetime to fetch records from.
:param record_length: The required number of records. :param record_length: The required number of records.
:param ex_details: Exchange details. :param ex_details: Exchange details.
@ -203,11 +203,11 @@ class DataCache:
end_datetime = dt.datetime.utcnow() end_datetime = dt.datetime.utcnow()
if self.cache_exists(key=key): if self.cache_exists(key=key):
logger.debug('Getting records from cache.') logger.debug('Getting records from data.')
records = self.get_cache(key) records = self.get_cache(key)
else: else:
logger.debug( logger.debug(
f'Records not in cache. Requesting from DB: starting at: {start_datetime} to: {end_datetime}') f'Records not in data. Requesting from DB: starting at: {start_datetime} to: {end_datetime}')
records = self.get_records_since_from_db(table_name=key, st=start_datetime, et=end_datetime, records = self.get_records_since_from_db(table_name=key, st=start_datetime, et=end_datetime,
rl=record_length, ex_details=ex_details) rl=record_length, ex_details=ex_details)
logger.debug(f'Got {len(records.index)} records from DB.') logger.debug(f'Got {len(records.index)} records from DB.')

View File

@ -9,23 +9,23 @@ from shared_utilities import timeframe_to_minutes, ts_of_n_minutes_ago
# log.basicConfig(level=log.ERROR) # log.basicConfig(level=log.ERROR)
class Candles: class Candles:
def __init__(self, exchanges, config_obj, data_source): def __init__(self, exchanges, users, data_source, config):
# A reference to the app configuration # A reference to the app configuration
self.config = config_obj self.users = users
# The maximum amount of candles to load at one time. # The maximum amount of candles to load at one time.
self.max_records = self.config.app_data.get('max_data_loaded') self.max_records = config.get_setting('max_data_loaded')
# This object maintains all the cached data. # This object maintains all the cached data.
self.data = data_source self.data = data_source
# print('Setting the candle cache.') # print('Setting the candle data.')
# # Populate the cache: # # Populate the data:
# self.set_cache(symbol=self.config.users.get_chart_view(user_name='guest', specific_property='market'), # self.set_cache(symbol=self.users.get_chart_view(user_name='guest', specific_property='market'),
# interval=self.config.users.get_chart_view(user_name='guest', specific_property='timeframe'), # interval=self.users.get_chart_view(user_name='guest', specific_property='timeframe'),
# exchange_name=self.config.users.get_chart_view(user_name='guest', specific_property='exchange_name')) # exchange_name=self.users.get_chart_view(user_name='guest', specific_property='exchange_name'))
# print('DONE Setting cache') # print('DONE Setting data')
def get_last_n_candles(self, num_candles: int, asset: str, timeframe: str, exchange: str, user_name: str): def get_last_n_candles(self, num_candles: int, asset: str, timeframe: str, exchange: str, user_name: str):
""" """
@ -83,7 +83,7 @@ class Candles:
# #
def set_cache(self, symbol=None, interval=None, exchange_name=None, user_name=None): def set_cache(self, symbol=None, interval=None, exchange_name=None, user_name=None):
""" """
This method requests a chart from memory to ensure the cache is initialized. This method requests a chart from memory to ensure the data is initialized.
:param user_name: :param user_name:
:param symbol: str - The symbol of the market. :param symbol: str - The symbol of the market.
@ -91,18 +91,18 @@ class Candles:
:param exchange_name: str - The name of the exchange_name to fetch from. :param exchange_name: str - The name of the exchange_name to fetch from.
:return: None :return: None
""" """
# By default, initialise cache with the last viewed chart. # By default, initialise data with the last viewed chart.
if not symbol: if not symbol:
assert user_name is not None assert user_name is not None
symbol = self.config.users.get_chart_view(user_name=user_name, prop='market') symbol = self.users.get_chart_view(user_name=user_name, prop='market')
log.info(f'set_candle_history(): No symbol provided. Using{symbol}') log.info(f'set_candle_history(): No symbol provided. Using{symbol}')
if not interval: if not interval:
assert user_name is not None assert user_name is not None
interval = self.config.users.get_chart_view(user_name=user_name, prop='timeframe') interval = self.users.get_chart_view(user_name=user_name, prop='timeframe')
log.info(f'set_candle_history(): No timeframe provided. Using{interval}') log.info(f'set_candle_history(): No timeframe provided. Using{interval}')
if not exchange_name: if not exchange_name:
assert user_name is not None assert user_name is not None
exchange_name = self.config.users.get_chart_view(user_name=user_name, prop='exchange_name') exchange_name = self.users.get_chart_view(user_name=user_name, prop='exchange_name')
# Log the completion to the console. # Log the completion to the console.
log.info('set_candle_history(): Loading candle data...') log.info('set_candle_history(): Loading candle data...')
@ -199,20 +199,20 @@ class Candles:
:param user_name: str - The name of the user who owns the exchange. :param user_name: str - The name of the user who owns the exchange.
:return: list - Candle records in the lightweight charts format. :return: list - Candle records in the lightweight charts format.
""" """
# By default, initialise cache with the last viewed chart. # By default, initialise data with the last viewed chart.
if not symbol: if not symbol:
assert user_name is not None assert user_name is not None
symbol = self.config.users.get_chart_view(user_name=user_name, prop='market''market') symbol = self.users.get_chart_view(user_name=user_name, prop='market''market')
log.info(f'set_candle_history(): No symbol provided. Using{symbol}') log.info(f'set_candle_history(): No symbol provided. Using{symbol}')
if not interval: if not interval:
assert user_name is not None assert user_name is not None
interval = self.config.users.get_chart_view(user_name=user_name, prop='market''timeframe') interval = self.users.get_chart_view(user_name=user_name, prop='market''timeframe')
log.info(f'set_candle_history(): No timeframe provided. Using{interval}') log.info(f'set_candle_history(): No timeframe provided. Using{interval}')
if not exchange_name: if not exchange_name:
assert user_name is not None assert user_name is not None
exchange_name = self.config.users.get_chart_view(user_name=user_name, prop='market''exchange_name') exchange_name = self.users.get_chart_view(user_name=user_name, prop='market''exchange_name')
log.info(f'get_candle_history(): No exchange name provided. Using {exchange_name}') log.info(f'get_candle_history(): No exchange name provided. Using {exchange_name}')
candlesticks = self.get_last_n_candles(num_candles=num_records, asset=symbol, timeframe=interval, candlesticks = self.get_last_n_candles(num_candles=num_records, asset=symbol, timeframe=interval,

View File

@ -308,12 +308,12 @@ indicator_types.append('MACD')
class Indicators: class Indicators:
def __init__(self, candles, config): def __init__(self, candles, users):
# Object manages and serves price and candle data. # Object manages and serves price and candle data.
self.candles = candles self.candles = candles
# A connection to an object that handles user configuration and persistent data. # A connection to an object that handles user data.
self.config = config self.users = users
# Collection of instantiated indicators objects # Collection of instantiated indicators objects
self.indicators = pd.DataFrame(columns=['creator', 'name', 'visible', self.indicators = pd.DataFrame(columns=['creator', 'name', 'visible',
@ -331,7 +331,7 @@ class Indicators:
Get the users watch-list from the database and load the indicators into a dataframe. Get the users watch-list from the database and load the indicators into a dataframe.
:return: None :return: None
""" """
active_indicators: pd.DataFrame = self.config.users.get_indicators(user_name) active_indicators: pd.DataFrame = self.users.get_indicators(user_name)
if active_indicators is not None: if active_indicators is not None:
# Create an instance for each indicator. # Create an instance for each indicator.
@ -347,7 +347,7 @@ class Indicators:
Saves the indicators in the database indexed by the user id. Saves the indicators in the database indexed by the user id.
:return: None :return: None
""" """
self.config.users.save_indicators(indicator) self.users.save_indicators(indicator)
# @staticmethod # @staticmethod
# def get_indicator_defaults(): # def get_indicator_defaults():
@ -389,7 +389,7 @@ class Indicators:
:param only_enabled: bool - If True, return only indicators marked as visible. :param only_enabled: bool - If True, return only indicators marked as visible.
:return: dict - A dictionary of indicator names as keys and their attributes as values. :return: dict - A dictionary of indicator names as keys and their attributes as values.
""" """
user_id = self.config.users.get_id(username) user_id = self.users.get_id(username)
if not user_id: if not user_id:
raise ValueError(f"Invalid user_name: {username}") raise ValueError(f"Invalid user_name: {username}")
@ -498,7 +498,7 @@ class Indicators:
:param num_results: The number of results being requested. :param num_results: The number of results being requested.
:return: The results of the indicator analysis as a DataFrame. :return: The results of the indicator analysis as a DataFrame.
""" """
username = self.config.users.get_username(indicator.creator) username = self.users.get_username(indicator.creator)
src = indicator.source src = indicator.source
symbol, timeframe, exchange_name = src['symbol'], src['timeframe'], src['exchange_name'] symbol, timeframe, exchange_name = src['symbol'], src['timeframe'], src['exchange_name']
@ -532,7 +532,7 @@ class Indicators:
if start_ts: if start_ts:
print("Warning: start_ts has not implemented in get_indicator_data()!") print("Warning: start_ts has not implemented in get_indicator_data()!")
user_id = self.config.users.get_id(user_name=user_name) user_id = self.users.get_id(user_name=user_name)
# Construct the query based on user_id and visibility. # Construct the query based on user_id and visibility.
query = f"creator == {user_id}" query = f"creator == {user_id}"
@ -582,7 +582,7 @@ class Indicators:
if not indicator_name: if not indicator_name:
raise ValueError("No indicator name provided.") raise ValueError("No indicator name provided.")
self.indicators = self.indicators.query("name != @indicator_name").reset_index(drop=True) self.indicators = self.indicators.query("name != @indicator_name").reset_index(drop=True)
self.config.users.save_indicators() self.users.save_indicators()
def create_indicator(self, creator: str, name: str, kind: str, def create_indicator(self, creator: str, name: str, kind: str,
source: dict, properties: dict, visible: bool = True): source: dict, properties: dict, visible: bool = True):
@ -618,7 +618,7 @@ class Indicators:
indicator = indicator_class(name, kind, properties) indicator = indicator_class(name, kind, properties)
# Add the new indicator to a pandas dataframe. # Add the new indicator to a pandas dataframe.
creator_id = self.config.users.get_id(creator) creator_id = self.users.get_id(creator)
row_data = { row_data = {
'creator': creator_id, 'creator': creator_id,
'name': name, 'name': name,

View File

@ -1,38 +1,33 @@
import ccxt
import pandas as pd import pandas as pd
import datetime import matplotlib.pyplot as plt
from matplotlib.table import Table
# Simulating the cache as a DataFrame
data = {
'key': ['BTC/USD_2h_binance', 'ETH/USD_1h_coinbase'],
'data': ['{"open": 50000, "close": 50500}', '{"open": 1800, "close": 1825}']
}
cache_df = pd.DataFrame(data)
def fetch_and_print_ohlcv(exchange_name, symbol, timeframe, since, limit=5): # Visualization function
# Initialize the exchange def visualize_cache(df):
exchange_class = getattr(ccxt, exchange_name) fig, ax = plt.subplots(figsize=(6, 3))
exchange = exchange_class() ax.set_axis_off()
tb = Table(ax, bbox=[0, 0, 1, 1])
# Fetch historical candlestick data with a limit # Adding column headers
ohlcv = exchange.fetch_ohlcv(symbol, timeframe, since=since, limit=limit) for i, column in enumerate(df.columns):
tb.add_cell(0, i, width=0.4, height=0.3, text=column, loc='center', facecolor='lightgrey')
# Convert to DataFrame for better readability # Adding rows and cells
df = pd.DataFrame(ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']) for i in range(len(df)):
for j, value in enumerate(df.iloc[i]):
tb.add_cell(i + 1, j, width=0.4, height=0.3, text=value, loc='center', facecolor='white')
# Print the first few rows of the DataFrame ax.add_table(tb)
print("First few rows of the fetched OHLCV data:") plt.title("Visualizing Cache Data")
print(df.head()) plt.show()
# Print the timestamps in human-readable format
df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms')
print("\nFirst few timestamps in human-readable format:")
print(df[['timestamp', 'datetime']].head())
# Confirm the format of the timestamps
print("\nTimestamp format confirmation:")
for ts in df['timestamp']:
print(f"{ts} (milliseconds since Unix epoch)")
# Example usage visualize_cache(cache_df)
exchange_name = 'binance' # Change this to your exchange
symbol = 'BTC/USDT'
timeframe = '5m'
since = int((datetime.datetime(2024, 8, 1) - datetime.datetime(1970, 1, 1)).total_seconds() * 1000)
fetch_and_print_ohlcv(exchange_name, symbol, timeframe, since, limit=5)

View File

@ -267,10 +267,10 @@ class Trade:
class Trades: class Trades:
def __init__(self, loaded_trades=None): def __init__(self, users):
""" """
This class receives, executes, tracks and stores all active_trades. This class receives, executes, tracks and stores all active_trades.
:param loaded_trades: <list?dict?> A bunch of active_trades to create and store. :param users: <Users> A class that maintains users each user may have trades.
""" """
# Object that maintains exchange_interface and account data # Object that maintains exchange_interface and account data
self.exchange_interface = None self.exchange_interface = None
@ -290,7 +290,8 @@ class Trades:
self.settled_trades = [] self.settled_trades = []
self.stats = {'num_trades': 0, 'total_position': 0, 'total_position_value': 0} self.stats = {'num_trades': 0, 'total_position': 0, 'total_position_value': 0}
# Load any trades that were passed into the constructor. # Load all trades.
loaded_trades = users.get_all_active_user_trades()
if loaded_trades is not None: if loaded_trades is not None:
# Create the active_trades loaded from file. # Create the active_trades loaded from file.
self.load_trades(loaded_trades) self.load_trades(loaded_trades)

View File

@ -1,5 +1,5 @@
import pytz import pytz
from DataCache_v2 import DataCache from DataCache_v2 import DataCache, timeframe_to_timedelta, estimate_record_count
from ExchangeInterface import ExchangeInterface from ExchangeInterface import ExchangeInterface
import unittest import unittest
import pandas as pd import pandas as pd
@ -224,10 +224,23 @@ class TestDataCacheV2(unittest.TestCase):
exchange_id INTEGER, exchange_id INTEGER,
FOREIGN KEY (exchange_id) REFERENCES exchange(id) FOREIGN KEY (exchange_id) REFERENCES exchange(id)
)""" )"""
sql_create_table_4 = f"""
CREATE TABLE IF NOT EXISTS test_table_2 (
key TEXT PRIMARY KEY,
data TEXT NOT NULL
)"""
sql_create_table_5 = f"""
CREATE TABLE IF NOT EXISTS users (
users_data TEXT PRIMARY KEY,
data TEXT NOT NULL
)"""
with SQLite(db_file=self.db_file) as con: with SQLite(db_file=self.db_file) as con:
con.execute(sql_create_table_1) con.execute(sql_create_table_1)
con.execute(sql_create_table_2) con.execute(sql_create_table_2)
con.execute(sql_create_table_3) con.execute(sql_create_table_3)
con.execute(sql_create_table_4)
con.execute(sql_create_table_5)
self.data = DataCache(self.exchanges) self.data = DataCache(self.exchanges)
self.data.db = self.database self.data.db = self.database
@ -241,30 +254,41 @@ class TestDataCacheV2(unittest.TestCase):
def test_set_cache(self): def test_set_cache(self):
print('\nTesting set_cache() method without no-overwrite flag:') print('\nTesting set_cache() method without no-overwrite flag:')
self.data.set_cache(data='data', key=self.key) self.data.set_cache(data='data', key=self.key)
attr = self.data.__getattribute__('cached_data')
self.assertEqual(attr[self.key], 'data') # Access the cache data using the DataFrame structure
print(' - Set cache without no-overwrite flag passed.') cached_value = self.data.get_cache(key=self.key)
self.assertEqual(cached_value, 'data')
print(' - Set data without no-overwrite flag passed.')
print('Testing set_cache() once again with new data without no-overwrite flag:') print('Testing set_cache() once again with new data without no-overwrite flag:')
self.data.set_cache(data='more_data', key=self.key) self.data.set_cache(data='more_data', key=self.key)
attr = self.data.__getattribute__('cached_data')
self.assertEqual(attr[self.key], 'more_data') # Access the updated cache data
print(' - Set cache with new data without no-overwrite flag passed.') cached_value = self.data.get_cache(key=self.key)
self.assertEqual(cached_value, 'more_data')
print(' - Set data with new data without no-overwrite flag passed.')
print('Testing set_cache() method once again with more data with no-overwrite flag set:') print('Testing set_cache() method once again with more data with no-overwrite flag set:')
self.data.set_cache(data='even_more_data', key=self.key, do_not_overwrite=True) self.data.set_cache(data='even_more_data', key=self.key, do_not_overwrite=True)
attr = self.data.__getattribute__('cached_data')
self.assertEqual(attr[self.key], 'more_data') # Since do_not_overwrite is True, the cached data should not change
print(' - Set cache with no-overwrite flag passed.') cached_value = self.data.get_cache(key=self.key)
self.assertEqual(cached_value, 'more_data')
print(' - Set data with no-overwrite flag passed.')
def test_cache_exists(self): def test_cache_exists(self):
print('Testing cache_exists() method:') print('Testing cache_exists() method:')
self.assertFalse(self.data.cache_exists(key=self.key))
print(' - Check for non-existent cache passed.')
# Check that the cache does not contain the key before setting it
self.assertFalse(self.data.cache_exists(key=self.key))
print(' - Check for non-existent data passed.')
# Set the cache with a DataFrame containing the key-value pair
self.data.set_cache(data='data', key=self.key) self.data.set_cache(data='data', key=self.key)
# Check that the cache now contains the key
self.assertTrue(self.data.cache_exists(key=self.key)) self.assertTrue(self.data.cache_exists(key=self.key))
print(' - Check for existent cache passed.') print(' - Check for existent data passed.')
def test_update_candle_cache(self): def test_update_candle_cache(self):
print('Testing update_candle_cache() method:') print('Testing update_candle_cache() method:')
@ -273,21 +297,21 @@ class TestDataCacheV2(unittest.TestCase):
data_gen = DataGenerator('5m') data_gen = DataGenerator('5m')
# Create initial DataFrame and insert into cache # Create initial DataFrame and insert into cache
df_initial = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 0, 0)) df_initial = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 0, 0, tzinfo=dt.timezone.utc))
print(f'Inserting this table into cache:\n{df_initial}\n') print(f'Inserting this table into cache:\n{df_initial}\n')
self.data.set_cache(data=df_initial, key=self.key) self.data.set_cache(data=df_initial, key=self.key)
# Create new DataFrame to be added to cache # Create new DataFrame to be added to cache
df_new = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 15, 0)) df_new = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 15, 0, tzinfo=dt.timezone.utc))
print(f'Updating cache with this table:\n{df_new}\n') print(f'Updating cache with this table:\n{df_new}\n')
self.data.update_candle_cache(more_records=df_new, key=self.key) self.data._update_candle_cache(more_records=df_new, key=self.key)
# Retrieve the resulting DataFrame from cache # Retrieve the resulting DataFrame from cache
result = self.data.get_cache(key=self.key) result = self.data.get_cache(key=self.key)
print(f'The resulting table in cache is:\n{result}\n') print(f'The resulting table in cache is:\n{result}\n')
# Create the expected DataFrame # Create the expected DataFrame
expected = data_gen.create_table(num_rec=6, start=dt.datetime(2024, 8, 9, 0, 0, 0)) expected = data_gen.create_table(num_rec=6, start=dt.datetime(2024, 8, 9, 0, 0, 0, tzinfo=dt.timezone.utc))
print(f'The expected open_time values are:\n{expected["open_time"].tolist()}\n') print(f'The expected open_time values are:\n{expected["open_time"].tolist()}\n')
# Assert that the open_time values in the result match those in the expected DataFrame, in order # Assert that the open_time values in the result match those in the expected DataFrame, in order
@ -295,37 +319,50 @@ class TestDataCacheV2(unittest.TestCase):
f"open_time values in result are {result['open_time'].tolist()}" \ f"open_time values in result are {result['open_time'].tolist()}" \
f" but expected {expected['open_time'].tolist()}" f" but expected {expected['open_time'].tolist()}"
print(f'The results open_time values match:\n{result["open_time"].tolist()}\n') print(f'The result open_time values match:\n{result["open_time"].tolist()}\n')
print(' - Update cache with new records passed.') print(' - Update cache with new records passed.')
def test_update_cached_dict(self): def test_update_cached_dict(self):
print('Testing update_cached_dict() method:') print('Testing update_cached_dict() method:')
# Set an empty dictionary in the cache for the specified key
self.data.set_cache(data={}, key=self.key) self.data.set_cache(data={}, key=self.key)
# Update the cached dictionary with a new key-value pair
self.data.update_cached_dict(cache_key=self.key, dict_key='sub_key', data='value') self.data.update_cached_dict(cache_key=self.key, dict_key='sub_key', data='value')
# Retrieve the updated cache
cache = self.data.get_cache(key=self.key) cache = self.data.get_cache(key=self.key)
# Verify that the 'sub_key' in the cached dictionary has the correct value
self.assertEqual(cache['sub_key'], 'value') self.assertEqual(cache['sub_key'], 'value')
print(' - Update dictionary in cache passed.') print(' - Update dictionary in cache passed.')
def test_get_cache(self): def test_get_cache(self):
print('Testing get_cache() method:') print('Testing get_cache() method:')
# Set some data into the cache
self.data.set_cache(data='data', key=self.key) self.data.set_cache(data='data', key=self.key)
# Retrieve the cached data using the get_cache method
result = self.data.get_cache(key=self.key) result = self.data.get_cache(key=self.key)
# Verify that the result matches the data we set
self.assertEqual(result, 'data') self.assertEqual(result, 'data')
print(' - Retrieve cache passed.') print(' - Retrieve data passed.')
def _test_get_records_since(self, set_cache=True, set_db=True, query_offset=None, num_rec=None, ex_details=None, def _test_get_records_since(self, set_cache=True, set_db=True, query_offset=None, num_rec=None, ex_details=None,
simulate_scenarios=None): simulate_scenarios=None):
""" """
Test the get_records_since() method by generating a table of simulated data, Test the get_records_since() method by generating a table of simulated data,
inserting it into cache and/or database, and then querying the records. inserting it into data and/or database, and then querying the records.
Parameters: Parameters:
set_cache (bool): If True, the generated table is inserted into the cache. set_cache (bool): If True, the generated table is inserted into the cache.
set_db (bool): If True, the generated table is inserted into the database. set_db (bool): If True, the generated table is inserted into the database.
query_offset (int, optional): The offset in the timeframe units for the query. query_offset (int, optional): The offset in the timeframe units for the query.
num_rec (int, optional): The number of records to generate in the simulated table. num_rec (int, optional): The number of records to generate in the simulated table.
ex_details (list, optional): Exchange details to generate the cache key. ex_details (list, optional): Exchange details to generate the data key.
simulate_scenarios (str, optional): The type of scenario to simulate. Options are: simulate_scenarios (str, optional): The type of scenario to simulate. Options are:
- 'not_enough_data': The table data doesn't go far enough back. - 'not_enough_data': The table data doesn't go far enough back.
- 'incomplete_data': The table doesn't have enough records to satisfy the query. - 'incomplete_data': The table doesn't have enough records to satisfy the query.
@ -336,7 +373,7 @@ class TestDataCacheV2(unittest.TestCase):
# Use provided ex_details or fallback to the class attribute. # Use provided ex_details or fallback to the class attribute.
ex_details = ex_details or self.ex_details ex_details = ex_details or self.ex_details
# Generate a cache/database key using exchange details. # Generate a data/database key using exchange details.
key = f'{ex_details[0]}_{ex_details[1]}_{ex_details[2]}' key = f'{ex_details[0]}_{ex_details[1]}_{ex_details[2]}'
# Set default number of records if not provided. # Set default number of records if not provided.
@ -374,13 +411,13 @@ class TestDataCacheV2(unittest.TestCase):
print(f'Table Created:\n{temp_df}') print(f'Table Created:\n{temp_df}')
if set_cache: if set_cache:
# Insert the generated table into cache. # Insert the generated table into the cache.
print('Inserting table into cache.') print('Inserting table into the cache.')
self.data.set_cache(data=df_initial, key=key) self.data.set_cache(data=df_initial, key=key)
if set_db: if set_db:
# Insert the generated table into the database. # Insert the generated table into the database.
print('Inserting table into database.') print('Inserting table into the database.')
with SQLite(self.db_file) as con: with SQLite(self.db_file) as con:
df_initial.to_sql(key, con, if_exists='replace', index=False) df_initial.to_sql(key, con, if_exists='replace', index=False)
@ -414,7 +451,7 @@ class TestDataCacheV2(unittest.TestCase):
# Check that the result has more rows than the expected incomplete data. # Check that the result has more rows than the expected incomplete data.
assert result.shape[0] > expected.shape[ assert result.shape[0] > expected.shape[
0], "Result has fewer or equal rows compared to the incomplete data." 0], "Result has fewer or equal rows compared to the incomplete data."
print("\nThe returned DataFrames has filled in the missing data!") print("\nThe returned DataFrame has filled in the missing data!")
else: else:
# Ensure the result and expected dataframes match in shape and content. # Ensure the result and expected dataframes match in shape and content.
assert result.shape == expected.shape, f"Shape mismatch: {result.shape} vs {expected.shape}" assert result.shape == expected.shape, f"Shape mismatch: {result.shape} vs {expected.shape}"
@ -444,10 +481,10 @@ class TestDataCacheV2(unittest.TestCase):
print(' - Fetch records within the specified time range passed.') print(' - Fetch records within the specified time range passed.')
def test_get_records_since(self): def test_get_records_since(self):
print('\nTest get_records_since with records set in cache') print('\nTest get_records_since with records set in data')
self._test_get_records_since() self._test_get_records_since()
print('\nTest get_records_since with records not in cache') print('\nTest get_records_since with records not in data')
self._test_get_records_since(set_cache=False) self._test_get_records_since(set_cache=False)
print('\nTest get_records_since with records not in database') print('\nTest get_records_since with records not in database')
@ -467,21 +504,21 @@ class TestDataCacheV2(unittest.TestCase):
self._test_get_records_since(simulate_scenarios='missing_section') self._test_get_records_since(simulate_scenarios='missing_section')
def test_other_timeframes(self): def test_other_timeframes(self):
# print('\nTest get_records_since with a different timeframe') print('\nTest get_records_since with a different timeframe')
# ex_details = ['BTC/USD', '15m', 'binance', 'test_guy'] ex_details = ['BTC/USD', '15m', 'binance', 'test_guy']
# start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=2) start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=2)
# # Query the records since the calculated start time. # Query the records since the calculated start time.
# result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details) result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details)
# last_record_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC') last_record_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC')
# assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=15.1) assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=15.1)
#
# print('\nTest get_records_since with a different timeframe') print('\nTest get_records_since with a different timeframe')
# ex_details = ['BTC/USD', '5m', 'binance', 'test_guy'] ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
# start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=1) start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=1)
# # Query the records since the calculated start time. # Query the records since the calculated start time.
# result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details) result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details)
# last_record_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC') last_record_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC')
# assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=5.1) assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=5.1)
print('\nTest get_records_since with a different timeframe') print('\nTest get_records_since with a different timeframe')
ex_details = ['BTC/USD', '4h', 'binance', 'test_guy'] ex_details = ['BTC/USD', '4h', 'binance', 'test_guy']
@ -506,16 +543,217 @@ class TestDataCacheV2(unittest.TestCase):
def test_fetch_candles_from_exchange(self): def test_fetch_candles_from_exchange(self):
print('Testing _fetch_candles_from_exchange() method:') print('Testing _fetch_candles_from_exchange() method:')
start_time = dt.datetime.utcnow() - dt.timedelta(days=1)
end_time = dt.datetime.utcnow()
result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h', exchange_name='binance', # Define start and end times for the data fetch
user_name='test_guy', start_datetime=start_time, start_time = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc) - dt.timedelta(days=1)
end_datetime=end_time) end_time = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc)
# Fetch the candles from the exchange using the method
result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h',
exchange_name='binance', user_name='test_guy',
start_datetime=start_time, end_datetime=end_time)
# Validate that the result is a DataFrame
self.assertIsInstance(result, pd.DataFrame) self.assertIsInstance(result, pd.DataFrame)
self.assertFalse(result.empty)
# Validate that the DataFrame is not empty
self.assertFalse(result.empty, "The DataFrame returned from the exchange is empty.")
# Ensure that the 'open_time' column exists in the DataFrame
self.assertIn('open_time', result.columns, "'open_time' column is missing in the result DataFrame.")
# Check if the DataFrame contains valid timestamps within the specified range
min_time = pd.to_datetime(result['open_time'].min(), unit='ms').tz_localize('UTC')
max_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC')
self.assertTrue(start_time <= min_time <= end_time, f"Data starts outside the expected range: {min_time}")
self.assertTrue(start_time <= max_time <= end_time, f"Data ends outside the expected range: {max_time}")
print(' - Fetch candle data from exchange passed.') print(' - Fetch candle data from exchange passed.')
def test_remove_row(self):
print('Testing remove_row() method:')
# Insert data into the cache with the expected columns
df = pd.DataFrame({
'key': [self.key],
'data': ['test_data']
})
self.data.set_cache(data='test_data', key=self.key)
# Ensure the data is in the cache
self.assertTrue(self.data.cache_exists(self.key), "Data was not correctly inserted into the cache.")
# Remove the row from the cache only (soft delete)
self.data.remove_row(table='test_table_2', filter_vals=('key', self.key), remove_from_db=False)
# Verify the row has been removed from the cache
self.assertFalse(self.data.cache_exists(self.key), "Row was not correctly removed from the cache.")
# Reinsert the data for hard delete test
self.data.set_cache(data='test_data', key=self.key)
# Mock database delete by adding the row to the database
self.data.db.insert_row(table='test_table_2', columns=('key', 'data'), values=(self.key, 'test_data'))
# Remove the row from both cache and database (hard delete)
self.data.remove_row(table='test_table_2', filter_vals=('key', self.key), remove_from_db=True)
# Verify the row has been removed from the cache
self.assertFalse(self.data.cache_exists(self.key), "Row was not correctly removed from the cache.")
# Verify the row has been removed from the database
with SQLite(self.db_file) as con:
result = pd.read_sql(f'SELECT * FROM test_table_2 WHERE key="{self.key}"', con)
self.assertTrue(result.empty, "Row was not correctly removed from the database.")
print(' - Remove row from cache and database passed.')
def test_timeframe_to_timedelta(self):
print('Testing timeframe_to_timedelta() function:')
result = timeframe_to_timedelta('2h')
expected = pd.Timedelta(hours=2)
self.assertEqual(result, expected, "Failed to convert '2h' to Timedelta")
result = timeframe_to_timedelta('5m')
expected = pd.Timedelta(minutes=5)
self.assertEqual(result, expected, "Failed to convert '5m' to Timedelta")
result = timeframe_to_timedelta('1d')
expected = pd.Timedelta(days=1)
self.assertEqual(result, expected, "Failed to convert '1d' to Timedelta")
result = timeframe_to_timedelta('3M')
expected = pd.DateOffset(months=3)
self.assertEqual(result, expected, "Failed to convert '3M' to DateOffset")
result = timeframe_to_timedelta('1Y')
expected = pd.DateOffset(years=1)
self.assertEqual(result, expected, "Failed to convert '1Y' to DateOffset")
with self.assertRaises(ValueError):
timeframe_to_timedelta('5x')
print(' - All timeframe_to_timedelta() tests passed.')
def test_estimate_record_count(self):
print('Testing estimate_record_count() function:')
start_time = dt.datetime(2023, 8, 1, 0, 0, 0, tzinfo=dt.timezone.utc)
end_time = dt.datetime(2023, 8, 2, 0, 0, 0, tzinfo=dt.timezone.utc)
result = estimate_record_count(start_time, end_time, '1h')
expected = 24
self.assertEqual(result, expected, "Failed to estimate record count for 1h timeframe")
result = estimate_record_count(start_time, end_time, '1d')
expected = 1
self.assertEqual(result, expected, "Failed to estimate record count for 1d timeframe")
start_time = int(start_time.timestamp() * 1000) # Convert to milliseconds
end_time = int(end_time.timestamp() * 1000) # Convert to milliseconds
result = estimate_record_count(start_time, end_time, '1h')
expected = 24
self.assertEqual(result, expected, "Failed to estimate record count for 1h timeframe with milliseconds")
with self.assertRaises(ValueError):
estimate_record_count("invalid_start", end_time, '1h')
print(' - All estimate_record_count() tests passed.')
def test_fetch_cached_rows(self):
print('Testing fetch_cached_rows() method:')
# Set up mock data in the cache
df = pd.DataFrame({
'table': ['test_table_2'],
'key': ['test_key'],
'data': ['test_data']
})
self.data.cache = pd.concat([self.data.cache, df])
# Test fetching from cache
result = self.data.fetch_cached_rows('test_table_2', ('key', 'test_key'))
self.assertIsInstance(result, pd.DataFrame, "Failed to fetch DataFrame from cache")
self.assertFalse(result.empty, "The fetched DataFrame is empty")
self.assertEqual(result.iloc[0]['data'], 'test_data', "Incorrect data fetched from cache")
# Test fetching from database (assuming the method calls it)
# Here we would typically mock the database call
# But since we're not doing I/O, we will skip that part
print(' - Fetch from cache and database simulated.')
def test_is_attr_taken(self):
print('Testing is_attr_taken() method:')
# Set up mock data in the cache
df = pd.DataFrame({
'table': ['users'],
'user_name': ['test_user'],
'data': ['test_data']
})
self.data.cache = pd.concat([self.data.cache, df])
# Test for existing attribute
result = self.data.is_attr_taken('users', 'user_name', 'test_user')
self.assertTrue(result, "Failed to detect existing attribute")
# Test for non-existing attribute
result = self.data.is_attr_taken('users', 'user_name', 'non_existing_user')
self.assertFalse(result, "Incorrectly detected non-existing attribute")
print(' - All is_attr_taken() tests passed.')
def test_insert_data(self):
print('Testing insert_data() method:')
# Create a DataFrame to insert
df = pd.DataFrame({
'key': ['new_key'],
'data': ['new_data']
})
# Insert data into the database and cache
self.data.insert_data(df=df, table='test_table_2')
# Verify that the data was added to the cache
cached_value = self.data.get_cache('new_key')
self.assertEqual(cached_value, 'new_data', "Failed to insert data into cache")
# Normally, we would also verify that the data was inserted into the database
# This would typically be done with a mock database or by checking the database state directly
print(' - Data insertion into cache and database simulated.')
def test_insert_row(self):
print('Testing insert_row() method:')
self.data.insert_row(table='test_table_2', columns=('key', 'data'), values=('test_key', 'test_data'))
# Verify the row was inserted
with SQLite(self.db_file) as con:
result = pd.read_sql('SELECT * FROM test_table_2 WHERE key="test_key"', con)
self.assertFalse(result.empty, "Row was not inserted into the database.")
print(' - Insert row passed.')
def test_fill_data_holes(self):
print('Testing _fill_data_holes() method:')
# Create mock data with gaps
df = pd.DataFrame({
'open_time': [dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc).timestamp() * 1000,
dt.datetime(2023, 1, 1, 2, tzinfo=dt.timezone.utc).timestamp() * 1000,
dt.datetime(2023, 1, 1, 6, tzinfo=dt.timezone.utc).timestamp() * 1000,
dt.datetime(2023, 1, 1, 8, tzinfo=dt.timezone.utc).timestamp() * 1000,
dt.datetime(2023, 1, 1, 12, tzinfo=dt.timezone.utc).timestamp() * 1000]
})
# Call the method
result = self.data._fill_data_holes(records=df, interval='2h')
self.assertEqual(len(result), 7, "Data holes were not filled correctly.")
print(' - _fill_data_holes passed.')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -115,7 +115,7 @@ def test_log_out_all_users():
def test_load_user_data(): def test_load_user_data():
# Todo method incomplete # Todo method incomplete
result = config.users.load_user_data(user_name='RobbieD') result = config.users.get_user_data(user_name='RobbieD')
print('\n Test result:', result) print('\n Test result:', result)
assert result is None assert result is None