Refactored DataCache, Database and Users. All db interactions are now all inside Database. All data request from Users now go through DataCache. Expanded DataCache test to include all methods. All DataCache tests pass.

This commit is contained in:
Rob 2024-08-19 23:10:13 -03:00
parent fc407708ba
commit 8361efd965
16 changed files with 1383 additions and 776 deletions

2
.gitignore vendored
View File

@ -4,7 +4,7 @@
# Ignore Flask session data
flask_session/
# Ignore testing cache
# Ignore testing data
.pytest_cache/
# Ignore databases

View File

@ -1,6 +1,6 @@
from typing import Any
from Users import Users
from DataCache_v2 import DataCache
from Strategies import Strategies
from backtesting import Backtester
@ -20,25 +20,29 @@ class BrighterTrades:
# Object that interacts with the persistent data.
self.data = DataCache(self.exchanges)
# Configuration and settings for the user app and charts
self.config = Configuration(cache=self.data)
# Configuration for the app
self.config = Configuration()
# Object that maintains signals. Initialize with any signals loaded from file.
self.signals = Signals(self.config.signals_list)
# The object that manages users in the system.
self.users = Users(data_cache=self.data)
# Object that maintains signals.
self.signals = Signals(self.config)
# Object that maintains candlestick and price data.
self.candles = Candles(config_obj=self.config, exchanges=self.exchanges, data_source=self.data)
self.candles = Candles(users=self.users, exchanges=self.exchanges, data_source=self.data,
config=self.config)
# Object that interacts with and maintains data from available indicators
self.indicators = Indicators(self.candles, self.config)
self.indicators = Indicators(self.candles, self.users)
# Object that maintains the trades data
self.trades = Trades(self.config.trades)
self.trades = Trades(self.users)
# The Trades object needs to connect to an exchange_interface.
self.trades.connect_exchanges(exchanges=self.exchanges)
# Object that maintains the strategies data
self.strategies = Strategies(self.config.strategies_list, self.trades)
self.strategies = Strategies(self.data, self.trades)
# Object responsible for testing trade and strategies data.
self.backtester = Backtester()
@ -56,8 +60,8 @@ class BrighterTrades:
raise ValueError("Missing required arguments for 'create_new_user'")
try:
self.config.users.create_new_user(email=email, username=username, password=password)
login_successful = self.config.users.log_in_user(username=username, password=password)
self.users.create_new_user(email=email, username=username, password=password)
login_successful = self.users.log_in_user(username=username, password=password)
return login_successful
except Exception as e:
# Handle specific exceptions or log the error
@ -77,11 +81,11 @@ class BrighterTrades:
try:
if cmd == 'logout':
return self.config.users.log_out_user(username=user_name)
return self.users.log_out_user(username=user_name)
elif cmd == 'login':
if password is None:
raise ValueError("Password is required for login.")
return self.config.users.log_in_user(username=user_name, password=password)
return self.users.log_in_user(username=user_name, password=password)
except Exception as e:
# Handle specific exceptions or log the error
raise ValueError("Error during user login/logout: " + str(e))
@ -98,19 +102,19 @@ class BrighterTrades:
if info == 'Chart View':
try:
return self.config.users.get_chart_view(user_name=user_name)
return self.users.get_chart_view(user_name=user_name)
except Exception as e:
# Handle specific exceptions or log the error
raise ValueError("Error retrieving chart view information: " + str(e))
elif info == 'Is logged in?':
try:
return self.config.users.is_logged_in(user_name=user_name)
return self.users.is_logged_in(user_name=user_name)
except Exception as e:
# Handle specific exceptions or log the error
raise ValueError("Error checking logged in status: " + str(e))
elif info == 'User_id':
try:
return self.config.users.get_id(user_name=user_name)
return self.users.get_id(user_name=user_name)
except Exception as e:
# Handle specific exceptions or log the error
raise ValueError("Error fetching id: " + str(e))
@ -181,10 +185,10 @@ class BrighterTrades:
:param default_keys: default API keys.
:return: bool - True on success.
"""
active_exchanges = self.config.users.get_exchanges(user_name, category='active_exchanges')
active_exchanges = self.users.get_exchanges(user_name, category='active_exchanges')
success = False
for exchange in active_exchanges:
keys = self.config.users.get_api_keys(user_name, exchange)
keys = self.users.get_api_keys(user_name, exchange)
success = self.connect_or_config_exchange(user_name=user_name,
exchange_name=exchange,
api_keys=keys)
@ -202,7 +206,7 @@ class BrighterTrades:
:param user_name: str - The name of the user making the query.
"""
chart_view = self.config.users.get_chart_view(user_name=user_name)
chart_view = self.users.get_chart_view(user_name=user_name)
indicator_types = self.indicators.indicator_types
available_indicators = self.indicators.get_indicator_list(user_name)
@ -231,19 +235,20 @@ class BrighterTrades:
:return: A dictionary containing the requested data.
"""
chart_view = self.config.users.get_chart_view(user_name=user_name)
chart_view = self.users.get_chart_view(user_name=user_name)
exchange = self.exchanges.get_exchange(ename=chart_view.get('exchange'), uname=user_name)
# noinspection PyDictCreation
r_data = {}
r_data['title'] = self.config.app_data.get('application_title', '')
r_data['title'] = self.config.get_setting('application_title')
r_data['chart_interval'] = chart_view.get('timeframe', '')
r_data['selected_exchange'] = chart_view.get('exchange', '')
r_data['intervals'] = exchange.intervals if exchange else []
r_data['symbols'] = exchange.get_symbols() if exchange else {}
r_data['available_exchanges'] = self.exchanges.get_available_exchanges() or []
r_data['connected_exchanges'] = self.exchanges.get_connected_exchanges(user_name) or []
r_data['configured_exchanges'] = self.config.users.get_exchanges(user_name, category='configured_exchanges') or []
r_data['configured_exchanges'] = self.users.get_exchanges(
user_name, category='configured_exchanges') or []
r_data['my_balances'] = self.exchanges.get_all_balances(user_name) or {}
r_data['indicator_types'] = self.indicators.indicator_types or []
r_data['indicator_list'] = self.indicators.get_indicator_list(user_name) or []
@ -295,7 +300,7 @@ class BrighterTrades:
return "The new signal must have a 'name' attribute."
self.signals.new_signal(data)
self.config.update_data('signals', self.signals.get_signals('dict'))
self.config.set_setting('signals_list', self.signals.get_signals('dict'))
return data
def received_new_strategy(self, data: dict) -> str | dict:
@ -309,7 +314,7 @@ class BrighterTrades:
return "The new strategy must have a 'name' attribute."
self.strategies.new_strategy(data)
self.config.update_data('strategies', self.strategies.get_strategies('dict'))
self.config.set_setting('strategies', self.strategies.get_strategies('dict'))
return data
def delete_strategy(self, strategy_name: str) -> None:
@ -377,10 +382,9 @@ class BrighterTrades:
success = self.exchanges.connect_exchange(exchange_name=exchange_name, user_name=user_name,
api_keys=api_keys)
if success:
self.config.users.active_exchange(exchange=exchange_name, user_name=user_name, cmd='set')
self.users.active_exchange(exchange=exchange_name, user_name=user_name, cmd='set')
if api_keys:
self.config.users.update_api_keys(api_keys=api_keys, exchange=exchange_name,
user_name=user_name)
self.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, user_name=user_name)
return True
else:
return False # Failed to connect
@ -391,7 +395,7 @@ class BrighterTrades:
else:
# Exchange is already connected, update API keys if provided
if api_keys:
self.config.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, user_name=user_name)
self.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, user_name=user_name)
return True # Already connected
def close_trade(self, trade_id):
@ -470,19 +474,19 @@ class BrighterTrades:
if setting == 'interval':
interval_state = params['timeframe']
self.config.users.set_chart_view(values=interval_state, specific_property='timeframe', user_name=user_name)
self.users.set_chart_view(values=interval_state, specific_property='timeframe', user_name=user_name)
elif setting == 'trading_pair':
trading_pair = params['symbol']
self.config.users.set_chart_view(values=trading_pair, specific_property='market', user_name=user_name)
self.users.set_chart_view(values=trading_pair, specific_property='market', user_name=user_name)
elif setting == 'exchange':
exchange_name = params['exchange_name']
# Get the first result of a list of available symbols from this exchange_name.
market = self.exchanges.get_exchange(ename=exchange_name, uname=user_name).get_symbols()[0]
# Change the exchange in the chart view and pass it a default market to display.
self.config.users.set_chart_view(values=exchange_name, specific_property='exchange_name',
user_name=user_name, default_market=market)
self.users.set_chart_view(values=exchange_name, specific_property='exchange_name',
user_name=user_name, default_market=market)
elif setting == 'toggle_indicator':
indicators_to_toggle = params.getlist('indicator')

View File

@ -1,130 +1,279 @@
import pandas
import yaml
from Signals import Signals
from indicators import Indicators
from Users import Users
import os
import time
import logging
from typing import Any
logger = logging.getLogger(__name__)
class Configuration:
def __init__(self, cache):
# ************** Default values**************
"""
Manages the application's settings, loading and saving them to a YAML file.
- Automatically loads the settings from file when instantiated.
- Settings can be added or modified with set_setting('key', value)
or by editing the file directly.
"""
self.app_data = {
'application_title': 'BrighterTrades', # The title of our program.
'max_data_loaded': 1000 # The maximum number of candles to store in memory.
def __init__(self, config_file='config.yml'):
"""Initializes with default settings and loads saved data."""
self.config_file = config_file
self.save_in_progress = False # Persistent flag to prevent infinite recursion
self._set_default_settings()
self.manage_config('load')
def _set_default_settings(self):
"""Set default settings."""
self.settings = {
'application_title': 'BrighterTrades',
'max_data_loaded': 1000
}
# The object that interacts with the database.
self.data = cache.db
def _generate_config_data(self):
"""Generates a list of settings to be saved to the config file."""
return self.settings.copy()
# The object that manages users in the system.
self.users = Users(data_cache=cache)
# The name of the file that stores saved_data
self.config_FN = 'config.yml'
# A list of all the available Signals.
# Calls a static method of Signals that initializes a default list.
self.signals_list = Signals.get_signals_defaults()
# list of all the available strategies.
self.strategies_list = []
# list of trades.
self.trades = []
# The data that will be saved and loaded from file .
self.saved_data = None
# Load any saved data from file
self.config_and_states('load')
def update_data(self, data_type, data):
def manage_config(self, cmd: str):
"""
Replace current list of data sets with an updated list.
:param data_type: The data being replaced
:param data: The replacement data.
:return: None.
Loads or saves settings to the config file according to cmd: 'load' | 'save'
"""
if data_type == 'strategies':
self.strategies_list = data
elif data_type == 'signals':
self.signals_list = data
elif data_type == 'trades':
self.trades = data
else:
raise ValueError(f'Configuration: update_data(): Unsupported data_type: {data_type}')
# Save it to file.
self.config_and_states('save')
def remove(self, what, name):
# Removes by name an item from a list in saved data.
def _apply_loaded_settings(data):
"""Updates settings from loaded data and marks if saving is needed."""
needs_save = False
for key, value in data.items():
if key not in self.settings or self.settings[key] != value:
self.settings[key] = value
needs_save = True
return needs_save
# Trades are indexed by unique_id, while Signals and Strategies are indexed by name.
if what == 'trades':
prop = 'unique_id'
else:
prop = 'name'
def _load_config_from_file(filepath):
try:
with open(filepath, "r") as file_descriptor:
return yaml.safe_load(file_descriptor)
except yaml.YAMLError:
timestamp = time.strftime("%Y%m%d-%H%M%S")
backup_path = f"{filepath}.{timestamp}.backup"
os.rename(filepath, backup_path)
logging.warning(f"Corrupt YAML file detected. Backup saved to {backup_path}")
logging.info(f"Recreating the configuration file with default settings.")
return None
for obj in self.saved_data[what]:
if obj[prop] == name:
self.saved_data[what].remove(obj)
break
# Save it to file.
self.config_and_states('save')
def config_and_states(self, cmd):
"""Loads or saves configurable data to the file set in self.config_FN"""
# The data stored and retrieved from file session.
self.saved_data = {
'signals': self.signals_list,
'strategies': self.strategies_list,
'trades': self.trades
}
def set_loaded_values():
# Sets the values in the saved_data object.
if 'signals' in self.saved_data:
self.signals_list = self.saved_data['signals']
if 'strategies' in self.saved_data:
self.strategies_list = self.saved_data['strategies']
if 'trades' in self.saved_data:
self.trades = self.saved_data['trades']
def load_configuration(filepath):
"""load file data"""
with open(filepath, "r") as file_descriptor:
data = yaml.safe_load(file_descriptor)
return data
def save_configuration(filepath, data):
"""Saves file data"""
with open(filepath, "w") as file_descriptor:
yaml.dump(data, file_descriptor)
def _save_config_to_file(filepath, data):
try:
with open(filepath, "w") as file_descriptor:
yaml.dump(data, file_descriptor)
except (IOError, OSError) as e:
logging.error(f"Failed to save configuration to {filepath}: {e}")
raise ValueError(f"Failed to save configuration to {filepath}: {e}")
if cmd == 'load':
# If load_configuration() finds a file it overwrites
# the saved_data object otherwise it creates a new file
# with the defaults contained in saved_data>
# If file exist load the values.
try:
self.saved_data = load_configuration(self.config_FN)
set_loaded_values()
# If file doesn't exist create a file and save the default values.
except IOError:
save_configuration(self.config_FN, self.saved_data)
if os.path.exists(self.config_file):
loaded_data = _load_config_from_file(self.config_file)
if loaded_data is None: # Corrupt file case, recreate it
_save_config_to_file(self.config_file, self._generate_config_data())
elif _apply_loaded_settings(loaded_data) and not self.save_in_progress:
self.save_in_progress = True
self.manage_config('save')
self.save_in_progress = False
else:
logging.info(f"Configuration file not found. Creating a new one at {self.config_file}.")
_save_config_to_file(self.config_file, self._generate_config_data())
elif cmd == 'save':
try:
# Write saved_data to the file.
save_configuration(self.config_FN, self.saved_data)
except IOError:
raise ValueError("save_configuration(): Couldn't save the file.")
_save_config_to_file(self.config_file, self._generate_config_data())
else:
raise ValueError('save_configuration(): Invalid command received.')
raise ValueError('manage_config(): Invalid command.')
def reset_settings_to_defaults(self):
"""Resets settings to default values."""
self._set_default_settings()
self.manage_config('save')
def get_setting(self, key: str) -> Any:
"""
Returns the value of the specified setting or None if the key is not found.
"""
return self.settings.get(key, None)
def set_setting(self, key: str, value: Any):
"""
Receives a key and value of any setting and saves the configuration.
"""
self.settings[key] = value
self.manage_config('save')
# import yaml
# from Signals import Signals
#
#
# class Configuration:
# """
# Configuration class manages the application's settings,
# signals, strategies, and trades. It loads and saves these
# configurations to a YAML file.
#
# Attributes:
# app_data (dict): Default application settings like title and maximum data load.
# data (object): Database interaction object from the data.
# config_FN (str): Filename for storing and loading configurations.
# signals_list (list): List of available signals initialized by the Signals class.
# strategies_list (list): List of strategies available for the application.
# trades (list): List of trades managed by the application.
# saved_data (dict): Data structure for saving/loading configuration data.
# """
#
# def __init__(self, data):
# """
# Initializes the Configuration object with default values and loads saved data.
#
# Args:
# data (object): Cache object with a database attribute to interact with the database.
# """
# # ************** Default values **************
# # Application metadata such as title and maximum data to be loaded.
# self.app_data = {
# 'application_title': 'BrighterTrades', # The title of the program.
# 'max_data_loaded': 1000 # Maximum number of candles to store in memory.
# }
#
# # The database object for interaction.
# self.data = data.db
#
# # Name of the configuration file.
# self.config_FN = 'config.yml'
#
# # List of available signals initialized with default values.
# self.signals_list = Signals.get_signals_defaults()
#
# # List to hold available strategies.
# self.strategies_list = []
#
# # List to hold trades.
# self.trades = []
#
# # Placeholder for data loaded from or to be saved to the file.
# self.saved_data = None
#
# # Load any saved data from the configuration file.
# self.config_and_states('load')
#
# def update_data(self, data_type, data):
# """
# Replace the current list of data sets with an updated list.
#
# Args:
# data_type (str): Type of data to be updated ('strategies', 'signals', or 'trades').
# data (list): The new data to replace the old one.
#
# Raises:
# ValueError: If the provided data_type is not supported.
# """
# if data_type == 'strategies':
# self.strategies_list = data
# elif data_type == 'signals':
# self.signals_list = data
# elif data_type == 'trades':
# self.trades = data
# else:
# raise ValueError(f'Configuration: update_data(): Unsupported data_type: {data_type}')
#
# # Save the updated data to the configuration file.
# self.config_and_states('save')
#
# def remove(self, what, name):
# """
# Removes an item by name from the saved data list.
#
# Args:
# what (str): Type of data to remove ('strategies', 'signals', or 'trades').
# name (str): The name or unique_id of the item to remove.
#
# Raises:
# ValueError: If the item with the specified name is not found.
# """
# # Determine the property to match based on the type of data.
# if what == 'trades':
# prop = 'unique_id'
# else:
# prop = 'name'
#
# # Remove the item from the list if it matches the name or unique_id.
# for obj in self.saved_data[what]:
# if obj[prop] == name:
# self.saved_data[what].remove(obj)
# break
#
# # Save the updated data to the configuration file.
# self.config_and_states('save')
#
# def config_and_states(self, cmd):
# """
# Loads or saves configurable data to the file set in self.config_FN.
#
# Args:
# cmd (str): Command to either 'load' or 'save' the configuration data.
#
# Raises:
# ValueError: If the command is neither 'load' nor 'save'.
# """
#
# # Data structure to hold the current state of signals, strategies, and trades.
# self.saved_data = {
# 'signals': self.signals_list,
# 'strategies': self.strategies_list,
# 'trades': self.trades
# }
#
# def set_loaded_values():
# """Sets the values in the saved_data object to the class attributes."""
# if 'signals' in self.saved_data:
# self.signals_list = self.saved_data['signals']
#
# if 'strategies' in self.saved_data:
# self.strategies_list = self.saved_data['strategies']
#
# if 'trades' in self.saved_data:
# self.trades = self.saved_data['trades']
#
# def load_configuration(filepath):
# """
# Load configuration data from a YAML file.
#
# Args:
# filepath (str): Path to the configuration file.
#
# Returns:
# dict: Loaded configuration data.
# """
# with open(filepath, "r") as file_descriptor:
# data = yaml.safe_load(file_descriptor)
# return data
#
# def save_configuration(filepath, data):
# """
# Save configuration data to a YAML file.
#
# Args:
# filepath (str): Path to the configuration file.
# data (dict): Data to save.
# """
# with open(filepath, "w") as file_descriptor:
# yaml.dump(data, file_descriptor)
#
# if cmd == 'load':
# try:
# # Attempt to load the configuration from the file.
# self.saved_data = load_configuration(self.config_FN)
# set_loaded_values()
# except IOError:
# # If the file doesn't exist, save the default values.
# save_configuration(self.config_FN, self.saved_data)
#
# elif cmd == 'save':
# try:
# # Save the current state to the configuration file.
# save_configuration(self.config_FN, self.saved_data)
# except IOError:
# raise ValueError("save_configuration(): Couldn't save the file.")
# else:
# raise ValueError('save_configuration(): Invalid command received.')

View File

@ -1,4 +1,5 @@
from typing import List, Any
import json
from typing import List, Any, Tuple
import pandas as pd
import datetime as dt
import logging
@ -59,23 +60,31 @@ def estimate_record_count(start_time, end_time, timeframe: str) -> int:
return int(expected_records)
def generate_expected_timestamps(start_datetime: dt.datetime, end_datetime: dt.datetime,
timeframe: str) -> pd.DatetimeIndex:
if start_datetime.tzinfo is None:
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
if end_datetime.tzinfo is None:
raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.")
delta = timeframe_to_timedelta(timeframe)
if isinstance(delta, pd.Timedelta):
return pd.date_range(start=start_datetime, end=end_datetime, freq=delta)
elif isinstance(delta, pd.DateOffset):
current = start_datetime
timestamps = []
while current <= end_datetime:
timestamps.append(current)
current += delta
return pd.DatetimeIndex(timestamps)
# def generate_expected_timestamps(start_datetime: dt.datetime, end_datetime: dt.datetime,
# timeframe: str) -> pd.DatetimeIndex:
# """
# What it says. Todo: confirm this is unused and archive.
#
# :param start_datetime:
# :param end_datetime:
# :param timeframe:
# :return:
# """
# if start_datetime.tzinfo is None:
# raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
# if end_datetime.tzinfo is None:
# raise ValueError("end_datetime is timezone naive. Please provide a timezone-aware datetime.")
#
# delta = timeframe_to_timedelta(timeframe)
# if isinstance(delta, pd.Timedelta):
# return pd.date_range(start=start_datetime, end=end_datetime, freq=delta)
# elif isinstance(delta, pd.DateOffset):
# current = start_datetime
# timestamps = []
# while current <= end_datetime:
# timestamps.append(current)
# current += delta
# return pd.DatetimeIndex(timestamps)
class DataCache:
@ -84,10 +93,170 @@ class DataCache:
def __init__(self, exchanges):
self.db = Database()
self.exchanges = exchanges
self.cached_data = {}
# Single DataFrame for all cached data
self.cache = pd.DataFrame(columns=['key', 'data']) # Assuming 'key' and 'data' are necessary
logger.info("DataCache initialized.")
def fetch_cached_rows(self, table: str, filter_vals: Tuple[str, Any]) -> pd.DataFrame | None:
"""
Retrieves rows from the cache if available; otherwise, queries the database and caches the result.
:param table: Name of the database table to query.
:param filter_vals: A tuple containing the column name and the value to filter by.
:return: A DataFrame containing the requested rows, or None if no matching rows are found.
"""
# Construct a filter condition for the cache based on the table name and filter values.
cache_filter = (self.cache['table'] == table) & (self.cache[filter_vals[0]] == filter_vals[1])
cached_rows = self.cache[cache_filter]
# If the data is found in the cache, return it.
if not cached_rows.empty:
return cached_rows
# If the data is not found in the cache, query the database.
rows = self.db.get_rows_where(table, filter_vals)
if rows is not None:
# Tag the rows with the table name and add them to the cache.
rows['table'] = table
self.cache = pd.concat([self.cache, rows])
return rows
def remove_row(self, filter_vals: Tuple[str, Any], additional_filter: Tuple[str, Any] = None,
remove_from_db: bool = True, table: str = None) -> None:
"""
Removes a specific row from the cache and optionally from the database based on filter criteria.
:param filter_vals: A tuple containing the column name and the value to filter by.
:param additional_filter: An optional additional filter to apply.
:param remove_from_db: If True, also removes the row from the database. Default is True.
:param table: The name of the table from which to remove the row in the database (optional).
"""
logger.debug(
f"Removing row from cache: filter={filter_vals},"
f" additional_filter={additional_filter}, remove_from_db={remove_from_db}, table={table}")
# Construct the filter condition for the cache
cache_filter = (self.cache[filter_vals[0]] == filter_vals[1])
if additional_filter:
cache_filter = cache_filter & (self.cache[additional_filter[0]] == additional_filter[1])
# Remove the row from the cache
self.cache = self.cache.drop(self.cache[cache_filter].index)
logger.info(f"Row removed from cache: filter={filter_vals}")
if remove_from_db and table:
# Construct the SQL query to delete from the database
sql = f"DELETE FROM {table} WHERE {filter_vals[0]} = ?"
params = [filter_vals[1]]
if additional_filter:
sql += f" AND {additional_filter[0]} = ?"
params.append(additional_filter[1])
# Execute the SQL query to remove the row from the database
self.db.execute_sql(sql, tuple(params))
logger.info(
f"Row removed from database: table={table}, filter={filter_vals},"
f" additional_filter={additional_filter}")
def is_attr_taken(self, table: str, attr: str, val: Any) -> bool:
"""
Checks if a specific attribute in a table is already taken.
:param table: The name of the table to check.
:param attr: The attribute to check (e.g., 'user_name', 'email').
:param val: The value of the attribute to check.
:return: True if the attribute is already taken, False otherwise.
"""
# Fetch rows from the specified table where the attribute matches the given value
result = self.fetch_cached_rows(table=table, filter_vals=(attr, val))
return result is not None and not result.empty
def fetch_cached_item(self, item_name: str, table_name: str, filter_vals: Tuple[str, Any]) -> Any:
"""
Retrieves a specific item from the cache or database, caching the result if necessary.
:param item_name: The name of the column to retrieve.
:param table_name: The name of the table where the item is stored.
:param filter_vals: A tuple containing the column name and the value to filter by.
:return: The value of the requested item.
:raises ValueError: If the item is not found in either the cache or the database.
"""
# Fetch the relevant rows
rows = self.fetch_cached_rows(table_name, filter_vals)
if rows is not None and not rows.empty:
# Return the specific item from the first matching row.
return rows.iloc[0][item_name]
# If the item is not found, raise an error.
raise ValueError(f"Item {item_name} not found in {table_name} where {filter_vals[0]} = {filter_vals[1]}")
def modify_cached_row(self, table: str, filter_vals: Tuple[str, Any], field_name: str, new_data: Any) -> None:
"""
Modifies a specific field in a row within the cache and updates the database accordingly.
:param table: The name of the table where the data is stored.
:param filter_vals: A tuple containing the column name and the value to filter by.
:param field_name: The field to be updated.
:param new_data: The new data to be set.
"""
# Retrieve the row from the cache or database
row = self.fetch_cached_rows(table, filter_vals)
if row is None or row.empty:
raise ValueError(f"Row not found in cache or database for {filter_vals[0]} = {filter_vals[1]}")
# Modify the specified field
if isinstance(new_data, str):
row.loc[0, field_name] = new_data
else:
# If new_data is not a string, its converted to a JSON string before being inserted into the DataFrame.
row.loc[0, field_name] = json.dumps(new_data)
# Update the cache by removing the old entry and adding the modified row
self.cache = self.cache.drop(
self.cache[(self.cache['table'] == table) & (self.cache[filter_vals[0]] == filter_vals[1])].index
)
self.cache = pd.concat([self.cache, row])
# Update the database with the modified row
self.db.insert_dataframe(row.drop(columns='id'), table)
def insert_data(self, df: pd.DataFrame, table: str, skip_cache: bool = False) -> None:
"""
Inserts data into the specified table in the database, with an option to skip cache insertion.
:param df: The DataFrame containing the data to insert.
:param table: The name of the table where the data should be inserted.
:param skip_cache: If True, skips inserting the data into the cache. Default is False.
"""
# Insert the data into the database
self.db.insert_dataframe(df=df, table=table)
# Optionally insert the data into the cache
if not skip_cache:
df['table'] = table # Add table name for cache identification
self.cache = pd.concat([self.cache, df])
def insert_row(self, table: str, columns: tuple, values: tuple) -> None:
"""
Inserts a single row into the specified table in the database.
:param table: The name of the table where the row should be inserted.
:param columns: A tuple of column names corresponding to the values.
:param values: A tuple of values to insert into the specified columns.
"""
self.db.insert_row(table=table, columns=columns, values=values)
def get_records_since(self, start_datetime: dt.datetime, ex_details: List[str]) -> pd.DataFrame:
"""
This gets up-to-date records from a specified market and exchange.
:param start_datetime: The approximate time the first record should represent.
:param ex_details: The user exchange and market.
:return: The records.
"""
if self.TYPECHECKING_ENABLED:
if not isinstance(start_datetime, dt.datetime):
raise TypeError("start_datetime must be a datetime object")
@ -106,12 +275,20 @@ class DataCache:
'end_datetime': end_datetime,
'ex_details': ex_details,
}
return self.get_or_fetch_from('cache', **args)
return self._get_or_fetch_from('data', **args)
except Exception as e:
logger.error(f"An error occurred: {str(e)}")
raise
def get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame:
def _get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame:
"""
Fetches market records from a resource stack (data, database, exchange).
fills incomplete request by fetching down the stack then updates the rest.
:param target: Starting point for the fetch. ['data', 'database', 'exchange']
:param kwargs: Details and credentials for the request.
:return: Records in a dataframe.
"""
start_datetime = kwargs.get('start_datetime')
if start_datetime.tzinfo is None:
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
@ -144,12 +321,12 @@ class DataCache:
key = self._make_key(ex_details)
combined_data = pd.DataFrame()
if target == 'cache':
resources = [self.get_candles_from_cache, self.get_from_database, self.get_from_server]
if target == 'data':
resources = [self._get_candles_from_cache, self._get_from_database, self._get_from_server]
elif target == 'database':
resources = [self.get_from_database, self.get_from_server]
resources = [self._get_from_database, self._get_from_server]
elif target == 'server':
resources = [self.get_from_server]
resources = [self._get_from_server]
else:
raise ValueError('Not a valid Target!')
@ -165,11 +342,11 @@ class DataCache:
combined_data = pd.concat([combined_data, result]).drop_duplicates(subset='open_time').sort_values(
by='open_time')
is_complete, request_criteria = self.data_complete(combined_data, **request_criteria)
is_complete, request_criteria = self._data_complete(combined_data, **request_criteria)
if is_complete:
if fetch_method in [self.get_from_database, self.get_from_server]:
self.update_candle_cache(combined_data, key)
if fetch_method == self.get_from_server:
if fetch_method in [self._get_from_database, self._get_from_server]:
self._update_candle_cache(combined_data, key)
if fetch_method == self._get_from_server:
self._populate_db(ex_details, combined_data)
return combined_data
@ -178,7 +355,7 @@ class DataCache:
logger.error('Unable to fetch the requested data.')
return combined_data if not combined_data.empty else pd.DataFrame()
def get_candles_from_cache(self, **kwargs) -> pd.DataFrame:
def _get_candles_from_cache(self, **kwargs) -> pd.DataFrame:
start_datetime = kwargs.get('start_datetime')
if start_datetime.tzinfo is None:
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
@ -200,7 +377,7 @@ class DataCache:
raise ValueError("Missing required arguments")
key = self._make_key(ex_details)
logger.debug('Getting records from cache.')
logger.debug('Getting records from data.')
df = self.get_cache(key)
if df is None:
logger.debug("Cache records didn't exist.")
@ -210,7 +387,7 @@ class DataCache:
df['open_time'] <= unix_time_millis(end_datetime))].reset_index(drop=True)
return df_filtered
def get_from_database(self, **kwargs) -> pd.DataFrame:
def _get_from_database(self, **kwargs) -> pd.DataFrame:
start_datetime = kwargs.get('start_datetime')
if start_datetime.tzinfo is None:
raise ValueError("start_datetime is timezone naive. Please provide a timezone-aware datetime.")
@ -240,7 +417,7 @@ class DataCache:
return self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=start_datetime,
et=end_datetime)
def get_from_server(self, **kwargs) -> pd.DataFrame:
def _get_from_server(self, **kwargs) -> pd.DataFrame:
symbol = kwargs.get('ex_details')[0]
interval = kwargs.get('ex_details')[1]
exchange_name = kwargs.get('ex_details')[2]
@ -272,7 +449,7 @@ class DataCache:
end_datetime)
@staticmethod
def data_complete(data: pd.DataFrame, **kwargs) -> (bool, dict):
def _data_complete(data: pd.DataFrame, **kwargs) -> (bool, dict):
"""
Checks if the data completely satisfies the request.
@ -341,17 +518,21 @@ class DataCache:
return True, kwargs
def cache_exists(self, key: str) -> bool:
return key in self.cached_data
return key in self.cache['key'].values
def get_cache(self, key: str) -> Any | None:
if key not in self.cached_data:
logger.warning(f"The requested cache key({key}) doesn't exist!")
# Check if the key exists in the cache
if key not in self.cache['key'].values:
logger.warning(f"The requested data key ({key}) doesn't exist!")
return None
return self.cached_data[key]
def update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None:
logger.debug('Updating cache with new records.')
# Concatenate the new records with the existing cache
# Retrieve the data associated with the key
result = self.cache[self.cache['key'] == key]['data'].iloc[0]
return result
def _update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None:
logger.debug('Updating data with new records.')
# Concatenate the new records with the existing data
records = pd.concat([self.get_cache(key), more_records], axis=0, ignore_index=True)
# Drop duplicates based on 'open_time' and keep the first occurrence
records = records.drop_duplicates(subset="open_time", keep='first')
@ -359,24 +540,49 @@ class DataCache:
records = records.sort_values(by='open_time').reset_index(drop=True)
# Reindex 'id' to ensure the expected order
records['id'] = range(1, len(records) + 1)
# Set the updated DataFrame back to cache
# Set the updated DataFrame back to data
self.set_cache(data=records, key=key)
def update_cached_dict(self, cache_key: str, dict_key: str, data: Any) -> None:
"""
Updates a dictionary stored in cache.
Updates a dictionary stored in the DataFrame cache.
:param data: The data to insert into cache.
:param cache_key: The cache index key for the dictionary.
:param dict_key: The dictionary key for the data.
:param data: The data to insert into the dictionary.
:param cache_key: The cache key for the dictionary.
:param dict_key: The key within the dictionary to update.
:return: None
"""
self.cached_data[cache_key].update({dict_key: data})
# Locate the row in the DataFrame that matches the cache_key
cache_index = self.cache.index[self.cache['key'] == cache_key]
if not cache_index.empty:
# Update the dictionary stored in the 'data' column
cache_dict = self.cache.at[cache_index[0], 'data']
if isinstance(cache_dict, dict):
cache_dict[dict_key] = data
# Ensure the DataFrame is updated with the new dictionary
self.cache.at[cache_index[0], 'data'] = cache_dict
else:
raise ValueError(f"Expected a dictionary in cache, but found {type(cache_dict)}.")
else:
raise KeyError(f"Cache key '{cache_key}' not found.")
def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None:
if do_not_overwrite and key in self.cached_data:
if do_not_overwrite and key in self.cache['key'].values:
return
self.cached_data[key] = data
# Corrected construction of the new row
new_row = pd.DataFrame({'key': [key], 'data': [data]})
# If the key already exists, drop the old entry
self.cache = self.cache[self.cache['key'] != key]
# Append the new row to the cache
self.cache = pd.concat([self.cache, new_row], ignore_index=True)
print(f'Current Cache: {self.cache}')
logger.debug(f'Cache set for key: {key}')
def _fetch_candles_from_exchange(self, symbol: str, interval: str, exchange_name: str, user_name: str,
@ -422,12 +628,12 @@ class DataCache:
if num_rec_records < estimated_num_records:
logger.info('Detected gaps in the data, attempting to fill missing records.')
candles = self.fill_data_holes(candles, interval)
candles = self._fill_data_holes(candles, interval)
return candles
@staticmethod
def fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame:
def _fill_data_holes(records: pd.DataFrame, interval: str) -> pd.DataFrame:
time_span = timeframe_to_timedelta(interval).total_seconds() / 60
last_timestamp = None
filled_records = []

View File

@ -33,7 +33,7 @@ class SQLite:
class HDict(dict):
"""
Hashable dictionary to use as cache keys.
Hashable dictionary to use as data keys.
Example usage:
--------------
@ -85,15 +85,16 @@ class Database:
def __init__(self, db_file: str = None):
self.db_file = db_file
def execute_sql(self, sql: str) -> None:
def execute_sql(self, sql: str, params: tuple = ()) -> None:
"""
Executes a raw SQL statement.
Executes a raw SQL statement with optional parameters.
:param sql: SQL statement to execute.
:param params: Optional tuple of parameters to pass with the SQL statement.
"""
with SQLite(self.db_file) as con:
cur = con.cursor()
cur.execute(sql)
cur.execute(sql, params)
def get_item_where(self, item_name: str, table_name: str, filter_vals: Tuple[str, Any]) -> Any:
"""
@ -120,12 +121,17 @@ class Database:
:param table: Name of the table.
:param filter_vals: Tuple of column name and value to filter by.
:return: DataFrame of the query result or None if empty.
:return: DataFrame of the query result or None if empty or column does not exist.
"""
with SQLite(self.db_file) as con:
qry = f"SELECT * FROM {table} WHERE {filter_vals[0]} = ?"
result = pd.read_sql(qry, con, params=(filter_vals[1],))
return result if not result.empty else None
try:
with SQLite(self.db_file) as con:
qry = f"SELECT * FROM {table} WHERE {filter_vals[0]} = ?"
result = pd.read_sql(qry, con, params=(filter_vals[1],))
return result if not result.empty else None
except (sqlite3.OperationalError, pd.errors.DatabaseError) as e:
# Log the error or handle it appropriately
print(f"Error querying table '{table}' for column '{filter_vals[0]}': {e}")
return None
def insert_dataframe(self, df: pd.DataFrame, table: str) -> None:
"""

View File

@ -52,11 +52,18 @@ class Signal:
class Signals:
def __init__(self, loaded_signals=None):
def __init__(self, config):
# list of Signal objects.
self.signals = []
# load a list of existing signals from file.
loaded_signals = config.get_setting('signals_list')
if loaded_signals is None:
# Populate the list and file with defaults defined in this class.
loaded_signals = self.get_signals_defaults()
config.set_setting('signals_list', loaded_signals)
# Initialize signals with loaded data.
if loaded_signals is not None:
self.create_signal_from_dic(loaded_signals)

View File

@ -1,5 +1,5 @@
import json
from DataCache_v2 import DataCache
class Strategy:
def __init__(self, **args):
@ -154,14 +154,26 @@ class Strategy:
class Strategies:
def __init__(self, loaded_strats, trades):
def __init__(self, data: DataCache, trades):
# Handles database connections
self.data = data
# Reference to the trades object that maintains all trading actions and data.
self.trades = trades
# A list of all the Strategies created.
self.strat_list = []
# Initialise all the stately objects with the data saved to file.
for entry in loaded_strats:
self.strat_list.append(Strategy(**entry))
def get_all_strategy_names(self) -> list | None:
"""Return a list of all strategies in the database"""
self.data._get_from_database()
# Load existing Strategies from file.
loaded_strategies = config.get_setting('strategies')
if loaded_strategies is None:
# Populate the list and file with defaults defined in this class.
loaded_strategies = self.get_strategy_defaults()
config.set_setting('strategies', loaded_strategies)
for entry in loaded_strategies:
# Initialise all the strategy objects with data from file.
self.strat_list.append(Strategy(**entry)) return None
def new_strategy(self, data):
# Create an instance of the new Strategy.
@ -238,6 +250,7 @@ class Strategies:
published strategies and evaluates conditions against the data.
This function returns a list of strategies and action commands.
"""
def process_strategy(strategy):
action, cmd = strategy.evaluate_strategy(signals)
if action != 'do_nothing':
@ -262,4 +275,3 @@ class Strategies:
return False
else:
return return_obj

File diff suppressed because it is too large Load Diff

View File

@ -7,7 +7,6 @@ from flask_sock import Sock
from email_validator import validate_email, EmailNotValidError
# Handles all updates and requests for locally stored data.
import config
from BrighterTrades import BrighterTrades
# Set up logging
@ -54,7 +53,7 @@ def index():
try:
# Log the user in.
user_name = brighter_trades.config.users.load_or_create_user(username=session.get('user'))
user_name = brighter_trades.users.load_or_create_user(username=session.get('user'))
except ValueError as e:
if str(e) != 'GuestLimitExceeded!':
raise

View File

@ -37,16 +37,16 @@ class DataCache:
def cache_exists(self, key: str) -> bool:
"""
Checks if a cache exists for the given key.
Checks if a data exists for the given key.
:param key: The access key.
:return: True if cache exists, False otherwise.
:return: True if data exists, False otherwise.
"""
return key in self.cached_data
def update_candle_cache(self, more_records: pd.DataFrame, key: str) -> None:
"""
Adds records to existing cache.
Adds records to existing data.
:param more_records: The new records to be added.
:param key: The access key.
@ -59,9 +59,9 @@ class DataCache:
def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None:
"""
Creates a new cache key and inserts data.
Creates a new data key and inserts data.
:param data: The records to insert into cache.
:param data: The records to insert into data.
:param key: The index key for the data.
:param do_not_overwrite: Flag to prevent overwriting existing data.
:return: None
@ -72,10 +72,10 @@ class DataCache:
def update_cached_dict(self, cache_key: str, dict_key: str, data: Any) -> None:
"""
Updates a dictionary stored in cache.
Updates a dictionary stored in data.
:param data: The data to insert into cache.
:param cache_key: The cache index key for the dictionary.
:param data: The data to insert into data.
:param cache_key: The data index key for the dictionary.
:param dict_key: The dictionary key for the data.
:return: None
"""
@ -89,7 +89,7 @@ class DataCache:
:return: Any|None - The requested data or None on key error.
"""
if key not in self.cached_data:
logger.warning(f"The requested cache key({key}) doesn't exist!")
logger.warning(f"The requested data key({key}) doesn't exist!")
return None
return self.cached_data[key]
@ -98,14 +98,14 @@ class DataCache:
"""
Fetches records since the specified start datetime.
:param key: The cache key.
:param key: The data key.
:param start_datetime: The start datetime to fetch records from.
:param record_length: The required number of records.
:param ex_details: Exchange details.
:return: DataFrame containing the records.
"""
try:
target = 'cache'
target = 'data'
args = {
'key': key,
'start_datetime': start_datetime,
@ -167,7 +167,7 @@ class DataCache:
'record_length': record_length,
}
if target == 'cache':
if target == 'data':
result = get_from_cache()
if data_complete(result, **request_criteria):
return result
@ -193,7 +193,7 @@ class DataCache:
"""
Fetches records since the specified start datetime.
:param key: The cache key.
:param key: The data key.
:param start_datetime: The start datetime to fetch records from.
:param record_length: The required number of records.
:param ex_details: Exchange details.
@ -203,11 +203,11 @@ class DataCache:
end_datetime = dt.datetime.utcnow()
if self.cache_exists(key=key):
logger.debug('Getting records from cache.')
logger.debug('Getting records from data.')
records = self.get_cache(key)
else:
logger.debug(
f'Records not in cache. Requesting from DB: starting at: {start_datetime} to: {end_datetime}')
f'Records not in data. Requesting from DB: starting at: {start_datetime} to: {end_datetime}')
records = self.get_records_since_from_db(table_name=key, st=start_datetime, et=end_datetime,
rl=record_length, ex_details=ex_details)
logger.debug(f'Got {len(records.index)} records from DB.')

View File

@ -9,23 +9,23 @@ from shared_utilities import timeframe_to_minutes, ts_of_n_minutes_ago
# log.basicConfig(level=log.ERROR)
class Candles:
def __init__(self, exchanges, config_obj, data_source):
def __init__(self, exchanges, users, data_source, config):
# A reference to the app configuration
self.config = config_obj
self.users = users
# The maximum amount of candles to load at one time.
self.max_records = self.config.app_data.get('max_data_loaded')
self.max_records = config.get_setting('max_data_loaded')
# This object maintains all the cached data.
self.data = data_source
# print('Setting the candle cache.')
# # Populate the cache:
# self.set_cache(symbol=self.config.users.get_chart_view(user_name='guest', specific_property='market'),
# interval=self.config.users.get_chart_view(user_name='guest', specific_property='timeframe'),
# exchange_name=self.config.users.get_chart_view(user_name='guest', specific_property='exchange_name'))
# print('DONE Setting cache')
# print('Setting the candle data.')
# # Populate the data:
# self.set_cache(symbol=self.users.get_chart_view(user_name='guest', specific_property='market'),
# interval=self.users.get_chart_view(user_name='guest', specific_property='timeframe'),
# exchange_name=self.users.get_chart_view(user_name='guest', specific_property='exchange_name'))
# print('DONE Setting data')
def get_last_n_candles(self, num_candles: int, asset: str, timeframe: str, exchange: str, user_name: str):
"""
@ -83,7 +83,7 @@ class Candles:
#
def set_cache(self, symbol=None, interval=None, exchange_name=None, user_name=None):
"""
This method requests a chart from memory to ensure the cache is initialized.
This method requests a chart from memory to ensure the data is initialized.
:param user_name:
:param symbol: str - The symbol of the market.
@ -91,18 +91,18 @@ class Candles:
:param exchange_name: str - The name of the exchange_name to fetch from.
:return: None
"""
# By default, initialise cache with the last viewed chart.
# By default, initialise data with the last viewed chart.
if not symbol:
assert user_name is not None
symbol = self.config.users.get_chart_view(user_name=user_name, prop='market')
symbol = self.users.get_chart_view(user_name=user_name, prop='market')
log.info(f'set_candle_history(): No symbol provided. Using{symbol}')
if not interval:
assert user_name is not None
interval = self.config.users.get_chart_view(user_name=user_name, prop='timeframe')
interval = self.users.get_chart_view(user_name=user_name, prop='timeframe')
log.info(f'set_candle_history(): No timeframe provided. Using{interval}')
if not exchange_name:
assert user_name is not None
exchange_name = self.config.users.get_chart_view(user_name=user_name, prop='exchange_name')
exchange_name = self.users.get_chart_view(user_name=user_name, prop='exchange_name')
# Log the completion to the console.
log.info('set_candle_history(): Loading candle data...')
@ -199,20 +199,20 @@ class Candles:
:param user_name: str - The name of the user who owns the exchange.
:return: list - Candle records in the lightweight charts format.
"""
# By default, initialise cache with the last viewed chart.
# By default, initialise data with the last viewed chart.
if not symbol:
assert user_name is not None
symbol = self.config.users.get_chart_view(user_name=user_name, prop='market''market')
symbol = self.users.get_chart_view(user_name=user_name, prop='market''market')
log.info(f'set_candle_history(): No symbol provided. Using{symbol}')
if not interval:
assert user_name is not None
interval = self.config.users.get_chart_view(user_name=user_name, prop='market''timeframe')
interval = self.users.get_chart_view(user_name=user_name, prop='market''timeframe')
log.info(f'set_candle_history(): No timeframe provided. Using{interval}')
if not exchange_name:
assert user_name is not None
exchange_name = self.config.users.get_chart_view(user_name=user_name, prop='market''exchange_name')
exchange_name = self.users.get_chart_view(user_name=user_name, prop='market''exchange_name')
log.info(f'get_candle_history(): No exchange name provided. Using {exchange_name}')
candlesticks = self.get_last_n_candles(num_candles=num_records, asset=symbol, timeframe=interval,

View File

@ -308,12 +308,12 @@ indicator_types.append('MACD')
class Indicators:
def __init__(self, candles, config):
def __init__(self, candles, users):
# Object manages and serves price and candle data.
self.candles = candles
# A connection to an object that handles user configuration and persistent data.
self.config = config
# A connection to an object that handles user data.
self.users = users
# Collection of instantiated indicators objects
self.indicators = pd.DataFrame(columns=['creator', 'name', 'visible',
@ -331,7 +331,7 @@ class Indicators:
Get the users watch-list from the database and load the indicators into a dataframe.
:return: None
"""
active_indicators: pd.DataFrame = self.config.users.get_indicators(user_name)
active_indicators: pd.DataFrame = self.users.get_indicators(user_name)
if active_indicators is not None:
# Create an instance for each indicator.
@ -347,7 +347,7 @@ class Indicators:
Saves the indicators in the database indexed by the user id.
:return: None
"""
self.config.users.save_indicators(indicator)
self.users.save_indicators(indicator)
# @staticmethod
# def get_indicator_defaults():
@ -389,7 +389,7 @@ class Indicators:
:param only_enabled: bool - If True, return only indicators marked as visible.
:return: dict - A dictionary of indicator names as keys and their attributes as values.
"""
user_id = self.config.users.get_id(username)
user_id = self.users.get_id(username)
if not user_id:
raise ValueError(f"Invalid user_name: {username}")
@ -498,7 +498,7 @@ class Indicators:
:param num_results: The number of results being requested.
:return: The results of the indicator analysis as a DataFrame.
"""
username = self.config.users.get_username(indicator.creator)
username = self.users.get_username(indicator.creator)
src = indicator.source
symbol, timeframe, exchange_name = src['symbol'], src['timeframe'], src['exchange_name']
@ -532,7 +532,7 @@ class Indicators:
if start_ts:
print("Warning: start_ts has not implemented in get_indicator_data()!")
user_id = self.config.users.get_id(user_name=user_name)
user_id = self.users.get_id(user_name=user_name)
# Construct the query based on user_id and visibility.
query = f"creator == {user_id}"
@ -582,7 +582,7 @@ class Indicators:
if not indicator_name:
raise ValueError("No indicator name provided.")
self.indicators = self.indicators.query("name != @indicator_name").reset_index(drop=True)
self.config.users.save_indicators()
self.users.save_indicators()
def create_indicator(self, creator: str, name: str, kind: str,
source: dict, properties: dict, visible: bool = True):
@ -618,7 +618,7 @@ class Indicators:
indicator = indicator_class(name, kind, properties)
# Add the new indicator to a pandas dataframe.
creator_id = self.config.users.get_id(creator)
creator_id = self.users.get_id(creator)
row_data = {
'creator': creator_id,
'name': name,

View File

@ -1,38 +1,33 @@
import ccxt
import pandas as pd
import datetime
import matplotlib.pyplot as plt
from matplotlib.table import Table
# Simulating the cache as a DataFrame
data = {
'key': ['BTC/USD_2h_binance', 'ETH/USD_1h_coinbase'],
'data': ['{"open": 50000, "close": 50500}', '{"open": 1800, "close": 1825}']
}
cache_df = pd.DataFrame(data)
def fetch_and_print_ohlcv(exchange_name, symbol, timeframe, since, limit=5):
# Initialize the exchange
exchange_class = getattr(ccxt, exchange_name)
exchange = exchange_class()
# Visualization function
def visualize_cache(df):
fig, ax = plt.subplots(figsize=(6, 3))
ax.set_axis_off()
tb = Table(ax, bbox=[0, 0, 1, 1])
# Fetch historical candlestick data with a limit
ohlcv = exchange.fetch_ohlcv(symbol, timeframe, since=since, limit=limit)
# Adding column headers
for i, column in enumerate(df.columns):
tb.add_cell(0, i, width=0.4, height=0.3, text=column, loc='center', facecolor='lightgrey')
# Convert to DataFrame for better readability
df = pd.DataFrame(ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
# Adding rows and cells
for i in range(len(df)):
for j, value in enumerate(df.iloc[i]):
tb.add_cell(i + 1, j, width=0.4, height=0.3, text=value, loc='center', facecolor='white')
# Print the first few rows of the DataFrame
print("First few rows of the fetched OHLCV data:")
print(df.head())
# Print the timestamps in human-readable format
df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms')
print("\nFirst few timestamps in human-readable format:")
print(df[['timestamp', 'datetime']].head())
# Confirm the format of the timestamps
print("\nTimestamp format confirmation:")
for ts in df['timestamp']:
print(f"{ts} (milliseconds since Unix epoch)")
ax.add_table(tb)
plt.title("Visualizing Cache Data")
plt.show()
# Example usage
exchange_name = 'binance' # Change this to your exchange
symbol = 'BTC/USDT'
timeframe = '5m'
since = int((datetime.datetime(2024, 8, 1) - datetime.datetime(1970, 1, 1)).total_seconds() * 1000)
fetch_and_print_ohlcv(exchange_name, symbol, timeframe, since, limit=5)
visualize_cache(cache_df)

View File

@ -267,10 +267,10 @@ class Trade:
class Trades:
def __init__(self, loaded_trades=None):
def __init__(self, users):
"""
This class receives, executes, tracks and stores all active_trades.
:param loaded_trades: <list?dict?> A bunch of active_trades to create and store.
:param users: <Users> A class that maintains users each user may have trades.
"""
# Object that maintains exchange_interface and account data
self.exchange_interface = None
@ -290,7 +290,8 @@ class Trades:
self.settled_trades = []
self.stats = {'num_trades': 0, 'total_position': 0, 'total_position_value': 0}
# Load any trades that were passed into the constructor.
# Load all trades.
loaded_trades = users.get_all_active_user_trades()
if loaded_trades is not None:
# Create the active_trades loaded from file.
self.load_trades(loaded_trades)

View File

@ -1,5 +1,5 @@
import pytz
from DataCache_v2 import DataCache
from DataCache_v2 import DataCache, timeframe_to_timedelta, estimate_record_count
from ExchangeInterface import ExchangeInterface
import unittest
import pandas as pd
@ -224,10 +224,23 @@ class TestDataCacheV2(unittest.TestCase):
exchange_id INTEGER,
FOREIGN KEY (exchange_id) REFERENCES exchange(id)
)"""
sql_create_table_4 = f"""
CREATE TABLE IF NOT EXISTS test_table_2 (
key TEXT PRIMARY KEY,
data TEXT NOT NULL
)"""
sql_create_table_5 = f"""
CREATE TABLE IF NOT EXISTS users (
users_data TEXT PRIMARY KEY,
data TEXT NOT NULL
)"""
with SQLite(db_file=self.db_file) as con:
con.execute(sql_create_table_1)
con.execute(sql_create_table_2)
con.execute(sql_create_table_3)
con.execute(sql_create_table_4)
con.execute(sql_create_table_5)
self.data = DataCache(self.exchanges)
self.data.db = self.database
@ -241,30 +254,41 @@ class TestDataCacheV2(unittest.TestCase):
def test_set_cache(self):
print('\nTesting set_cache() method without no-overwrite flag:')
self.data.set_cache(data='data', key=self.key)
attr = self.data.__getattribute__('cached_data')
self.assertEqual(attr[self.key], 'data')
print(' - Set cache without no-overwrite flag passed.')
# Access the cache data using the DataFrame structure
cached_value = self.data.get_cache(key=self.key)
self.assertEqual(cached_value, 'data')
print(' - Set data without no-overwrite flag passed.')
print('Testing set_cache() once again with new data without no-overwrite flag:')
self.data.set_cache(data='more_data', key=self.key)
attr = self.data.__getattribute__('cached_data')
self.assertEqual(attr[self.key], 'more_data')
print(' - Set cache with new data without no-overwrite flag passed.')
# Access the updated cache data
cached_value = self.data.get_cache(key=self.key)
self.assertEqual(cached_value, 'more_data')
print(' - Set data with new data without no-overwrite flag passed.')
print('Testing set_cache() method once again with more data with no-overwrite flag set:')
self.data.set_cache(data='even_more_data', key=self.key, do_not_overwrite=True)
attr = self.data.__getattribute__('cached_data')
self.assertEqual(attr[self.key], 'more_data')
print(' - Set cache with no-overwrite flag passed.')
# Since do_not_overwrite is True, the cached data should not change
cached_value = self.data.get_cache(key=self.key)
self.assertEqual(cached_value, 'more_data')
print(' - Set data with no-overwrite flag passed.')
def test_cache_exists(self):
print('Testing cache_exists() method:')
self.assertFalse(self.data.cache_exists(key=self.key))
print(' - Check for non-existent cache passed.')
# Check that the cache does not contain the key before setting it
self.assertFalse(self.data.cache_exists(key=self.key))
print(' - Check for non-existent data passed.')
# Set the cache with a DataFrame containing the key-value pair
self.data.set_cache(data='data', key=self.key)
# Check that the cache now contains the key
self.assertTrue(self.data.cache_exists(key=self.key))
print(' - Check for existent cache passed.')
print(' - Check for existent data passed.')
def test_update_candle_cache(self):
print('Testing update_candle_cache() method:')
@ -273,21 +297,21 @@ class TestDataCacheV2(unittest.TestCase):
data_gen = DataGenerator('5m')
# Create initial DataFrame and insert into cache
df_initial = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 0, 0))
df_initial = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 0, 0, tzinfo=dt.timezone.utc))
print(f'Inserting this table into cache:\n{df_initial}\n')
self.data.set_cache(data=df_initial, key=self.key)
# Create new DataFrame to be added to cache
df_new = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 15, 0))
df_new = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 15, 0, tzinfo=dt.timezone.utc))
print(f'Updating cache with this table:\n{df_new}\n')
self.data.update_candle_cache(more_records=df_new, key=self.key)
self.data._update_candle_cache(more_records=df_new, key=self.key)
# Retrieve the resulting DataFrame from cache
result = self.data.get_cache(key=self.key)
print(f'The resulting table in cache is:\n{result}\n')
# Create the expected DataFrame
expected = data_gen.create_table(num_rec=6, start=dt.datetime(2024, 8, 9, 0, 0, 0))
expected = data_gen.create_table(num_rec=6, start=dt.datetime(2024, 8, 9, 0, 0, 0, tzinfo=dt.timezone.utc))
print(f'The expected open_time values are:\n{expected["open_time"].tolist()}\n')
# Assert that the open_time values in the result match those in the expected DataFrame, in order
@ -295,37 +319,50 @@ class TestDataCacheV2(unittest.TestCase):
f"open_time values in result are {result['open_time'].tolist()}" \
f" but expected {expected['open_time'].tolist()}"
print(f'The results open_time values match:\n{result["open_time"].tolist()}\n')
print(f'The result open_time values match:\n{result["open_time"].tolist()}\n')
print(' - Update cache with new records passed.')
def test_update_cached_dict(self):
print('Testing update_cached_dict() method:')
# Set an empty dictionary in the cache for the specified key
self.data.set_cache(data={}, key=self.key)
# Update the cached dictionary with a new key-value pair
self.data.update_cached_dict(cache_key=self.key, dict_key='sub_key', data='value')
# Retrieve the updated cache
cache = self.data.get_cache(key=self.key)
# Verify that the 'sub_key' in the cached dictionary has the correct value
self.assertEqual(cache['sub_key'], 'value')
print(' - Update dictionary in cache passed.')
def test_get_cache(self):
print('Testing get_cache() method:')
# Set some data into the cache
self.data.set_cache(data='data', key=self.key)
# Retrieve the cached data using the get_cache method
result = self.data.get_cache(key=self.key)
# Verify that the result matches the data we set
self.assertEqual(result, 'data')
print(' - Retrieve cache passed.')
print(' - Retrieve data passed.')
def _test_get_records_since(self, set_cache=True, set_db=True, query_offset=None, num_rec=None, ex_details=None,
simulate_scenarios=None):
"""
Test the get_records_since() method by generating a table of simulated data,
inserting it into cache and/or database, and then querying the records.
inserting it into data and/or database, and then querying the records.
Parameters:
set_cache (bool): If True, the generated table is inserted into the cache.
set_db (bool): If True, the generated table is inserted into the database.
query_offset (int, optional): The offset in the timeframe units for the query.
num_rec (int, optional): The number of records to generate in the simulated table.
ex_details (list, optional): Exchange details to generate the cache key.
ex_details (list, optional): Exchange details to generate the data key.
simulate_scenarios (str, optional): The type of scenario to simulate. Options are:
- 'not_enough_data': The table data doesn't go far enough back.
- 'incomplete_data': The table doesn't have enough records to satisfy the query.
@ -336,7 +373,7 @@ class TestDataCacheV2(unittest.TestCase):
# Use provided ex_details or fallback to the class attribute.
ex_details = ex_details or self.ex_details
# Generate a cache/database key using exchange details.
# Generate a data/database key using exchange details.
key = f'{ex_details[0]}_{ex_details[1]}_{ex_details[2]}'
# Set default number of records if not provided.
@ -374,13 +411,13 @@ class TestDataCacheV2(unittest.TestCase):
print(f'Table Created:\n{temp_df}')
if set_cache:
# Insert the generated table into cache.
print('Inserting table into cache.')
# Insert the generated table into the cache.
print('Inserting table into the cache.')
self.data.set_cache(data=df_initial, key=key)
if set_db:
# Insert the generated table into the database.
print('Inserting table into database.')
print('Inserting table into the database.')
with SQLite(self.db_file) as con:
df_initial.to_sql(key, con, if_exists='replace', index=False)
@ -414,7 +451,7 @@ class TestDataCacheV2(unittest.TestCase):
# Check that the result has more rows than the expected incomplete data.
assert result.shape[0] > expected.shape[
0], "Result has fewer or equal rows compared to the incomplete data."
print("\nThe returned DataFrames has filled in the missing data!")
print("\nThe returned DataFrame has filled in the missing data!")
else:
# Ensure the result and expected dataframes match in shape and content.
assert result.shape == expected.shape, f"Shape mismatch: {result.shape} vs {expected.shape}"
@ -444,10 +481,10 @@ class TestDataCacheV2(unittest.TestCase):
print(' - Fetch records within the specified time range passed.')
def test_get_records_since(self):
print('\nTest get_records_since with records set in cache')
print('\nTest get_records_since with records set in data')
self._test_get_records_since()
print('\nTest get_records_since with records not in cache')
print('\nTest get_records_since with records not in data')
self._test_get_records_since(set_cache=False)
print('\nTest get_records_since with records not in database')
@ -467,21 +504,21 @@ class TestDataCacheV2(unittest.TestCase):
self._test_get_records_since(simulate_scenarios='missing_section')
def test_other_timeframes(self):
# print('\nTest get_records_since with a different timeframe')
# ex_details = ['BTC/USD', '15m', 'binance', 'test_guy']
# start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=2)
# # Query the records since the calculated start time.
# result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details)
# last_record_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC')
# assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=15.1)
#
# print('\nTest get_records_since with a different timeframe')
# ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
# start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=1)
# # Query the records since the calculated start time.
# result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details)
# last_record_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC')
# assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=5.1)
print('\nTest get_records_since with a different timeframe')
ex_details = ['BTC/USD', '15m', 'binance', 'test_guy']
start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=2)
# Query the records since the calculated start time.
result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details)
last_record_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC')
assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=15.1)
print('\nTest get_records_since with a different timeframe')
ex_details = ['BTC/USD', '5m', 'binance', 'test_guy']
start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=1)
# Query the records since the calculated start time.
result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details)
last_record_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC')
assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=5.1)
print('\nTest get_records_since with a different timeframe')
ex_details = ['BTC/USD', '4h', 'binance', 'test_guy']
@ -506,16 +543,217 @@ class TestDataCacheV2(unittest.TestCase):
def test_fetch_candles_from_exchange(self):
print('Testing _fetch_candles_from_exchange() method:')
start_time = dt.datetime.utcnow() - dt.timedelta(days=1)
end_time = dt.datetime.utcnow()
result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h', exchange_name='binance',
user_name='test_guy', start_datetime=start_time,
end_datetime=end_time)
# Define start and end times for the data fetch
start_time = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc) - dt.timedelta(days=1)
end_time = dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc)
# Fetch the candles from the exchange using the method
result = self.data._fetch_candles_from_exchange(symbol='BTC/USD', interval='2h',
exchange_name='binance', user_name='test_guy',
start_datetime=start_time, end_datetime=end_time)
# Validate that the result is a DataFrame
self.assertIsInstance(result, pd.DataFrame)
self.assertFalse(result.empty)
# Validate that the DataFrame is not empty
self.assertFalse(result.empty, "The DataFrame returned from the exchange is empty.")
# Ensure that the 'open_time' column exists in the DataFrame
self.assertIn('open_time', result.columns, "'open_time' column is missing in the result DataFrame.")
# Check if the DataFrame contains valid timestamps within the specified range
min_time = pd.to_datetime(result['open_time'].min(), unit='ms').tz_localize('UTC')
max_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC')
self.assertTrue(start_time <= min_time <= end_time, f"Data starts outside the expected range: {min_time}")
self.assertTrue(start_time <= max_time <= end_time, f"Data ends outside the expected range: {max_time}")
print(' - Fetch candle data from exchange passed.')
def test_remove_row(self):
print('Testing remove_row() method:')
# Insert data into the cache with the expected columns
df = pd.DataFrame({
'key': [self.key],
'data': ['test_data']
})
self.data.set_cache(data='test_data', key=self.key)
# Ensure the data is in the cache
self.assertTrue(self.data.cache_exists(self.key), "Data was not correctly inserted into the cache.")
# Remove the row from the cache only (soft delete)
self.data.remove_row(table='test_table_2', filter_vals=('key', self.key), remove_from_db=False)
# Verify the row has been removed from the cache
self.assertFalse(self.data.cache_exists(self.key), "Row was not correctly removed from the cache.")
# Reinsert the data for hard delete test
self.data.set_cache(data='test_data', key=self.key)
# Mock database delete by adding the row to the database
self.data.db.insert_row(table='test_table_2', columns=('key', 'data'), values=(self.key, 'test_data'))
# Remove the row from both cache and database (hard delete)
self.data.remove_row(table='test_table_2', filter_vals=('key', self.key), remove_from_db=True)
# Verify the row has been removed from the cache
self.assertFalse(self.data.cache_exists(self.key), "Row was not correctly removed from the cache.")
# Verify the row has been removed from the database
with SQLite(self.db_file) as con:
result = pd.read_sql(f'SELECT * FROM test_table_2 WHERE key="{self.key}"', con)
self.assertTrue(result.empty, "Row was not correctly removed from the database.")
print(' - Remove row from cache and database passed.')
def test_timeframe_to_timedelta(self):
print('Testing timeframe_to_timedelta() function:')
result = timeframe_to_timedelta('2h')
expected = pd.Timedelta(hours=2)
self.assertEqual(result, expected, "Failed to convert '2h' to Timedelta")
result = timeframe_to_timedelta('5m')
expected = pd.Timedelta(minutes=5)
self.assertEqual(result, expected, "Failed to convert '5m' to Timedelta")
result = timeframe_to_timedelta('1d')
expected = pd.Timedelta(days=1)
self.assertEqual(result, expected, "Failed to convert '1d' to Timedelta")
result = timeframe_to_timedelta('3M')
expected = pd.DateOffset(months=3)
self.assertEqual(result, expected, "Failed to convert '3M' to DateOffset")
result = timeframe_to_timedelta('1Y')
expected = pd.DateOffset(years=1)
self.assertEqual(result, expected, "Failed to convert '1Y' to DateOffset")
with self.assertRaises(ValueError):
timeframe_to_timedelta('5x')
print(' - All timeframe_to_timedelta() tests passed.')
def test_estimate_record_count(self):
print('Testing estimate_record_count() function:')
start_time = dt.datetime(2023, 8, 1, 0, 0, 0, tzinfo=dt.timezone.utc)
end_time = dt.datetime(2023, 8, 2, 0, 0, 0, tzinfo=dt.timezone.utc)
result = estimate_record_count(start_time, end_time, '1h')
expected = 24
self.assertEqual(result, expected, "Failed to estimate record count for 1h timeframe")
result = estimate_record_count(start_time, end_time, '1d')
expected = 1
self.assertEqual(result, expected, "Failed to estimate record count for 1d timeframe")
start_time = int(start_time.timestamp() * 1000) # Convert to milliseconds
end_time = int(end_time.timestamp() * 1000) # Convert to milliseconds
result = estimate_record_count(start_time, end_time, '1h')
expected = 24
self.assertEqual(result, expected, "Failed to estimate record count for 1h timeframe with milliseconds")
with self.assertRaises(ValueError):
estimate_record_count("invalid_start", end_time, '1h')
print(' - All estimate_record_count() tests passed.')
def test_fetch_cached_rows(self):
print('Testing fetch_cached_rows() method:')
# Set up mock data in the cache
df = pd.DataFrame({
'table': ['test_table_2'],
'key': ['test_key'],
'data': ['test_data']
})
self.data.cache = pd.concat([self.data.cache, df])
# Test fetching from cache
result = self.data.fetch_cached_rows('test_table_2', ('key', 'test_key'))
self.assertIsInstance(result, pd.DataFrame, "Failed to fetch DataFrame from cache")
self.assertFalse(result.empty, "The fetched DataFrame is empty")
self.assertEqual(result.iloc[0]['data'], 'test_data', "Incorrect data fetched from cache")
# Test fetching from database (assuming the method calls it)
# Here we would typically mock the database call
# But since we're not doing I/O, we will skip that part
print(' - Fetch from cache and database simulated.')
def test_is_attr_taken(self):
print('Testing is_attr_taken() method:')
# Set up mock data in the cache
df = pd.DataFrame({
'table': ['users'],
'user_name': ['test_user'],
'data': ['test_data']
})
self.data.cache = pd.concat([self.data.cache, df])
# Test for existing attribute
result = self.data.is_attr_taken('users', 'user_name', 'test_user')
self.assertTrue(result, "Failed to detect existing attribute")
# Test for non-existing attribute
result = self.data.is_attr_taken('users', 'user_name', 'non_existing_user')
self.assertFalse(result, "Incorrectly detected non-existing attribute")
print(' - All is_attr_taken() tests passed.')
def test_insert_data(self):
print('Testing insert_data() method:')
# Create a DataFrame to insert
df = pd.DataFrame({
'key': ['new_key'],
'data': ['new_data']
})
# Insert data into the database and cache
self.data.insert_data(df=df, table='test_table_2')
# Verify that the data was added to the cache
cached_value = self.data.get_cache('new_key')
self.assertEqual(cached_value, 'new_data', "Failed to insert data into cache")
# Normally, we would also verify that the data was inserted into the database
# This would typically be done with a mock database or by checking the database state directly
print(' - Data insertion into cache and database simulated.')
def test_insert_row(self):
print('Testing insert_row() method:')
self.data.insert_row(table='test_table_2', columns=('key', 'data'), values=('test_key', 'test_data'))
# Verify the row was inserted
with SQLite(self.db_file) as con:
result = pd.read_sql('SELECT * FROM test_table_2 WHERE key="test_key"', con)
self.assertFalse(result.empty, "Row was not inserted into the database.")
print(' - Insert row passed.')
def test_fill_data_holes(self):
print('Testing _fill_data_holes() method:')
# Create mock data with gaps
df = pd.DataFrame({
'open_time': [dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc).timestamp() * 1000,
dt.datetime(2023, 1, 1, 2, tzinfo=dt.timezone.utc).timestamp() * 1000,
dt.datetime(2023, 1, 1, 6, tzinfo=dt.timezone.utc).timestamp() * 1000,
dt.datetime(2023, 1, 1, 8, tzinfo=dt.timezone.utc).timestamp() * 1000,
dt.datetime(2023, 1, 1, 12, tzinfo=dt.timezone.utc).timestamp() * 1000]
})
# Call the method
result = self.data._fill_data_holes(records=df, interval='2h')
self.assertEqual(len(result), 7, "Data holes were not filled correctly.")
print(' - _fill_data_holes passed.')
if __name__ == '__main__':
unittest.main()

View File

@ -115,7 +115,7 @@ def test_log_out_all_users():
def test_load_user_data():
# Todo method incomplete
result = config.users.load_user_data(user_name='RobbieD')
result = config.users.get_user_data(user_name='RobbieD')
print('\n Test result:', result)
assert result is None