From b666ec22afd0982f590b2b92dcc9898511f5eea9 Mon Sep 17 00:00:00 2001 From: Rob Date: Sat, 7 Sep 2024 19:51:01 -0300 Subject: [PATCH] Extended DataCache with indicator caching functionality. All DataCache tests pass. --- markdown/candles.md | 6 +- src/BrighterTrades.py | 99 +++- src/DataCache_v2.py | 38 +- src/DataCache_v3.py | 448 ++++++++++++-- src/Database.py | 2 +- src/Exchange.py | 10 +- src/ExchangeInterface.py | 15 + src/Users.py | 140 ++--- src/app.py | 6 +- src/archived_code/DataCache.py | 16 +- src/archived_code/test_DataCache.py | 18 +- src/candles.py | 24 +- src/indicators.py | 325 ++++------ src/maintenence/debuging_testing.py | 37 +- src/shared_utilities.py | 4 +- src/static/communication.js | 2 + src/static/exchanges.js | 61 +- src/static/indicators.js | 349 +++++++---- src/templates/exchange_config_popup.html | 2 +- src/templates/exchange_info_hud.html | 14 +- src/templates/index.html | 1 + src/templates/indicators_hud.html | 61 +- src/templates/new_indicator_popup.html | 66 +++ src/templates/price_chart.html | 2 +- .../working_public_exchanges.txt | 0 tests/test_DataCache.py | 560 +++++++++++++++--- tests/test_Exchange.py | 2 +- tests/test_candles.py | 6 +- tests/test_database.py | 10 +- tests/test_shared_utilities.py | 14 +- 30 files changed, 1579 insertions(+), 759 deletions(-) create mode 100644 src/templates/new_indicator_popup.html rename src/{maintenence => }/working_public_exchanges.txt (100%) diff --git a/markdown/candles.md b/markdown/candles.md index 2b0bd78..10f460b 100644 --- a/markdown/candles.md +++ b/markdown/candles.md @@ -65,7 +65,7 @@ start @startuml start :Reset index of candles; - :Extract the open_time and volume columns; + :Extract the `time and volume columns; :Add the color field calling get_color() to supply the values; :Rename volume column to value; :Return volumes; @@ -83,9 +83,9 @@ start :Get color-coded volume; else (False) :Extract the specific values, - indexed by open_time; + indexed by time; endif - :Rename open_time column to time; + :Rename time column to time; :Convert time to seconds from milliseconds; :Return values; stop diff --git a/src/BrighterTrades.py b/src/BrighterTrades.py index 33bd4d0..6bc5ebd 100644 --- a/src/BrighterTrades.py +++ b/src/BrighterTrades.py @@ -187,16 +187,23 @@ class BrighterTrades: """ active_exchanges = self.users.get_exchanges(user_name, category='active_exchanges') success = False + for exchange in active_exchanges: keys = self.users.get_api_keys(user_name, exchange) - success = self.connect_or_config_exchange(user_name=user_name, - exchange_name=exchange, - api_keys=keys) + result = self.connect_or_config_exchange(user_name=user_name, + exchange_name=exchange, + api_keys=keys) + if (result['status'] == 'success') or (result['status'] == 'already_connected'): + success = True + if not success: # If no active exchange was successfully connected, connect to the default exchange - success = self.connect_or_config_exchange(user_name=user_name, - exchange_name=default_exchange, - api_keys=default_keys) + result = self.connect_or_config_exchange(user_name=user_name, + exchange_name=default_exchange, + api_keys=default_keys) + if result['status'] == 'success': + success = True + return success def get_js_init_data(self, user_name: str) -> dict: @@ -207,7 +214,7 @@ class BrighterTrades: :param user_name: str - The name of the user making the query. """ chart_view = self.users.get_chart_view(user_name=user_name) - indicator_types = self.indicators.indicator_types + indicator_types = self.indicators.get_available_indicator_types() available_indicators = self.indicators.get_indicator_list(user_name) if not chart_view: @@ -223,7 +230,9 @@ class BrighterTrades: 'timeframe': chart_view.get('timeframe'), 'exchange_name': chart_view.get('exchange_name'), 'trading_pair': chart_view.get('market'), - 'user_name': user_name + 'user_name': user_name, + 'public_exchanges': self.exchanges.get_public_exchanges() + } return js_data @@ -250,7 +259,7 @@ class BrighterTrades: r_data['configured_exchanges'] = self.users.get_exchanges( user_name, category='configured_exchanges') or [] r_data['my_balances'] = self.exchanges.get_all_balances(user_name) or {} - r_data['indicator_types'] = self.indicators.indicator_types or [] + r_data['indicator_types'] = self.indicators.get_available_indicator_types() or [] r_data['indicator_list'] = self.indicators.get_indicator_list(user_name) or [] r_data['enabled_indicators'] = self.indicators.get_indicator_list(user_name, only_enabled=True) or [] r_data['ma_vals'] = self.indicators.MV_AVERAGE_ENUM @@ -365,38 +374,47 @@ class BrighterTrades: """ return self.strategies.get_strategies('json') - def connect_or_config_exchange(self, user_name: str, exchange_name: str, api_keys: dict = None) -> bool: + def connect_or_config_exchange(self, user_name: str, exchange_name: str, api_keys: dict = None) -> dict: """ Connects to an exchange if not already connected, or configures the exchange connection for a single user. :param user_name: str - The name of the user. :param exchange_name: str - The name of the exchange. :param api_keys: dict - The API keys for the exchange. - :return: bool - True if the exchange was connected or configured successfully, False otherwise. + :return: dict - A dictionary containing the result of the operation. """ - # Check if the exchange is already connected - if self.exchanges.exchange_data.query("user == @user_name and name == @exchange_name").empty: - # Exchange is not connected, try to connect - try: + result = { + 'exchange': exchange_name, + 'status': '', + 'message': '' + } + + try: + if self.exchanges.exchange_data.query("user == @user_name and name == @exchange_name").empty: + # Exchange is not connected, try to connect success = self.exchanges.connect_exchange(exchange_name=exchange_name, user_name=user_name, api_keys=api_keys) if success: self.users.active_exchange(exchange=exchange_name, user_name=user_name, cmd='set') if api_keys: self.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, user_name=user_name) - return True + result['status'] = 'success' + result['message'] = f'Successfully connected to {exchange_name}.' else: - return False # Failed to connect - except Exception as e: - # Handle specific exceptions or log connection errors - print(f"Failed to connect user '{user_name}' to exchange '{exchange_name}': {str(e)}") - return False # Failed to connect - else: - # Exchange is already connected, update API keys if provided - if api_keys: - self.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, user_name=user_name) - return True # Already connected + result['status'] = 'failure' + result['message'] = f'Failed to connect to {exchange_name}.' + else: + # Exchange is already connected, update API keys if provided + if api_keys: + self.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, user_name=user_name) + result['status'] = 'already_connected' + result['message'] = f'{exchange_name}: API keys updated.' + except Exception as e: + result['status'] = 'error' + result['message'] = f"Failed to connect to {exchange_name} for user '{user_name}': {str(e)}" + + return result def close_trade(self, trade_id): """ @@ -482,9 +500,28 @@ class BrighterTrades: elif setting == 'exchange': exchange_name = params['exchange_name'] - # Get the first result of a list of available symbols from this exchange_name. - market = self.exchanges.get_exchange(ename=exchange_name, uname=user_name).get_symbols()[0] - # Change the exchange in the chart view and pass it a default market to display. + + # Get the list of available symbols (markets) for the specified exchange and user. + markets = self.exchanges.get_exchange(ename=exchange_name, uname=user_name).get_symbols() + + # Check if the markets list is empty + if not markets: + # If no markets are available, exit without changing the chart view. + print(f"No available markets found for exchange '{exchange_name}'. Chart view remains unchanged.") + return + + # Get the currently viewed market for the user. + current_symbol = self.users.get_chart_view(user_name=user_name, prop='market') + + # Determine the market to display based on availability. + if current_symbol not in markets: + # If the current market is not available, default to the first available market. + market = markets[0] + else: + # Otherwise, continue displaying the current market. + market = current_symbol + + # Update the user's chart view to reflect the new exchange and default market. self.users.set_chart_view(values=exchange_name, specific_property='exchange_name', user_name=user_name, default_market=market) @@ -564,8 +601,8 @@ class BrighterTrades: if msg_type == 'config_exchange': user, exchange, keys = msg_data['user'], msg_data['exch'], msg_data['keys'] - if r_data := self.connect_or_config_exchange(user_name=user, exchange_name=exchange, api_keys=keys): - return standard_reply("Exchange_connection_result", r_data) + r_data = self.connect_or_config_exchange(user_name=user, exchange_name=exchange, api_keys=keys) + return standard_reply("Exchange_connection_result", r_data) if msg_type == 'reply': # If the message is a reply log the response to the terminal. diff --git a/src/DataCache_v2.py b/src/DataCache_v2.py index 9bae91e..ed996bb 100644 --- a/src/DataCache_v2.py +++ b/src/DataCache_v2.py @@ -363,9 +363,9 @@ class DataCache: if 'id' in result.columns: result = result.drop(columns=['id']) - # Concatenate, drop duplicates based on 'open_time', and sort - combined_data = pd.concat([combined_data, result]).drop_duplicates(subset='open_time').sort_values( - by='open_time') + # Concatenate, drop duplicates based on 'time', and sort + combined_data = pd.concat([combined_data, result]).drop_duplicates(subset='time').sort_values( + by='time') is_complete, request_criteria = self._data_complete(combined_data, **request_criteria) if is_complete: @@ -408,8 +408,8 @@ class DataCache: logger.debug("Cache records didn't exist.") return pd.DataFrame() logger.debug('Filtering records.') - df_filtered = df[(df['open_time'] >= unix_time_millis(start_datetime)) & ( - df['open_time'] <= unix_time_millis(end_datetime))].reset_index(drop=True) + df_filtered = df[(df['time'] >= unix_time_millis(start_datetime)) & ( + df['time'] <= unix_time_millis(end_datetime))].reset_index(drop=True) return df_filtered def _get_from_database(self, **kwargs) -> pd.DataFrame: @@ -439,7 +439,7 @@ class DataCache: return pd.DataFrame() logger.debug('Getting records from database.') - return self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=start_datetime, + return self.db.get_timestamped_records(table_name=table_name, timestamp_field='time', st=start_datetime, et=end_datetime) def _get_from_server(self, **kwargs) -> pd.DataFrame: @@ -499,12 +499,12 @@ class DataCache: temp_data = data.copy() - # Convert 'open_time' to datetime with unit='ms' and localize to UTC - temp_data['open_time_dt'] = pd.to_datetime(temp_data['open_time'], + # Convert 'time' to datetime with unit='ms' and localize to UTC + temp_data['time_dt'] = pd.to_datetime(temp_data['time'], unit='ms', errors='coerce').dt.tz_localize('UTC') - min_timestamp = temp_data['open_time_dt'].min() - max_timestamp = temp_data['open_time_dt'].max() + min_timestamp = temp_data['time_dt'].min() + max_timestamp = temp_data['time_dt'].max() logger.debug(f"Data time range: {min_timestamp} to {max_timestamp}") logger.debug(f"Expected time range: {start_datetime} to {end_datetime}") @@ -526,7 +526,7 @@ class DataCache: return False, updated_request_criteria # Filter data between start_datetime and end_datetime - mask = (temp_data['open_time_dt'] >= start_datetime) & (temp_data['open_time_dt'] <= end_datetime) + mask = (temp_data['time_dt'] >= start_datetime) & (temp_data['time_dt'] <= end_datetime) data_in_range = temp_data.loc[mask] expected_count = estimate_record_count(start_datetime, end_datetime, timeframe) @@ -559,10 +559,10 @@ class DataCache: logger.debug('Updating data with new records.') # Concatenate the new records with the existing data records = pd.concat([self.get_cache(key), more_records], axis=0, ignore_index=True) - # Drop duplicates based on 'open_time' and keep the first occurrence - records = records.drop_duplicates(subset="open_time", keep='first') - # Sort the records by 'open_time' - records = records.sort_values(by='open_time').reset_index(drop=True) + # Drop duplicates based on 'time' and keep the first occurrence + records = records.drop_duplicates(subset="time", keep='first') + # Sort the records by 'time' + records = records.sort_values(by='time').reset_index(drop=True) # Reindex 'id' to ensure the expected order records['id'] = range(1, len(records) + 1) # Set the updated DataFrame back to data @@ -621,11 +621,11 @@ class DataCache: num_rec_records = len(candles.index) if num_rec_records == 0: logger.warning(f"No OHLCV data returned for {symbol}.") - return pd.DataFrame(columns=['open_time', 'open', 'high', 'low', 'close', 'volume']) + return pd.DataFrame(columns=['time', 'open', 'high', 'low', 'close', 'volume']) logger.info(f'{num_rec_records} candles retrieved from the exchange.') - open_times = candles.open_time + open_times = candles.time min_open_time = open_times.min() max_open_time = open_times.max() @@ -650,7 +650,7 @@ class DataCache: logger.info(f"Starting to fill data holes for interval: {interval}") for index, row in records.iterrows(): - time_stamp = row['open_time'] + time_stamp = row['time'] if last_timestamp is None: last_timestamp = time_stamp @@ -670,7 +670,7 @@ class DataCache: for ts in range(int(last_timestamp) + step, int(time_stamp), step): new_row = row.copy() - new_row['open_time'] = ts + new_row['time'] = ts filled_records.append(new_row) logger.debug(f"Filled timestamp: {ts}") diff --git a/src/DataCache_v3.py b/src/DataCache_v3.py index fc593d8..fabf652 100644 --- a/src/DataCache_v3.py +++ b/src/DataCache_v3.py @@ -1,10 +1,15 @@ +import io +import pickle from abc import ABC, abstractmethod import logging import datetime as dt +from typing import Any, Tuple + import pandas as pd import numpy as np import json +from indicators import Indicator, indicators_registry from shared_utilities import unix_time_millis from Database import Database @@ -17,7 +22,9 @@ def timeframe_to_timedelta(timeframe: str) -> pd.Timedelta | pd.DateOffset: digits = int("".join([i if i.isdigit() else "" for i in timeframe])) unit = "".join([i if i.isalpha() else "" for i in timeframe]) - if unit == 'm': + if unit == 's': # Add support for seconds + return pd.Timedelta(seconds=digits) + elif unit == 'm': return pd.Timedelta(minutes=digits) elif unit == 'h': return pd.Timedelta(hours=digits) @@ -38,28 +45,43 @@ def estimate_record_count(start_time, end_time, timeframe: str) -> int: Estimate the number of records expected between start_time and end_time based on the given timeframe. Accepts either datetime objects or Unix timestamps in milliseconds. """ - # Check if the input is in milliseconds (timestamp) - if isinstance(start_time, (int, float, np.integer)) and isinstance(end_time, (int, float, np.integer)): - # Convert timestamps from milliseconds to seconds for calculation - start_time = int(start_time) / 1000 - end_time = int(end_time) / 1000 - start_datetime = dt.datetime.utcfromtimestamp(start_time).replace(tzinfo=dt.timezone.utc) - end_datetime = dt.datetime.utcfromtimestamp(end_time).replace(tzinfo=dt.timezone.utc) - elif isinstance(start_time, dt.datetime) and isinstance(end_time, dt.datetime): - if start_time.tzinfo is None: - raise ValueError("start_time is timezone naive. Please provide a timezone-aware datetime.") - if end_time.tzinfo is None: - raise ValueError("end_time is timezone naive. Please provide a timezone-aware datetime.") - start_datetime = start_time - end_datetime = end_time - else: - raise ValueError("start_time and end_time must be either both " - "datetime objects or both Unix timestamps in milliseconds.") + # Check for invalid inputs + if not isinstance(start_time, (int, float, np.integer, dt.datetime)): + raise ValueError(f"Invalid type for start_time: {type(start_time)}. Expected datetime or timestamp.") + if not isinstance(end_time, (int, float, np.integer, dt.datetime)): + raise ValueError(f"Invalid type for end_time: {type(end_time)}. Expected datetime or timestamp.") + # Convert milliseconds to datetime if needed + if isinstance(start_time, (int, float, np.integer)) and isinstance(end_time, (int, float, np.integer)): + start_time = dt.datetime.utcfromtimestamp(start_time / 1000).replace(tzinfo=dt.timezone.utc) + end_time = dt.datetime.utcfromtimestamp(end_time / 1000).replace(tzinfo=dt.timezone.utc) + elif isinstance(start_time, dt.datetime) and isinstance(end_time, dt.datetime): + if start_time.tzinfo is None or end_time.tzinfo is None: + raise ValueError("start_time and end_time must be timezone-aware.") + # Normalize both to UTC to avoid time zone differences affecting results + start_time = start_time.astimezone(dt.timezone.utc) + end_time = end_time.astimezone(dt.timezone.utc) + + # If the interval is negative, return 0 + if end_time <= start_time: + return 0 + + # Use the existing timeframe_to_timedelta function delta = timeframe_to_timedelta(timeframe) - total_seconds = (end_datetime - start_datetime).total_seconds() - expected_records = total_seconds // delta.total_seconds() - return int(expected_records) + + # Fixed-length intervals optimization + if isinstance(delta, dt.timedelta): + total_seconds = (end_time - start_time).total_seconds() + return int(total_seconds // delta.total_seconds()) + + # Handle 'M' (months) and 'Y' (years) using DateOffset + elif isinstance(delta, pd.DateOffset): + if 'months' in delta.kwds: + return (end_time.year - start_time.year) * 12 + (end_time.month - start_time.month) + elif 'years' in delta.kwds: + return end_time.year - start_time.year + + raise ValueError(f"Invalid timeframe: {timeframe}") # Cache Interface @@ -223,7 +245,7 @@ class InMemoryCache(Cache): # Mask for items that have not yet expired not_expired_mask = ( - self.cache['creation_time'] + self.cache['expire_delta'].fillna(pd.Timedelta(0)) > current_time) + self.cache['creation_time'] + self.cache['expire_delta'].fillna(pd.Timedelta(0)) > current_time) # Combine the masks mask_to_keep = non_expiring_mask | not_expired_mask @@ -324,6 +346,7 @@ class DataCacheBase: :param limit: The maximum number of items allowed in the cache (only used if creating a new cache). :param eviction_policy: The policy used when the cache reaches its limit (only used if creating a new cache). """ + # Automatically create the cache if it doesn't exist if cache_name not in self.caches: print(f"Creating Cache '{cache_name}' because it does not exist.") @@ -373,6 +396,7 @@ class DataCacheBase: :param key: The key associated with the cache item. :return Any: The cached data, or None if the key does not exist or the item is expired. """ + cache = self._get_cache(cache_name) if cache: return cache.get_item(key) @@ -568,25 +592,45 @@ class DatabaseInteractions(SnapshotDataCache): :return: A DataFrame containing the requested rows, or None if no matching rows are found. :raises ValueError: If the cache is not a DataFrame or does not contain DataFrames in the 'data' column. """ - if cache_name in self.caches: + # Attempt to retrieve cached data + cache_df = self._get_valid_cache(cache_name) - # Retrieve all items in the specified cache - cache_df = self.get_all_cache_items(cache_name=cache_name) - - if not isinstance(cache_df, pd.DataFrame): - raise ValueError(f"Cache '{cache_name}' is not a DataFrame and cannot be used with get_or_fetch_rows.") - - # Combine all the DataFrames in the 'data' column into a single DataFrame + # If the cache exists and contains data + if cache_df is not None: combined_data = pd.concat(cache_df['data'].values.tolist(), ignore_index=True) - - # Filter the combined DataFrame query_str = f"{filter_vals[0]} == @filter_vals[1]" matching_rows = combined_data.query(query_str) if not matching_rows.empty: return matching_rows - # If no data is found in the cache, fetch from the database + # Fallback to database if cache is invalid or no matching rows were found + return self._fetch_from_database(cache_name, filter_vals) + + def _get_valid_cache(self, cache_name: str) -> pd.DataFrame | None: + """ + Retrieves and validates the cache, ensuring it is a non-empty DataFrame containing 'data' column. + + :param cache_name: The key used to identify the cache. + :return: A valid DataFrame if cache is valid and contains data, otherwise None. + """ + if cache_name in self.caches: + cache_df = self.get_all_cache_items(cache_name=cache_name) + + # Return valid DataFrame if it exists and contains the 'data' column + if isinstance(cache_df, pd.DataFrame) and not cache_df.empty and 'data' in cache_df.columns: + return cache_df + + return None + + def _fetch_from_database(self, cache_name: str, filter_vals: tuple[str, any]) -> pd.DataFrame | None: + """ + Helper method to fetch rows from the database and cache the result. + + :param cache_name: The name of the table or key used to store/retrieve data. + :param filter_vals: A tuple with the filter column and value. + :return: A DataFrame with the fetched rows, or None if no data is found. + """ rows = self.db.get_rows_where(cache_name, filter_vals) if rows is not None and not rows.empty: # Store the fetched rows in the cache for future use @@ -760,17 +804,19 @@ class DatabaseInteractions(SnapshotDataCache): # Modify the specified field if isinstance(new_data, str): - row.loc[0, field_name] = new_data + updated_value = new_data else: - # If new_data is not a string, convert it to a JSON string before inserting into the DataFrame. - row.loc[0, field_name] = json.dumps(new_data) + updated_value = json.dumps(new_data) # Convert non-string data to JSON string - # Update the cache by removing the old entry and adding the modified row - self.remove_row(cache_name=cache_name, filter_vals=filter_vals) + # Update the DataFrame + row[field_name] = updated_value + + # Set the updated row in the cache (this will replace the old entry) self.set_cache_item(cache_name=cache_name, key=filter_vals[1], data=row) - # Update the database with the modified row (excluding the 'id' column if necessary) - self.db.insert_dataframe(row.drop(columns='id', errors='ignore'), table=cache_name) + # Ensure the value is a scalar before passing it to the SQL query + update_query = f"UPDATE {cache_name} SET {field_name} = ? WHERE {filter_vals[0]} = ?" + self.db.execute_sql(update_query, (updated_value, filter_vals[1])) def update_cached_dict(self, cache_name: str, cache_key: str, dict_key: str, data: any) -> None: """ @@ -827,8 +873,8 @@ class ServerInteractions(DatabaseInteractions): existing_records = pd.DataFrame() records = pd.concat([existing_records, more_records], axis=0, ignore_index=True) - records = records.drop_duplicates(subset="open_time", keep='first') - records = records.sort_values(by='open_time').reset_index(drop=True) + records = records.drop_duplicates(subset="time", keep='first') + records = records.sort_values(by='time').reset_index(drop=True) records['id'] = range(1, len(records) + 1) self.set_cache_item(cache_name='candles', key=key, data=records) @@ -859,12 +905,12 @@ class ServerInteractions(DatabaseInteractions): 'end_datetime': end_datetime, 'ex_details': ex_details, } - return self._get_or_fetch_from(target='data', **args) + return self.get_or_fetch_from(target='data', **args) except Exception as e: logger.error(f"An error occurred: {str(e)}") raise - def _get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame: + def get_or_fetch_from(self, target: str, **kwargs) -> pd.DataFrame: """ Fetches market records from a resource stack (data, database, exchange). Fills incomplete request by fetching down the stack then updates the rest. @@ -922,9 +968,9 @@ class ServerInteractions(DatabaseInteractions): if 'id' in result.columns: result = result.drop(columns=['id']) - # Concatenate, drop duplicates based on 'open_time', and sort - combined_data = pd.concat([combined_data, result]).drop_duplicates(subset='open_time').sort_values( - by='open_time') + # Concatenate, drop duplicates based on 'time', and sort + combined_data = pd.concat([combined_data, result]).drop_duplicates(subset='time').sort_values( + by='time') is_complete, request_criteria = self._data_complete(data=combined_data, **request_criteria) if is_complete: @@ -955,8 +1001,8 @@ class ServerInteractions(DatabaseInteractions): logger.debug("No cached records found.") return pd.DataFrame() - df_filtered = df[(df['open_time'] >= unix_time_millis(start_datetime)) & - (df['open_time'] <= unix_time_millis(end_datetime))].reset_index(drop=True) + df_filtered = df[(df['time'] >= unix_time_millis(start_datetime)) & + (df['time'] <= unix_time_millis(end_datetime))].reset_index(drop=True) return df_filtered @@ -987,7 +1033,7 @@ class ServerInteractions(DatabaseInteractions): return pd.DataFrame() logger.debug('Getting records from database.') - return self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=start_datetime, + return self.db.get_timestamped_records(table_name=table_name, timestamp_field='time', st=start_datetime, et=end_datetime) def _get_from_server(self, **kwargs) -> pd.DataFrame: @@ -1048,12 +1094,12 @@ class ServerInteractions(DatabaseInteractions): temp_data = data.copy() - # Convert 'open_time' to datetime with unit='ms' and localize to UTC - temp_data['open_time_dt'] = pd.to_datetime(temp_data['open_time'], - unit='ms', errors='coerce').dt.tz_localize('UTC') + # Convert 'time' to datetime with unit='ms' and localize to UTC + temp_data['time_dt'] = pd.to_datetime(temp_data['time'], + unit='ms', errors='coerce').dt.tz_localize('UTC') - min_timestamp = temp_data['open_time_dt'].min() - max_timestamp = temp_data['open_time_dt'].max() + min_timestamp = temp_data['time_dt'].min() + max_timestamp = temp_data['time_dt'].max() logger.debug(f"Data time range: {min_timestamp} to {max_timestamp}") logger.debug(f"Expected time range: {start_datetime} to {end_datetime}") @@ -1075,7 +1121,7 @@ class ServerInteractions(DatabaseInteractions): return False, updated_request_criteria # Filter data between start_datetime and end_datetime - mask = (temp_data['open_time_dt'] >= start_datetime) & (temp_data['open_time_dt'] <= end_datetime) + mask = (temp_data['time_dt'] >= start_datetime) & (temp_data['time_dt'] <= end_datetime) data_in_range = temp_data.loc[mask] expected_count = estimate_record_count(start_datetime, end_datetime, timeframe) @@ -1118,11 +1164,11 @@ class ServerInteractions(DatabaseInteractions): num_rec_records = len(candles.index) if num_rec_records == 0: logger.warning(f"No OHLCV data returned for {symbol}.") - return pd.DataFrame(columns=['open_time', 'open', 'high', 'low', 'close', 'volume']) + return pd.DataFrame(columns=['time', 'open', 'high', 'low', 'close', 'volume']) logger.info(f'{num_rec_records} candles retrieved from the exchange.') - open_times = candles.open_time + open_times = candles.time min_open_time = open_times.min() max_open_time = open_times.max() @@ -1147,7 +1193,7 @@ class ServerInteractions(DatabaseInteractions): logger.info(f"Starting to fill data holes for interval: {interval}") for index, row in records.iterrows(): - time_stamp = row['open_time'] + time_stamp = row['time'] if last_timestamp is None: last_timestamp = time_stamp @@ -1167,7 +1213,7 @@ class ServerInteractions(DatabaseInteractions): for ts in range(int(last_timestamp) + step, int(time_stamp), step): new_row = row.copy() - new_row['open_time'] = ts + new_row['time'] = ts filled_records.append(new_row) logger.debug(f"Filled timestamp: {ts}") @@ -1189,7 +1235,287 @@ class ServerInteractions(DatabaseInteractions): logger.info(f'Data inserted into table {table_name}') -class DataCache(ServerInteractions): +class IndicatorCache(ServerInteractions): + """ + IndicatorCache extends ServerInteractions and manages caching of both instantiated indicator objects + and their calculated data. This class optimizes the indicator calculation process by caching and + reusing data where possible, minimizing redundant computations. + + Attributes: + indicator_registry (dict): A dictionary mapping indicator types (e.g., 'SMA', 'EMA') to their classes. + """ + + def __init__(self, exchanges): + """ + Initialize the IndicatorCache with caches for indicators and their calculated data. + + :param exchanges: The exchange interfaces used for retrieving market data. + """ + super().__init__(exchanges) + # Cache for storing instantiated indicator objects + self.create_cache('indicators', cache_type=InMemoryCache, limit=100, eviction_policy='evict') + # Cache for storing calculated indicator data + self.create_cache('indicator_data', cache_type=InMemoryCache, limit=500, eviction_policy='evict') + # Registry of available indicators + self.indicator_registry = indicators_registry + + @staticmethod + def _make_indicator_key(symbol: str, timeframe: str, exchange_name: str, indicator_type: str, period: int) -> str: + """ + Generates a unique cache key for caching the indicator data based on the input properties. + """ + return f"{symbol}_{timeframe}_{exchange_name}_{indicator_type}_{period}" + + def get_indicator_instance(self, indicator_type: str, properties: dict) -> Indicator: + """ + Retrieves or creates an instance of an indicator based on its type and properties. + """ + if indicator_type not in self.indicator_registry: + raise ValueError(f"Unsupported indicator type: {indicator_type}") + indicator_class = self.indicator_registry[indicator_type] + return indicator_class(name=indicator_type, indicator_type=indicator_type, properties=properties) + + def set_cache_item(self, key: str, data: Any, expire_delta: dt.timedelta = None, + do_not_overwrite: bool = False, cache_name: str = 'default_cache', + limit: int = None, eviction_policy: str = 'evict'): + """ + Stores an item in the cache, with custom serialization for Indicator instances. + Maintains the signature consistent with the base class. + """ + # Serialize Indicator instances using pickle + if isinstance(data, Indicator): + data = pickle.dumps(data) + + # Use the base class method for actual caching + super().set_cache_item(key, data, expire_delta=expire_delta, do_not_overwrite=do_not_overwrite, + cache_name=cache_name, limit=limit, eviction_policy=eviction_policy) + + def get_cache_item(self, key: str, cache_name: str = 'default_cache') -> Any: + """ + Retrieves an item from the specified cache. + + :param cache_name: The name of the cache. + :param key: The key associated with the cache item. + :return Any: The cached data, or None if the key does not exist or the item is expired. + """ + data = super().get_cache_item(key, cache_name) + + # If no data is found, return None + if data is None: + logging.info(f"No data found in cache for key: {key}") + return None + + # Handle Indicator case (deserialize using pickle) + if cache_name == 'indicators': + logging.info(f"Indicator data retrieved from cache for key: {key}") + try: + deserialized_data = pickle.loads(data) + if isinstance(deserialized_data, Indicator): + return deserialized_data + else: + logging.warning(f"Expected Indicator instance, got {type(deserialized_data)}") + return deserialized_data # Fallback: Return deserialized data even if it's not an Indicator + except (pickle.PickleError, TypeError) as e: + logging.error(f"Deserialization failed for key {key}: {e}") + return None + + # Handle list case + if isinstance(data, list): + logging.info(f"List data retrieved from cache for key: {key}") + return data + + # Handle DataFrame case + if isinstance(data, pd.DataFrame) and not data.empty: + logging.info(f"DataFrame retrieved from cache for key: {key}") + return data + + # Return the data as-is for any other type + logging.info(f"Data retrieved from cache for key: {key}") + return data + + def set_user_indicator_properties(self, user_id: str, indicator_type: str, symbol: str, timeframe: str, + exchange_name: str, display_properties: dict): + """ + Stores or updates user-specific display properties for an indicator. + """ + if not isinstance(display_properties, dict): + raise ValueError("display_properties must be a dictionary") + + user_cache_key = f"user_{user_id}_{indicator_type}_{symbol}_{timeframe}_{exchange_name}" + self.set_cache_item(user_cache_key, display_properties, cache_name='user_display_properties') + + def get_user_indicator_properties(self, user_id: str, indicator_type: str, symbol: str, timeframe: str, + exchange_name: str) -> dict: + """ + Retrieves user-specific display properties for an indicator. + """ + # Ensure the arguments are valid + if not isinstance(user_id, str) or not isinstance(indicator_type, str) or \ + not isinstance(symbol, str) or not isinstance(timeframe, str) or \ + not isinstance(exchange_name, str): + raise TypeError("All arguments must be of type str") + user_cache_key = f"user_{user_id}_{indicator_type}_{symbol}_{timeframe}_{exchange_name}" + return self.get_cache_item(user_cache_key, cache_name='user_display_properties') + + def find_gaps_in_intervals(self, cached_data, start_idx, end_idx, timeframe, min_gap_size=None): + """ + Recursively searches for missing intervals in cached data based on time gaps. + + :param cached_data: DataFrame containing cached time-series data. + :param start_idx: Start index in the DataFrame. + :param end_idx: End index in the DataFrame. + :param timeframe: Expected interval between consecutive time points (e.g., '5m', '1h'). + :param min_gap_size: The minimum allowable gap (in terms of the number of missing records) to consider. + :return: A list of tuples representing the missing intervals (start, end). + """ + if cached_data.empty or end_idx - start_idx <= 1: + return [] + + start_time = cached_data['time'].iloc[start_idx] + end_time = cached_data['time'].iloc[end_idx] + + expected_records = estimate_record_count(start_time, end_time, timeframe) + actual_records = end_idx - start_idx + 1 + + if expected_records == actual_records: + return [] + + if min_gap_size and (expected_records - actual_records) < min_gap_size: + return [] + + mid_idx = (start_idx + end_idx) // 2 + left_missing = self.find_gaps_in_intervals(cached_data, start_idx, mid_idx, timeframe, min_gap_size) + right_missing = self.find_gaps_in_intervals(cached_data, mid_idx, end_idx, timeframe, min_gap_size) + + return left_missing + right_missing + + def calculate_indicator(self, user_name: str, symbol: str, timeframe: str, exchange_name: str, indicator_type: str, + start_datetime: dt.datetime, end_datetime: dt.datetime, properties: dict) -> dict: + """ + Fetches or calculates the indicator data for the specified range, caches the shared results, + and combines them with user-specific display properties or indicator defaults. + """ + # Ensure start_datetime and end_datetime are timezone-aware + if start_datetime.tzinfo is None or end_datetime.tzinfo is None: + raise ValueError("Datetime objects must be timezone-aware") + + # Step 1: Create cache key and indicator instance + cache_key = self._make_indicator_key(symbol, timeframe, exchange_name, indicator_type, properties['period']) + indicator_instance = self.get_indicator_instance(indicator_type, properties) + + # Step 2: Check cache for existing data + calculated_data = self._fetch_cached_data(cache_key, start_datetime, end_datetime) + missing_intervals = self._determine_missing_intervals(calculated_data, start_datetime, end_datetime, timeframe) + + # Step 3: Fetch and calculate data for missing intervals + if missing_intervals: + calculated_data = self._fetch_and_calculate_missing_data( + missing_intervals, calculated_data, indicator_instance, user_name, symbol, timeframe, exchange_name + ) + + # Step 4: Cache the newly calculated data + self.set_cache_item(cache_key, calculated_data, cache_name='indicator_data') + + # Step 5: Retrieve and merge user-specific display properties with defaults + merged_properties = self._get_merged_properties(user_name, indicator_type, symbol, timeframe, exchange_name, + indicator_instance) + + return { + 'calculation_data': calculated_data.to_dict('records'), + 'display_properties': merged_properties + } + + def _fetch_cached_data(self, cache_key: str, start_datetime: dt.datetime, + end_datetime: dt.datetime) -> pd.DataFrame: + """ + Fetches cached data for the given time range. + """ + # Retrieve cached data (expected to be a DataFrame with 'time' in Unix ms) + cached_df = self.get_cache_item(cache_key, cache_name='indicator_data') + + # If no cached data, return an empty DataFrame + if cached_df is None or cached_df.empty: + return pd.DataFrame() + + # Convert start and end datetime to Unix timestamps (milliseconds) + start_timestamp = int(start_datetime.timestamp() * 1000) + end_timestamp = int(end_datetime.timestamp() * 1000) + + # Convert cached 'time' column to Unix timestamps for comparison + cached_df['time'] = cached_df['time'].astype('int64') // 10 ** 6 + + # Return only the data within the specified time range + return cached_df[(cached_df['time'] >= start_timestamp) & (cached_df['time'] <= end_timestamp)] + + def _determine_missing_intervals(self, cached_data: pd.DataFrame, start_datetime: dt.datetime, + end_datetime: dt.datetime, timeframe: str): + """ + Determines missing intervals in the cached data by comparing the requested time range with cached data. + + :param cached_data: Cached DataFrame with time-series data in Unix timestamp (ms). + :param start_datetime: Start datetime of the requested data range. + :param end_datetime: End datetime of the requested data range. + :param timeframe: Expected interval between consecutive time points (e.g., '5m', '1h'). + :return: A list of tuples representing the missing intervals (start, end). + """ + # Convert start and end datetime to Unix timestamps (milliseconds) + start_timestamp = int(start_datetime.timestamp() * 1000) + end_timestamp = int(end_datetime.timestamp() * 1000) + + if cached_data is not None and not cached_data.empty: + # Find the last (most recent) time in the cached data (which is in Unix timestamp ms) + cached_end = cached_data['time'].max() + + # If the cached data ends before the requested end time, add that range to missing intervals + if cached_end < end_timestamp: + # Convert cached_end back to datetime for consistency in returned intervals + cached_end_dt = pd.to_datetime(cached_end, unit='ms').to_pydatetime() + return [(cached_end_dt, end_datetime)] + + # Otherwise, check for gaps within the cached data itself + return self.find_gaps_in_intervals(cached_data, 0, len(cached_data) - 1, timeframe) + + # If no cached data exists, the entire range is missing + return [(start_datetime, end_datetime)] + + def _fetch_and_calculate_missing_data(self, missing_intervals, calculated_data, indicator_instance, user_name, + symbol, timeframe, exchange_name): + """ + Fetches and calculates missing data for the specified intervals. + """ + for interval_start, interval_end in missing_intervals: + # Ensure interval_start and interval_end are timezone-aware + if interval_start.tzinfo is None: + interval_start = interval_start.replace(tzinfo=dt.timezone.utc) + if interval_end.tzinfo is None: + interval_end = interval_end.replace(tzinfo=dt.timezone.utc) + + ohlc_data = self.get_or_fetch_from('data', start_datetime=interval_start, end_datetime=interval_end, + ex_details=[symbol, timeframe, exchange_name, user_name]) + if ohlc_data.empty or 'close' not in ohlc_data.columns: + continue + + new_data = indicator_instance.calculate(candles=ohlc_data, user_name=user_name) + + calculated_data = pd.concat([calculated_data, new_data], ignore_index=True) + + return calculated_data + + def _get_merged_properties(self, user_name, indicator_type, symbol, timeframe, exchange_name, indicator_instance): + """ + Retrieves and merges user-specific properties with default indicator properties. + """ + user_properties = self.get_user_indicator_properties(user_name, indicator_type, + symbol, timeframe, exchange_name) + if not user_properties: + user_properties = { + key: value for key, + value in indicator_instance.properties.items() if key.startswith('color') or key.startswith('thickness') + } + return {**indicator_instance.properties, **user_properties} + + +class DataCache(IndicatorCache): """ Extends ServerInteractions, DatabaseInteractions, SnapshotDataCache, and DataCacheBase to create a fully functional cache management system that can store and manage custom caching objects of type InMemoryCache. The following diff --git a/src/Database.py b/src/Database.py index 198deec..c8d176f 100644 --- a/src/Database.py +++ b/src/Database.py @@ -257,7 +257,7 @@ class Database: CREATE TABLE IF NOT EXISTS '{table_name}' ( id INTEGER PRIMARY KEY, market_id INTEGER, - open_time INTEGER UNIQUE ON CONFLICT IGNORE, + time INTEGER UNIQUE ON CONFLICT IGNORE, open REAL NOT NULL, high REAL NOT NULL, low REAL NOT NULL, diff --git a/src/Exchange.py b/src/Exchange.py index 592da90..27ad093 100644 --- a/src/Exchange.py +++ b/src/Exchange.py @@ -70,7 +70,7 @@ class Exchange: # Perform an authenticated request to check if the API keys are valid self.client.fetch_balance() self.configured = True - logger.info("Authentication successful. Trading bot configured.") + logger.info("Authentication successful.") except ccxt.AuthenticationError: logger.error("Authentication failed. Please check your API keys.") except Exception as e: @@ -135,13 +135,13 @@ class Exchange: logger.warning(f"No OHLCV data returned for {symbol} from {current_start} to {current_end}.") break - df_columns = ['open_time', 'open', 'high', 'low', 'close', 'volume'] + df_columns = ['time', 'open', 'high', 'low', 'close', 'volume'] candles_df = pd.DataFrame(candles, columns=df_columns) data_frames.append(candles_df) # Update current_start to the time of the last candle retrieved - last_candle_time = candles_df['open_time'].iloc[-1] / 1000 # Convert from milliseconds to seconds + last_candle_time = candles_df['time'].iloc[-1] / 1000 # Convert from milliseconds to seconds current_start = datetime.utcfromtimestamp(last_candle_time).replace(tzinfo=timezone.utc) + timedelta( milliseconds=1) @@ -153,12 +153,12 @@ class Exchange: if data_frames: # Combine all chunks and drop duplicates in one step - result_df = pd.concat(data_frames).drop_duplicates(subset=['open_time']).reset_index(drop=True) + result_df = pd.concat(data_frames).drop_duplicates(subset=['time']).reset_index(drop=True) logger.info(f"Successfully fetched OHLCV data for {symbol}.") return result_df else: logger.warning(f"No OHLCV data fetched for {symbol}.") - return pd.DataFrame(columns=['open_time', 'open', 'high', 'low', 'close', 'volume']) + return pd.DataFrame(columns=['time', 'open', 'high', 'low', 'close', 'volume']) def _fetch_price(self, symbol: str) -> float: """ diff --git a/src/ExchangeInterface.py b/src/ExchangeInterface.py index 21296d9..1603769 100644 --- a/src/ExchangeInterface.py +++ b/src/ExchangeInterface.py @@ -30,6 +30,21 @@ class ExchangeInterface: """Retrieve the list of available exchanges from CCXT.""" return ccxt.exchanges + def get_public_exchanges(self) -> List[str]: + """Return a list of public exchanges available from CCXT.""" + public_list = [] + file_path = 'src\working_public_exchanges.txt' + + try: + with open(file_path, 'r') as file: + public_list = [line.strip() for line in file.readlines()] + except FileNotFoundError: + print(f"Error: The file {file_path} was not found.") + except Exception as e: + print(f"An error occurred: {e}") + + return public_list + def connect_exchange(self, exchange_name: str, user_name: str, api_keys: Dict[str, str] = None) -> bool: """ Initialize and store a reference to the specified exchange. diff --git a/src/Users.py b/src/Users.py index 113735b..19da2ef 100644 --- a/src/Users.py +++ b/src/Users.py @@ -1,3 +1,4 @@ +import copy import datetime as dt import json import random @@ -34,6 +35,19 @@ class BaseUser: filter_vals=('user_name', user_name) ) + def get_username(self, id: int) -> str: + """ + Retrieves the user username based on the ID. + + :param id: The id of the user. + :return: The name of the user as a str. + """ + return self.data.fetch_item( + item_name='user_name', + cache_name='users', + filter_vals=('id', id) + ) + def _remove_user_from_memory(self, user_name: str) -> None: """ Private method to remove a user's data from the cache (memory). @@ -109,11 +123,7 @@ class UserAccountManagement(BaseUser): """ super().__init__(data_cache) - self.max_guests = max_guests # Maximum number of guests - - # Initialize data for guest suffixes and cached users - self.data.set_cache_item(data=[], key='guest_suffixes', do_not_overwrite=True) - self.data.set_cache_item(data={}, key='cached_users', do_not_overwrite=True) + self.max_guests_factor = max_guests # Maximum number of guests is determined using this number. def is_logged_in(self, user_name: str) -> bool: """ @@ -125,29 +135,14 @@ class UserAccountManagement(BaseUser): if user_name is None: return False - def is_guest(_user) -> bool: - split_name = _user.at[0, 'user_name'].split("_") - return 'guest' in split_name and len(split_name) == 2 - - # Retrieve the user's data from either the cache or the database. user = self.get_user_data(user_name) + if user is None or user.empty: + return False - if user is not None and not user.empty: - # If the user exists, check their login status. - if user.at[0, 'status'] == 'logged_in': - # If the user is logged in, check if they are a guest. - if is_guest(user): - # Update the guest suffix cache if the user is a guest. - guest_suffixes = self.data.get_cache_item(key='guest_suffixes') or [] - guest_suffixes.append(user_name.split('_')[1]) - self.data.set_cache_item(data=guest_suffixes, key='guest_suffixes') - return True - else: - # If the user is not logged in, remove their data from the cache. - self._remove_user_from_memory(user_name) - return False + if user.iloc[0]['status'] == 'logged_in': + return True - # If the user data is not found or the status is not 'logged_in', return False. + self._remove_user_from_memory(user_name) return False def validate_password(self, username: str, password: str) -> bool: @@ -184,14 +179,12 @@ class UserAccountManagement(BaseUser): :param password: The unencrypted password. :return: True on successful login, False otherwise. """ - if success := self.validate_password(username=username, password=password): - self.modify_user_data(username=username, - field_name="status", - new_data="logged_in") - self.modify_user_data(username=username, - field_name="signin_time", + if self.validate_password(username=username, password=password): + self.modify_user_data(username=username, field_name="status", new_data="logged_in") + self.modify_user_data(username=username, field_name="signin_time", new_data=dt.datetime.utcnow().timestamp()) - return success + return True + return False def log_out_user(self, username: str) -> bool: """ @@ -259,33 +252,41 @@ class UserAccountManagement(BaseUser): :return: A unique guest username or None if the guest limit is reached. """ - guest_suffixes = self.data.get_cache_item(key='guest_suffixes') or [] - if len(guest_suffixes) >= self.max_guests: - return None + initial_multiplier = 9 # Start with the initial multiplier + max_attempts = 10 # Set the maximum number of attempts before increasing the multiplier + multiplier = initial_multiplier + attempts = 0 - suffix = random.choice(range(0, self.max_guests * 9)) - while suffix in guest_suffixes: - suffix = random.choice(range(0, self.max_guests * 9)) + while True: + suffix = random.choice(range(0, self.max_guests_factor * multiplier)) + username = f'guest_{suffix}' - guest_suffixes.append(suffix) - self.data.set_cache_item(key='guest_suffixes', data=guest_suffixes) - return f'guest_{suffix}' + # Check if the username already exists in the database + if not self.data.get_or_fetch_rows(cache_name='users', filter_vals=('user_name', username)): + return username + + attempts += 1 + + # If too many attempts have been made, increase the multiplier + if attempts >= max_attempts: + multiplier += 1 # Increment the multiplier by 1 + attempts = 0 # Reset the attempt counter + + # Safety net to avoid runaway multipliers + if multiplier > self.max_guests_factor: + return None def create_guest(self) -> str | None: - """ - Creates a guest user in the database and logs them in. - - :return: The guest username or None if the guest limit is reached. - """ - if (username := self.create_unique_guest_name()) is None: + username = self.create_unique_guest_name() + if username is None: return None - attrs = ({'user_name': username},) - self.create_new_user_in_db(attrs=attrs) - login_success = self.log_in_user(username=username, password='password') - if login_success: + self.create_new_user_in_db(attrs=({'user_name': username},)) + if self.log_in_user(username=username, password='password'): return username + return None # Failure to log in. + def create_new_user_in_db(self, attrs: tuple) -> None: """ Creates a new user in the database by modifying a default template. @@ -302,16 +303,19 @@ class UserAccountManagement(BaseUser): if default_user is None or default_user.empty: raise ValueError("Default user template not found in the database.") - # Modify the default user template with the provided attributes + # Make a deep copy of the default user to preserve the original template + new_user = copy.deepcopy(default_user) + + # Modify the deep copied user template with the provided attributes for attr in attrs: key, value = next(iter(attr.items())) - default_user.loc[0, key] = value + new_user.loc[0, key] = value # Remove the 'id' column before inserting into the database - default_user = default_user.drop(columns='id') + new_user = new_user.drop(columns='id') # Insert the modified user data into the database, skipping cache insertion - self.data.insert_df(df=default_user, cache_name="users", skip_cache=True) + self.data.insert_df(df=new_user, cache_name="users", skip_cache=True) def create_new_user(self, username: str, email: str, password: str) -> bool: """ @@ -341,7 +345,6 @@ class UserAccountManagement(BaseUser): username = self.create_guest() if not username: raise ValueError('GuestLimitExceeded!') - self.get_user_data(user_name=username) return username @staticmethod @@ -402,21 +405,21 @@ class UserExchangeManagement(UserAccountManagement): self.active_exchange(exchange=exchange, user_name=user_name, cmd='set') - def get_exchanges(self, user_name: str, category: str) -> list | None: + def get_exchanges(self, user_name: str, category: str) -> list: """ Retrieves the list of active or configured exchanges for a given user. :param user_name: The name of the user. :param category: The category to retrieve ('active_exchanges' or 'configured_exchanges'). - :return: A list of exchanges or None if user or category is not found. + :return: A list of exchanges or empty ist if user or category is not found. """ try: user = self.get_user_data(user_name) - exchanges = user.loc[0, category] + exchanges = user.iloc[0][category] return json.loads(exchanges) if exchanges else [] except (KeyError, IndexError, json.JSONDecodeError) as e: print(f"Error retrieving exchanges for user '{user_name}' and field '{category}': {str(e)}") - return None + return list() def active_exchange(self, exchange: str, user_name: str, cmd: str) -> None: """ @@ -427,7 +430,7 @@ class UserExchangeManagement(UserAccountManagement): :param user_name: The name of the user executing this command. """ user = self.get_user_data(user_name) - active_exchanges = user.loc[0, 'active_exchanges'] + active_exchanges = user.iloc[0]['active_exchanges'] if active_exchanges is None: active_exchanges = [] else: @@ -458,15 +461,16 @@ class UserIndicatorManagement(UserExchangeManagement): :return: A DataFrame containing the user's indicators or None if not found. """ # Retrieve the user's ID - user_id = self.get_id(user_name) + user_id = int(self.get_id(user_name)) # Fetch the indicators from the database using DataCache df = self.data.get_or_fetch_rows(cache_name='indicators', filter_vals=('creator', user_id)) # If indicators are found, process the JSON fields if df is not None and not df.empty: - df['source'] = df['source'].apply(json.loads) - df['properties'] = df['properties'].apply(json.loads) + # Ensure that we only apply json.loads if the value is not already a dict + df['source'] = df['source'].apply(lambda x: x if isinstance(x, dict) else json.loads(x)) + df['properties'] = df['properties'].apply(lambda x: x if isinstance(x, dict) else json.loads(x)) return df @@ -496,7 +500,7 @@ class UserIndicatorManagement(UserExchangeManagement): :param indicator_name: The name of the indicator to remove. :param user_name: The name of the user who created the indicator. """ - user_id = self.get_id(user_name) + user_id = int(self.get_id(user_name)) self.data.remove_row( filter_vals=('name', indicator_name), additional_filter=('creator', user_id), @@ -514,7 +518,7 @@ class UserIndicatorManagement(UserExchangeManagement): user = self.get_user_data(user_name) if user.empty: return - chart_view = json.loads(user.at[0, 'chart_views']) + chart_view = json.loads(user.iloc[0]['chart_views']) if prop is None: return chart_view @@ -547,12 +551,12 @@ class UserIndicatorManagement(UserExchangeManagement): if not isinstance(specific_property, str): raise ValueError("Specific property must be a string.") - chart_view = json.loads(user.at[0, 'chart_views']) + chart_view = json.loads(user.iloc[0]['chart_views']) if specific_property == 'exchange_name': if default_market is None: raise ValueError("Default market must be provided when setting the exchange name.") chart_view['market'] = default_market - chart_view['exchange_name'] = values + chart_view['exchange'] = values elif specific_property == 'timeframe': chart_view['timeframe'] = values elif specific_property == 'market': diff --git a/src/app.py b/src/app.py index 388accc..74d6982 100644 --- a/src/app.py +++ b/src/app.py @@ -50,7 +50,7 @@ def index(): Fetches data from brighter_trades and inject it into an HTML template. Renders the html template and serves the web application. """ - + # session.clear() # only for debugging!!!! try: # Log the user in. user_name = brighter_trades.users.load_or_create_user(username=session.get('user')) @@ -131,8 +131,8 @@ def ws(socket_conn): while True: msg = socket_conn.receive() if msg: - # If in json format the message gets converted into a dictionary - # otherwise it is handled as a status signal from the client + # If msg is in json, convert the message into a dictionary then feed the message handler. + # Otherwise, log the output. try: json_msg = json.loads(msg) json_msg_received(json_msg) diff --git a/src/archived_code/DataCache.py b/src/archived_code/DataCache.py index 2853bff..cebe81b 100644 --- a/src/archived_code/DataCache.py +++ b/src/archived_code/DataCache.py @@ -53,8 +53,8 @@ class DataCache: :return: None. """ records = pd.concat([more_records, self.get_cache(key)], axis=0, ignore_index=True) - records = records.drop_duplicates(subset="open_time", keep='first') - records = records.sort_values(by='open_time').reset_index(drop=True) + records = records.drop_duplicates(subset="time", keep='first') + records = records.sort_values(by='time').reset_index(drop=True) self.set_cache(data=records, key=key) def set_cache(self, data: Any, key: str, do_not_overwrite: bool = False) -> None: @@ -241,7 +241,7 @@ class DataCache: _timestamp = unix_time_millis(start_datetime) logger.debug(f"Start timestamp in human-readable form: {dt.datetime.utcfromtimestamp(_timestamp / 1000.0)}") - result = self.get_cache(key).query('open_time >= @_timestamp') + result = self.get_cache(key).query('time >= @_timestamp') return result except Exception as e: @@ -266,13 +266,13 @@ class DataCache: logger.debug(f'Got {len(new_records.index)} records from exchange_name') if not new_records.empty: data = pd.concat([data, new_records], axis=0, ignore_index=True) - data = data.drop_duplicates(subset="open_time", keep='first') + data = data.drop_duplicates(subset="time", keep='first') return data if self.db.table_exists(table_name=table_name): logger.debug('Table existed retrieving records from DB') logger.debug(f'Requesting from {st} to {et}') - records = self.db.get_timestamped_records(table_name=table_name, timestamp_field='open_time', st=st, et=et) + records = self.db.get_timestamped_records(table_name=table_name, timestamp_field='time', st=st, et=et) logger.debug(f'Got {len(records.index)} records from db') else: logger.debug(f"Table didn't exist fetching from {ex_details[2]}") @@ -354,7 +354,7 @@ class DataCache: logger.info(f"Starting to fill data holes for interval: {interval}") for index, row in records.iterrows(): - time_stamp = row['open_time'] + time_stamp = row['time'] if last_timestamp is None: last_timestamp = time_stamp @@ -374,7 +374,7 @@ class DataCache: for ts in range(int(last_timestamp) + step, int(time_stamp), step): new_row = row.copy() - new_row['open_time'] = ts + new_row['time'] = ts filled_records.append(new_row) logger.debug(f"Filled timestamp: {ts}") @@ -409,7 +409,7 @@ class DataCache: logger.info(f'{num_rec_records} candles retrieved from the exchange.') - open_times = candles.open_time + open_times = candles.time min_open_time = open_times.min() max_open_time = open_times.max() diff --git a/src/archived_code/test_DataCache.py b/src/archived_code/test_DataCache.py index 901c5aa..88a4a65 100644 --- a/src/archived_code/test_DataCache.py +++ b/src/archived_code/test_DataCache.py @@ -38,7 +38,7 @@ class TestDataCache(unittest.TestCase): CREATE TABLE IF NOT EXISTS test_table ( id INTEGER PRIMARY KEY, market_id INTEGER, - open_time INTEGER UNIQUE, + time INTEGER UNIQUE, open REAL NOT NULL, high REAL NOT NULL, low REAL NOT NULL, @@ -85,7 +85,7 @@ class TestDataCache(unittest.TestCase): def test_update_candle_cache(self): print('Testing update_candle_cache() method:') df_initial = pd.DataFrame({ - 'open_time': [1, 2, 3], + 'time': [1, 2, 3], 'open': [100, 101, 102], 'high': [110, 111, 112], 'low': [90, 91, 92], @@ -94,7 +94,7 @@ class TestDataCache(unittest.TestCase): }) df_new = pd.DataFrame({ - 'open_time': [3, 4, 5], + 'time': [3, 4, 5], 'open': [102, 103, 104], 'high': [112, 113, 114], 'low': [92, 93, 94], @@ -107,7 +107,7 @@ class TestDataCache(unittest.TestCase): result = self.data.get_cache(key=self.key1) expected = pd.DataFrame({ - 'open_time': [1, 2, 3, 4, 5], + 'time': [1, 2, 3, 4, 5], 'open': [100, 101, 102, 103, 104], 'high': [110, 111, 112, 113, 114], 'low': [90, 91, 92, 93, 94], @@ -134,7 +134,7 @@ class TestDataCache(unittest.TestCase): def test_get_records_since(self): print('Testing get_records_since() method:') df_initial = pd.DataFrame({ - 'open_time': [unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=i)) for i in range(3)], + 'time': [unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=i)) for i in range(3)], 'open': [100, 101, 102], 'high': [110, 111, 112], 'low': [90, 91, 92], @@ -148,7 +148,7 @@ class TestDataCache(unittest.TestCase): ex_details=['BTC/USD', '2h', 'binance']) expected = pd.DataFrame({ - 'open_time': df_initial['open_time'][:2].values, + 'time': df_initial['time'][:2].values, 'open': [100, 101], 'high': [110, 111], 'low': [90, 91], @@ -162,7 +162,7 @@ class TestDataCache(unittest.TestCase): print('Testing get_records_since_from_db() method:') df_initial = pd.DataFrame({ 'market_id': [None], - 'open_time': [unix_time_millis(dt.datetime.utcnow())], + 'time': [unix_time_millis(dt.datetime.utcnow())], 'open': [1.0], 'high': [1.0], 'low': [1.0], @@ -177,7 +177,7 @@ class TestDataCache(unittest.TestCase): end_datetime = dt.datetime.utcnow() result = self.data.get_records_since_from_db(table_name='test_table', st=start_datetime, et=end_datetime, rl=1, ex_details=['BTC/USD', '2h', 'binance']).sort_values( - by='open_time').reset_index(drop=True) + by='time').reset_index(drop=True) print("Columns in the result DataFrame:", result.columns) print("Result DataFrame:\n", result) @@ -188,7 +188,7 @@ class TestDataCache(unittest.TestCase): expected = pd.DataFrame({ 'market_id': [None], - 'open_time': [unix_time_millis(dt.datetime.utcnow())], + 'time': [unix_time_millis(dt.datetime.utcnow())], 'open': [1.0], 'high': [1.0], 'low': [1.0], diff --git a/src/candles.py b/src/candles.py index 3ac8ec9..85c2dd0 100644 --- a/src/candles.py +++ b/src/candles.py @@ -60,7 +60,7 @@ class Candles: At {minutes_per_candle} minutes_per_candle since {start_datetime}. There should be {minutes_since / minutes_per_candle} candles.""") # Return the results. The start_datetime was approximate, so we may have retrieved an extra result. - return candles[-num_candles:] + return self.convert_candles(candles[-num_candles:]) def set_new_candle(self, cdata): """ @@ -107,9 +107,10 @@ class Candles: # Log the completion to the console. log.info('set_candle_history(): Loading candle data...') + # Todo this doesn't seem necessary. # Load candles from database - _cdata = self.get_last_n_candles(num_candles=self.max_records, - asset=symbol, timeframe=interval, exchange=exchange_name, user_name=user_name) + # _cdata = self.get_last_n_candles(num_candles=self.max_records, + # asset=symbol, timeframe=interval, exchange=exchange_name, user_name=user_name) # Log the completion to the console. log.info('set_candle_history(): Candle data Loaded.') @@ -135,8 +136,8 @@ class Candles: # Make sure the index is looking nice candles.reset_index(inplace=True) - # Extract the open_time and volume columns. - volumes = candles.loc[:, ['open_time', 'volume']] + # Extract the time and volume columns. + volumes = candles.loc[:, ['time', 'volume']] # Add the color field calling get_color() to supply the values. volumes["color"] = volumes.apply(get_color, axis=1) # Rename volume column to value @@ -162,11 +163,8 @@ class Candles: if value_name == 'volume': values = self.get_colour_coded_volume(candles) else: - values = candles[['open_time', value_name]] + values = candles[['time', value_name]] - values = values.rename({'open_time': 'time'}, axis=1) - # The timestamps are in milliseconds but lightweight charts needs it divided by 1000. - values['time'] = values['time'].div(1000) return values @staticmethod @@ -175,12 +173,12 @@ class Candles: Converts a dataframe of candlesticks into the format lightweight charts expects. :param candles: dt.dataframe - :return: List - [{'open_time': value, 'open': value,...},...] + :return: List - [{'time': value, 'open': value,...},...] """ - new_candles = candles.loc[:, ['open_time', 'open', 'high', 'low', 'close']] + new_candles = candles.loc[:, ['time', 'open', 'high', 'low', 'close', 'volume']] - new_candles.rename(columns={'open_time': 'time'}, inplace=True) + new_candles.rename(columns={'time': 'time'}, inplace=True) # The timestamps are in milliseconds but lightweight charts needs it divided by 1000. new_candles.loc[:, ['time']] = new_candles.loc[:, ['time']].div(1000) @@ -219,4 +217,4 @@ class Candles: exchange=exchange_name, user_name=user_name) # Reformat relevant candlestick data into a list of python dictionary objects. - return self.convert_candles(candlesticks) + return candlesticks diff --git a/src/indicators.py b/src/indicators.py index 0cd77be..2f76e2d 100644 --- a/src/indicators.py +++ b/src/indicators.py @@ -1,14 +1,13 @@ import json import random -from typing import Any, List, Optional, Dict +from typing import Any, Optional, Dict import numpy as np import pandas as pd import talib -# A list container to hold all available indicator types. This list is -# appended everytime an indicator class is defined below. -indicator_types = [] +# A dictionary to hold both indicator types and their corresponding classes. +indicators_registry = {} class Indicator: @@ -17,26 +16,29 @@ class Indicator: self.properties = properties self.properties.setdefault('type', indicator_type) self.properties.setdefault('value', 0) + self.properties.setdefault('period', 14) def calculate(self, candles: pd.DataFrame, user_name: str, num_results: int = 1) -> dict: """ - Calculates the indicator values over a span of price data. - - :param candles: The candlestick data. - :param user_name: The user_name. - :param num_results: The number of requested results. - :return: A dictionary of indicator records. + Default calculation for indicators that follows a standard process: + - Converts the 'close' prices to NumPy array + - Uses the 'process' method to apply the indicator calculation + - Returns the result as a DataFrame """ + # Converts the close prices from the DataFrame to a NumPy array of floats closes = candles.close.to_numpy(dtype='float') + # Processing the close prices to calculate the Indicator i_values = self.process(closes, self.properties['period']) + # Stores the last calculated value. self.properties['value'] = round(float(i_values[-1]), 2) - df = pd.DataFrame({'time': candles.open_time, 'value': i_values.tolist()}) - r_data = df.iloc[self.properties['period']:] + # Create a DataFrame with 'time' and 'value' + df = pd.DataFrame({'time': candles.time, 'value': i_values.tolist()}) - return {"type": self.properties['type'], "data": r_data.to_dict('records')} + # Slice the DataFrame to skip initial rows where the indicator will be undefined + return df.iloc[self.properties['period']:] def process(self, data, period): """ @@ -48,263 +50,176 @@ class Indicator: class Volume(Indicator): def __init__(self, name: str, indicator_type: str, properties: dict): super().__init__(name, indicator_type, properties) + # Default display properties for Volume + self.properties.setdefault('color_up', 'rgba(0, 150, 136, 0.8)') # Green for increased volume + self.properties.setdefault('color_down', 'rgba(255, 82, 82, 0.8)') # Red for decreased volume - def calculate(self, candles: pd.DataFrame, user_name: str, num_results: int = 1) -> dict: + def calculate(self, candles: pd.DataFrame, user_name: str, num_results: int = 1) -> pd.DataFrame: """ - Fetches the volume data and combines it with red or green - color info representing higher or lower volume changes. - - :param candles: The price data to analyze. - :param user_name: The name of the user executing this function. - :param num_results: The number of results requested. - :return: A dictionary of volume records. + Custom calculation for Volume that doesn't need NaN handling. + Returns a DataFrame with 'time', 'value', and 'color' columns. """ def get_color(row): row_index = row.name if (row_index - 1) not in volumes.index: - return 'rgba(0, 150, 136, 0.8)' # Green + return self.properties['color_up'] # Default green if cndls.close.iloc[row_index - 1] < cndls.close.iloc[row_index]: - return 'rgba(0, 150, 136, 0.8)' # Green + return self.properties['color_up'] # Increased volume (green) else: - return 'rgba(255, 82, 82, 0.8)' # Red + return self.properties['color_down'] # Decreased volume (red) cndls = candles.copy().reset_index(drop=True) - # Extract the open_time and volume columns - volumes = cndls.loc[:, ['open_time', 'volume']] - - # Add the color field using apply() and get_color() function + # Extract the open time and volume columns + volumes = cndls.loc[:, ['time', 'volume']] volumes['color'] = volumes.apply(get_color, axis=1) # Rename the volume column to 'value' volumes = volumes.rename(columns={'volume': 'value'}) - - # Rename the open_time column to 'time' - volumes = volumes.rename(columns={'open_time': 'time'}) - - # Get the last volume value as the current volume current_volume = volumes['value'].iloc[-1] self.properties['value'] = float(current_volume) - # Prepare the result data with the required structure - r_data = volumes.to_dict('records') - - return {"type": self.properties['type'], "data": r_data} - - -indicator_types.append('Volume') + # Return the DataFrame for consistency + return volumes class SMA(Indicator): def __init__(self, name: str, indicator_type: str, properties: dict): super().__init__(name, indicator_type, properties) + # Default display properties for SMA self.properties.setdefault('color', f"#{random.randrange(0x1000000):06x}") + self.properties.setdefault('thickness', 1) # Default line thickness self.properties.setdefault('period', 20) def process(self, data: np.ndarray, period: int) -> np.ndarray: """ Calculate the Simple Moving Average (SMA) of the given data. - - :param data: A numpy array of data points. - :param period: The period over which to calculate the SMA. - :return: A numpy array containing the SMA values. """ return talib.SMA(data, period) -indicator_types.append('SMA') - - class EMA(SMA): - def __init__(self, name: str, indicator_type: str, properties: dict): - super().__init__(name, indicator_type, properties) - def process(self, data: np.ndarray, period: int) -> np.ndarray: """ Calculate the Exponential Moving Average (EMA) of the given data. - - :param data: A numpy array of data points. - :param period: The period over which to calculate the EMA. - :return: A numpy array containing the EMA values. """ return talib.EMA(data, period) -indicator_types.append('EMA') - - class RSI(SMA): def __init__(self, name: str, indicator_type: str, properties: dict): super().__init__(name, indicator_type, properties) + # Default display properties for RSI + self.properties.setdefault('period', 14) def process(self, data: np.ndarray, period: int) -> np.ndarray: """ Calculate the Relative Strength Index (RSI) of the given data. - - :param data: A numpy array of data points. - :param period: The period over which to calculate the RSI. - :return: A numpy array containing the RSI values. """ return talib.RSI(data, period) -indicator_types.append('RSI') - - class LREG(SMA): - def __init__(self, name: str, indicator_type: str, properties: dict): - super().__init__(name, indicator_type, properties) - def process(self, data: np.ndarray, period: int) -> np.ndarray: """ Calculate the Linear Regression (LREG) of the given data. - - :param data: A numpy array of data points. - :param period: The period over which to calculate the LREG. - :return: A numpy array containing the LREG values. """ return talib.LINEARREG(data, period) -indicator_types.append('LREG') - - class ATR(SMA): - def __init__(self, name: str, indicator_type: str, properties: dict): - super().__init__(name, indicator_type, properties) - - def calculate(self, candles: pd.DataFrame, user_name: str, num_results: int = 1) -> dict: + def calculate(self, candles: pd.DataFrame, user_name: str, num_results: int = 1) -> pd.DataFrame: """ Calculate the Average True Range (ATR) indicator. - - :param candles: A DataFrame containing candlestick data. - :param user_name: The name of the user executing this function. - :param num_results: The number of results requested. - :return: A dictionary with the type and data of the indicator. """ highs = candles.high.to_numpy(dtype='float') lows = candles.low.to_numpy(dtype='float') closes = candles.close.to_numpy(dtype='float') - atr = talib.ATR(high=highs, - low=lows, - close=closes, - timeperiod=self.properties['period']) + # Calculate ATR using the talib library + atr = talib.ATR(high=highs, low=lows, close=closes, timeperiod=self.properties['period']) - df = pd.DataFrame({'time': candles.open_time, 'value': atr}) - - r_data = df.iloc[self.properties['period']:].to_dict('records') + # Create DataFrame with 'time' and 'value' columns + df = pd.DataFrame({'time': candles.time, 'value': atr}) + # Store the last calculated ATR value self.properties['value'] = round(float(atr[-1]), 2) - return {"type": self.properties['type'], "data": r_data} - - -indicator_types.append('ATR') + # Return the sliced DataFrame, excluding rows where the indicator is not fully calculated + return df.iloc[self.properties['period']:] class BolBands(Indicator): - def __init__(self, name: str, indicator_type: str, properties: dict): - super().__init__(name, indicator_type, properties) - ul_col = f"#{random.randrange(0x1000000):06x}" - self.properties.setdefault('period', 50) - self.properties.setdefault('color_1', ul_col) - self.properties.setdefault('color_2', f"#{random.randrange(0x1000000):06x}") - self.properties.setdefault('color_3', ul_col) - self.properties.setdefault('value', 0) - self.properties.setdefault('value2', 0) - self.properties.setdefault('value3', 0) - self.properties.setdefault('devup', 2) - self.properties.setdefault('devdn', 2) - self.properties.setdefault('ma', 1) - - def calculate(self, candles: pd.DataFrame, user_name: str, num_results: int = 1) -> dict: + def calculate(self, candles: pd.DataFrame, user_name: str, num_results: int = 1) -> pd.DataFrame: """ Calculate the Bollinger Bands indicator for the given candles. - - :param candles: The candlestick data. - :param user_name: The user_name. - :param num_results: The number of results requested. - :return: A dictionary containing the calculated Bollinger Bands values. """ np_real_data = candles.close.to_numpy(dtype='float') + # Calculate the Bollinger Bands (upper, middle, lower) upper, middle, lower = talib.BBANDS(np_real_data, timeperiod=self.properties['period'], nbdevup=self.properties['devup'], nbdevdn=self.properties['devdn'], matype=self.properties['ma']) + # Store the last calculated values in properties self.properties['value'] = round(float(upper[-1]), 2) self.properties['value2'] = round(float(middle[-1]), 2) self.properties['value3'] = round(float(lower[-1]), 2) - df1 = pd.DataFrame({'time': candles.open_time, 'value': upper}).dropna() - df2 = pd.DataFrame({'time': candles.open_time, 'value': middle}).dropna() - df3 = pd.DataFrame({'time': candles.open_time, 'value': lower}).dropna() + # Create a DataFrame with 'time', 'upper', 'middle', 'lower' + df = pd.DataFrame({ + 'time': candles.time, + 'upper': upper, + 'middle': middle, + 'lower': lower + }) - r_data = [df1.to_dict('records'), df2.to_dict('records'), df3.to_dict('records')] - - return {"type": self.properties['type'], "data": r_data} - - -indicator_types.append('BOLBands') + # Slice the DataFrame to skip initial rows where the indicator might be undefined + return df.iloc[self.properties['period']:] class MACD(Indicator): - def __init__(self, name, indicator_type, properties): - super().__init__(name, indicator_type, properties) - - self.properties.setdefault('fast_p', 12) - self.properties.setdefault('slow_p', 26) - self.properties.setdefault('signal_p', 9) - self.properties.setdefault('macd', 0) - self.properties.setdefault('signal', 0) - self.properties.setdefault('hist', 0) - self.properties.setdefault('color_1', f"#{random.randrange(0x1000000):06x}") - self.properties.setdefault('color_2', f"#{random.randrange(0x1000000):06x}") - self.properties['period'] = self.properties['slow_p'] + self.properties['signal_p'] - 2 - - # Adjusting the period - # Not sure about the lookback period for macd algorithm below was a result of trial and error. - num = self.properties['slow_p'] + self.properties['signal_p'] - 2 - self.properties['period'] = num - - def calculate(self, candles: pd.DataFrame, user_name: str, num_results: int = 800) -> dict: + def calculate(self, candles: pd.DataFrame, user_name: str, num_results: int = 1) -> pd.DataFrame: """ Calculate the MACD indicator for the given candles. - - :param candles: The candlestick data. - :param user_name: The user_name. - :param num_results: The number of results requested. - :return: A dictionary containing the calculated MACD values. """ - if self.properties['fast_p'] >= self.properties['slow_p']: - raise ValueError('The fast_period should be less than the slow_period.') + closes = candles.close.to_numpy(dtype='float') - closing_data = candles.close - if len(closing_data) < num_results: - print(f"Not enough data available to calculate {self.properties['type']} for the given time period.") - return {} - - closes = closing_data.to_numpy(dtype='float') - macd, signal, hist = talib.MACD(closes, self.properties['fast_p'], self.properties['slow_p'], - self.properties['signal_p']) + # Calculate MACD, Signal Line, and MACD Histogram + macd, signal, hist = talib.MACD(closes, + fastperiod=self.properties['fast_p'], + slowperiod=self.properties['slow_p'], + signalperiod=self.properties['signal_p']) + # Store the last calculated values self.properties['macd'] = round(float(macd[-1]), 2) self.properties['signal'] = round(float(signal[-1]), 2) + self.properties['hist'] = round(float(hist[-1]), 2) - df1 = pd.DataFrame({'time': closing_data.time, 'value': macd}).dropna() - df2 = pd.DataFrame({'time': closing_data.time, 'value': signal}).dropna() - df3 = pd.DataFrame({'time': closing_data.time, 'value': hist}).dropna() + # Create a DataFrame with 'time', 'macd', 'signal', 'hist' + df = pd.DataFrame({ + 'time': candles.time, + 'macd': macd, + 'signal': signal, + 'hist': hist + }) - r_data = [df1.to_dict('records'), df2.to_dict('records'), df3.to_dict('records')] - - return {"type": self.properties['type'], "data": r_data} + # Slice the DataFrame to skip initial rows where the indicator will be undefined + return df.iloc[self.properties['signal_p']:] -indicator_types.append('MACD') +# Register indicators in the registry +indicators_registry['Volume'] = Volume +indicators_registry['SMA'] = SMA +indicators_registry['EMA'] = EMA +indicators_registry['RSI'] = RSI +indicators_registry['LREG'] = LREG +indicators_registry['ATR'] = ATR +indicators_registry['BOLBands'] = BolBands +indicators_registry['MACD'] = MACD class Indicators: @@ -319,8 +234,8 @@ class Indicators: self.indicators = pd.DataFrame(columns=['creator', 'name', 'visible', 'kind', 'source', 'properties', 'ref']) - # Create an instance reference of all available indicator types in the global list. - self.indicator_types = indicator_types + # Available indicator types and classes from a global indicators_registry. + self.indicator_registry = indicators_registry # Enums values to use with Bolenger-bands. self.MV_AVERAGE_ENUM = {'SMA': 0, 'EMA': 1, 'WMA': 2, 'DEMA': 3, 'TEMA': 4, @@ -380,6 +295,10 @@ class Indicators: # } # return indicator_list + def get_available_indicator_types(self) -> list: + """Returns a list of all available indicator types.""" + return list(self.indicator_registry.keys()) + def get_indicator_list(self, username: str, only_enabled: bool = False) -> Dict[str, Dict[str, Any]]: """ Returns a dictionary of all indicators available to this user. @@ -394,7 +313,7 @@ class Indicators: raise ValueError(f"Invalid user_name: {username}") if only_enabled: - indicators_df = self.indicators.query("creator == @user_id and visible=='True'") + indicators_df = self.indicators.query("creator == @user_id and visible == 1") else: indicators_df = self.indicators.query('creator == @user_id') @@ -425,14 +344,14 @@ class Indicators: """ # Validate inputs if user_id not in self.indicators['creator'].unique(): - # raise ValueError(f"Invalid user_id: {user_id}") + # raise ValueError(f"Invalid user_name: {user_name}") # Nothing may be loaded. return # Set visibility for all indicators of the user - self.indicators.loc[self.indicators['creator'] == user_id, 'visible'] = False + self.indicators.loc[self.indicators['creator'] == user_id, 'visible'] = 0 # Set visibility for the specified indicator names - self.indicators.loc[self.indicators['name'].isin(indicator_names), 'visible'] = True + self.indicators.loc[self.indicators['name'].isin(indicator_names), 'visible'] = 1 def edit_indicator(self, user_name: str, params: Any): # if 'submit' in request.form: @@ -453,7 +372,7 @@ class Indicators: if 'delete' in params: indicator = params['delete'] # This will delete in both indicators and config. - self.delete_indicator(indicator) + self.delete_indicator(indicator_name=indicator, user_name=user_name) def new_indicator(self, user_name: str, params) -> None: """ @@ -474,9 +393,8 @@ class Indicators: # Create a dictionary of properties from the values in request form. source = { - 'source': 'price_data', - 'market': params['ei_symbol'], - 'time_frame': params['ei_timeframe'], + 'symbol': params['ei_symbol'], + 'timeframe': params['ei_timeframe'], 'exchange_name': params['ei_exchange_name'] } @@ -534,14 +452,14 @@ class Indicators: user_id = self.users.get_id(user_name=user_name) - # Construct the query based on user_id and visibility. + # Construct the query based on user_name and visibility. query = f"creator == {user_id}" if visible_only: - query += " and visible == True" + query += " and visible == 1" # Filter the indicators based on the query. indicators = self.indicators.loc[ - (self.indicators['creator'] == user_id) & (self.indicators['visible'] == 'True')] + (self.indicators['creator'] == user_id) & (self.indicators['visible'] == 1)] # Return None if no indicators matched the query. if indicators.empty: @@ -549,7 +467,7 @@ class Indicators: self.load_indicators(user_name=user_name) # query again. indicators = self.indicators.loc[ - (self.indicators['creator'] == user_id) & (self.indicators['visible'] == 'True')] + (self.indicators['creator'] == user_id) & (self.indicators['visible'] == 1)] if indicators.empty: return None @@ -560,29 +478,37 @@ class Indicators: timeframe = source['market']['timeframe'] exchange = source['market']['exchange'] indicators = indicators[indicators.source.apply(lambda x: x['symbol'] == symbol and - x['timeframe'] == timeframe and - x['exchange_name'] == exchange)] + x['timeframe'] == timeframe and + x['exchange_name'] == exchange)] else: raise ValueError(f'No implementation for source: {source}') - # Process each indicator and collect the results in a dictionary. - results = {} + # Process each indicator, convert DataFrame to JSON-serializable format, and collect the results + json_ready_results = {} for indicator in indicators.itertuples(index=False): indicator_results = self.process_indicator(indicator=indicator, num_results=num_results) - results[indicator.name] = indicator_results - return results + # Convert DataFrame to list of dictionaries if necessary + if isinstance(indicator_results, pd.DataFrame): + json_ready_results[indicator.name] = indicator_results.to_dict(orient='records') + else: + json_ready_results[indicator.name] = indicator_results # If not a DataFrame, leave as is - def delete_indicator(self, indicator_name: str) -> None: + return json_ready_results + + def delete_indicator(self, indicator_name: str, user_name: str) -> None: """ Remove the indicator by name + :param user_name: username :param indicator_name: The name of the indicator to remove. :return: None """ if not indicator_name: raise ValueError("No indicator name provided.") - self.indicators = self.indicators.query("name != @indicator_name").reset_index(drop=True) - self.users.save_indicators() + self.users.remove_indicator(indicator_name=indicator_name, user_name=user_name) + + # Force reload from database to refresh cache + self.load_indicators(user_name=user_name) def create_indicator(self, creator: str, name: str, kind: str, source: dict, properties: dict, visible: bool = True): @@ -599,21 +525,20 @@ class Indicators: :param visible: Whether to display it in the chart view. :return: None """ - indicator_classes = { - 'SMA': SMA, - 'EMA': EMA, - 'RSI': RSI, - 'LREG': LREG, - 'ATR': ATR, - 'BOLBands': BolBands, - 'MACD': MACD, - 'Volume': Volume - } # todo define this instead of indicator_types as a global - if kind not in indicator_classes: - raise ValueError(f"[INDICATORS.PY]: Requested an unsupported type of indicator: ({kind})") + self.indicators = self.indicators.reset_index(drop=True) + creator_id = self.users.get_id(creator) + # Check if an indicator with the same name already exists + existing_indicator = self.indicators.query('name == @name and creator == @creator_id') - indicator_class = indicator_classes[kind] + if not existing_indicator.empty: + print(f"Indicator '{name}' already exists for user '{creator}'. Skipping creation.") + return # Exit the method to prevent duplicate creation + + if kind not in self.indicator_registry: + raise ValueError(f"Requested an unsupported type of indicator: ({kind})") + + indicator_class = self.indicator_registry[kind] # Create an instance of the indicator. indicator = indicator_class(name, kind, properties) diff --git a/src/maintenence/debuging_testing.py b/src/maintenence/debuging_testing.py index 0f13369..42d956a 100644 --- a/src/maintenence/debuging_testing.py +++ b/src/maintenence/debuging_testing.py @@ -1,33 +1,4 @@ -import pandas as pd -import matplotlib.pyplot as plt -from matplotlib.table import Table - -# Simulating the cache as a DataFrame -data = { - 'key': ['BTC/USD_2h_binance', 'ETH/USD_1h_coinbase'], - 'data': ['{"open": 50000, "close": 50500}', '{"open": 1800, "close": 1825}'] -} -cache_df = pd.DataFrame(data) - - -# Visualization function -def visualize_cache(df): - fig, ax = plt.subplots(figsize=(6, 3)) - ax.set_axis_off() - tb = Table(ax, bbox=[0, 0, 1, 1]) - - # Adding column headers - for i, column in enumerate(df.columns): - tb.add_cell(0, i, width=0.4, height=0.3, text=column, loc='center', facecolor='lightgrey') - - # Adding rows and cells - for i in range(len(df)): - for j, value in enumerate(df.iloc[i]): - tb.add_cell(i + 1, j, width=0.4, height=0.3, text=value, loc='center', facecolor='white') - - ax.add_table(tb) - plt.title("Visualizing Cache Data") - plt.show() - - -visualize_cache(cache_df) +""" +set_cache_item + create_cache +""" \ No newline at end of file diff --git a/src/shared_utilities.py b/src/shared_utilities.py index ede448f..51f6a4e 100644 --- a/src/shared_utilities.py +++ b/src/shared_utilities.py @@ -18,7 +18,7 @@ def query_uptodate(records: pd.DataFrame, r_length_min: float) -> Union[float, N """ print('\nChecking if the records are up-to-date...') # Get the newest timestamp from the records passed in stored in ms - last_timestamp = float(records.open_time.max()) + last_timestamp = float(records.time.max()) print(f'The last timestamp on record is {last_timestamp}') # Get a timestamp of the UTC time in milliseconds to match the records in the DB @@ -104,7 +104,7 @@ def query_satisfied(start_datetime: dt.datetime, records: pd.DataFrame, r_length return float('nan') # Get the oldest timestamp from the records passed in - first_timestamp = float(records.open_time.min()) + first_timestamp = float(records.time.min()) print(f'First timestamp in records: {first_timestamp}') # Calculate the total duration of the records in milliseconds diff --git a/src/static/communication.js b/src/static/communication.js index 8aaba68..ea98a2d 100644 --- a/src/static/communication.js +++ b/src/static/communication.js @@ -176,6 +176,8 @@ class Comms { } else if (message.reply === 'trade_created') { const list_of_one = [message.data]; window.UI.trade.set_data(list_of_one); + } else if (message.reply === 'Exchange_connection_result') { + window.UI.exchanges.postConnection(message.data); } else { console.log(message.reply); console.log(message.data); diff --git a/src/static/exchanges.js b/src/static/exchanges.js index 9346419..a0032bc 100644 --- a/src/static/exchanges.js +++ b/src/static/exchanges.js @@ -7,26 +7,40 @@ class Exchanges { initialize() { let el = document.getElementById('conned_excs'); - this.connected_exchanges = el.innerHTML; + // Select all spans within the container + let spans = el.querySelectorAll('.avail_exchange'); + + // Extract the text content from each span and store it in the connected_exchanges array + this.connected_exchanges = Array.from(spans).map(span => span.textContent.trim()); } - status(){ + + status() { document.getElementById('exchanges_config_form').style.display = "grid"; } - closeForm(){ + closeForm() { document.getElementById('exchanges_config_form').style.display = "none"; } validateApiKey(data) { - if (data === undefined || data === null || data === "") { - alert('Enter a valid API key to register.'); - return false; - } else { - return true; + return !(data === undefined || data === null || data === ""); + } + + postConnection(data) { + if (data.status === 'success') { + // Trigger a page reload + location.reload(); + } + else if (data.status === 'already_connected') { + alert(data.message); + } + else if (data.status === 'failure') { + alert(data.message); } } + submitApi() { // Collect the data to submit. let exchange = document.getElementById('c_exchanges').value; @@ -35,15 +49,26 @@ class Exchanges { let secret_key = document.getElementById('api_secret_key').value; let keys = { 'key': key, 'secret': secret_key }; - // Validate the data. - let success = this.validateApiKey(key) && this.validateApiKey(secret_key); - if (success) { - // Send the valid data. - let payload = { 'user': user, 'exch': exchange, 'keys': keys }; - window.UI.data.comms.sendToApp("config_exchange", payload); - this.closeForm(); - // Refreshes the current page - setTimeout(function() {location.reload();}, 200); + // Determine if validation is required based on the exchange type. + const isPublicExchange = bt_data.public_exchanges.includes(exchange); + const isKeyValid = this.validateApiKey(key); + const isSecretKeyValid = this.validateApiKey(secret_key); + + // If it's a public exchange, we don't require API keys. + if (isPublicExchange) { + keys = {}; // Clear keys for public exchanges + } else if (!isKeyValid || !isSecretKeyValid) { + // Validate keys for non-public exchanges + alert('Enter a valid API key and secret key to register.'); + return; // Exit early if validation fails } + + // Send the data if validation passes or no validation is required. + let payload = { 'user': user, 'exch': exchange, 'keys': keys }; + window.UI.data.comms.sendToApp("config_exchange", payload); + this.closeForm(); + + // Refreshes the current page + // setTimeout(function() { location.reload(); }, 200); } -} \ No newline at end of file +} diff --git a/src/static/indicators.js b/src/static/indicators.js index 12948d5..850f97a 100644 --- a/src/static/indicators.js +++ b/src/static/indicators.js @@ -29,20 +29,33 @@ class Indicator_Output { } iOutput = new Indicator_Output(); +// Create a global map to store the mappings +const indicatorMap = new Map(); + class Indicator { constructor(name) { // The name of the indicator. this.name = name; - this.lines=[]; - this.hist=[]; + this.lines = []; + this.hist = []; } - init(data){ + + static getIndicatorConfig() { + return { + args: ['name'], + class: this + }; + } + + init(data) { console.log(this.name + ': init() unimplemented.'); } - update(data){ + + update(data) { console.log(this.name + ': update() unimplemented.'); } - addHist(name, chart, color='#26a69a'){ + + addHist(name, chart, color = '#26a69a') { this.hist[name] = chart.addHistogramSeries({ color: color, priceFormat: { @@ -55,18 +68,19 @@ class Indicator { }, }); } - addLine(name, chart, color, lineWidth){ + + addLine(name, chart, color, lineWidth) { this.lines[name] = chart.addLineSeries({ color: color, lineWidth: lineWidth }); - //Initialise the crosshair legend for the charts. + // Initialise the crosshair legend for the charts. iOutput.create_legend(this.name, chart, this.lines[name]); - } - setLine(name, data, value_name){ - console.log('indicators[68]: setLine takes:(name,data,value_name)') - console.log(name,data,value_name) + + setLine(name, data, value_name) { + console.log('indicators[68]: setLine takes:(name,data,value_name)'); + console.log(name, data, value_name); // Initialize the data with the data object provided. this.lines[name].setData(data); // Isolate the last value provided and round to 2 decimals places. @@ -75,15 +89,18 @@ class Indicator { // Update indicator output/crosshair legend. iOutput.set_legend_text(data.at(-1).value, this.name); } - updateDisplay(name, priceValue, value_name){ + + updateDisplay(name, priceValue, value_name) { let rounded_value = (Math.round(priceValue * 100) / 100).toFixed(2); // Update the data in the edit and view indicators panel document.getElementById(this.name + '_' + value_name).value = rounded_value; } - setHist(name, data){ + + setHist(name, data) { this.hist[name].setData(data); } - updateLine(name, data, value_name){ + + updateLine(name, data, value_name) { // Update the line-set data in the chart this.lines[name].update(data); // Update indicator output/crosshair legend. @@ -91,133 +108,175 @@ class Indicator { // Update the data in the edit and view indicators panel this.updateDisplay(name, data.value, value_name); } - updateHist(name,data){ + + updateHist(name, data) { this.hist[name].update(data); } } -class SMA extends Indicator{ +class SMA extends Indicator { constructor(name, chart, color, lineWidth = 2) { - // Call the inherited constructor. super(name); - - // Create a line series and append to the appropriate chart. - this.addLine('line', chart, color, lineWidth); - + this.addLine('line', chart, color, lineWidth); } - init(data){ - this.setLine('line',data, 'value'); - console.log('line data', data) + + static getIndicatorConfig() { + return { + args: ['name', 'chart_1', 'color'], + class: this + }; } - update(data){ + + init(data) { + this.setLine('line', data, 'value'); + console.log('line data', data); + } + + update(data) { this.updateLine('line', data[0], 'value'); } } +// Register SMA in the map +indicatorMap.set("SMA", SMA); -class Linear_Regression extends SMA{ +class Linear_Regression extends SMA { + // Inherits getIndicatorConfig from SMA } +indicatorMap.set("Linear_Regression", Linear_Regression); -class EMA extends SMA{ +class EMA extends SMA { + // Inherits getIndicatorConfig from SMA } +indicatorMap.set("EMA", EMA); -class RSI extends Indicator{ +class RSI extends Indicator { constructor(name, charts, color, lineWidth = 2) { - // Call the inherited constructor. super(name); - // If the chart doesn't exist create one. - if( !charts.hasOwnProperty('chart2') ) { charts.create_RSI_chart(); } + if (!charts.hasOwnProperty('chart2')) { + charts.create_RSI_chart(); + } let chart = charts.chart2; - // Create a line series and append to the appropriate chart. - this.addLine('line', chart, color, lineWidth); + this.addLine('line', chart, color, lineWidth); } - init(data){ - this.setLine('line',data, 'value'); + + static getIndicatorConfig() { + return { + args: ['name', 'charts', 'color'], + class: this + }; } - update(data){ + + init(data) { + this.setLine('line', data, 'value'); + } + + update(data) { this.updateLine('line', data[0], 'value'); } } +indicatorMap.set("RSI", RSI); -class MACD extends Indicator{ +class MACD extends Indicator { constructor(name, charts, color_m, color_s, lineWidth = 2) { - // Call the inherited constructor. super(name); - // If the chart doesn't exist create one. - if( !charts.hasOwnProperty('chart3') ) { charts.create_MACD_chart(); } + if (!charts.hasOwnProperty('chart3')) { + charts.create_MACD_chart(); + } let chart = charts.chart3; - // Create two line series and append to the chart. - this.addLine('line_m', chart, color_m, lineWidth); - this.addLine('line_s', chart, color_s, lineWidth); - this.addHist(name, chart); + this.addLine('line_m', chart, color_m, lineWidth); + this.addLine('line_s', chart, color_s, lineWidth); + this.addHist(name, chart); } - init(data){ - this.setLine('line_m',data[0], 'macd'); - this.setLine('line_s',data[1], 'signal'); + static getIndicatorConfig() { + return { + args: ['name', 'charts', 'color_1', 'color_2'], + class: this + }; + } + + init(data) { + this.setLine('line_m', data[0], 'macd'); + this.setLine('line_s', data[1], 'signal'); this.setHist(name, data[2]); } - update(data){ + update(data) { this.updateLine('line_m', data[0][0], 'macd'); this.updateLine('line_s', data[1][0], 'signal'); this.updateHist(name, data[2][0]); } } +indicatorMap.set("MACD", MACD); + + +class ATR extends Indicator { + // Inherits getIndicatorConfig from Indicator -class ATR extends Indicator{ init(data) { this.updateDisplay(this.name, data.at(-1).value, 'value'); } + update(data) { this.updateDisplay(this.name, data[0].value, 'value'); - } } +indicatorMap.set("ATR", ATR); -class Volume extends Indicator{ + +class Volume extends Indicator { constructor(name, chart) { - // Call the inherited constructor. super(name); this.addHist(name, chart); - this.hist[name].applyOptions( { scaleMargins: { top: 0.95, bottom: 0.0} } ); + this.hist[name].applyOptions({ scaleMargins: { top: 0.95, bottom: 0.0 } }); } - init(data){ + + static getIndicatorConfig() { + return { + args: ['name', 'chart_1'], + class: this + }; + } + + init(data) { this.setHist(this.name, data); } - update(data){ + + update(data) { this.updateHist(this.name, data[0]); } } +indicatorMap.set("Volume", Volume); -class Bolenger extends Indicator{ +class Bolenger extends Indicator { constructor(name, chart, color_u, color_m, color_l, lineWidth = 2) { - // Call the inherited constructor. super(name); - - // Create three line series and append to the chart. - this.addLine('line_u', chart, color_u, lineWidth); - this.addLine('line_m', chart, color_u, lineWidth); - this.addLine('line_l', chart, color_u, lineWidth); + this.addLine('line_u', chart, color_u, lineWidth); + this.addLine('line_m', chart, color_m, lineWidth); + this.addLine('line_l', chart, color_l, lineWidth); } - init(data){ - // Initialize the data with the data object provided. - this.setLine('line_u',data[0],'value'); - this.setLine('line_m',data[1],'value2'); - this.setLine('line_l',data[2],'value3'); + static getIndicatorConfig() { + return { + args: ['name', 'chart_1', 'color_1', 'color_2', 'color_3'], + class: this + }; } - update(data){ - // Update the line-set data in the chart + init(data) { + this.setLine('line_u', data[0], 'value'); + this.setLine('line_m', data[1], 'value2'); + this.setLine('line_l', data[2], 'value3'); + } + + update(data) { this.updateLine('line_u', data[0][0], 'value'); - // Update the line-set data in the chart this.updateLine('line_m', data[1][0], 'value2'); - // Update the line-set data in the chart this.updateLine('line_l', data[2][0], 'value3'); } - } +indicatorMap.set("Bolenger", Bolenger); class Indicators { constructor() { @@ -225,6 +284,33 @@ class Indicators { this.i_objs = {}; } + create_indicators(indicators, charts, bt_data) { + for (let name in indicators) { + if (!indicators[name].visible) continue; + + let i_type = indicators[name].type; + let IndicatorClass = indicatorMap.get(i_type); + + if (IndicatorClass) { + let { args, class: IndicatorConstructor } = IndicatorClass.getIndicatorConfig(); + + let preparedArgs = args.map(arg => { + if (arg === 'name') return name; + if (arg === 'charts') return charts; + if (arg === 'chart_1') return charts.chart_1; + if (arg === 'color') return indicators[name].color; + if (arg === 'color_1') return '#FF0000'; //bt_data.indicators[name].color_1 || red; + if (arg === 'color_2') return 'red'; // bt_data.indicators[name].color_2 || white; + if (arg === 'color_3') return 'red'; // bt_data.indicators[name].color_3 || blue; + }); + + this.i_objs[name] = new IndicatorConstructor(...preparedArgs); + } else { + console.error(`Unknown indicator type: ${i_type}`); + } + } + } + addToCharts(charts, idata){ /* Receives indicator data, creates and stores the indicator @@ -235,54 +321,6 @@ class Indicators { idata.indicator_data.then( (data) => { this.init_indicators(data); } ); } - create_indicators(indicators, charts){ - // loop through all the indicators received from the - // server and if the are enabled and create them. - for (let name in indicators) { - - // If this indicator is hidden skip to the next one - if (!indicators[name].visible) {continue;} - - // Get the type of indicator - let i_type = indicators[name].type; - - // Call the indicator creation function - if (i_type == 'SMA') { - // The color of the line - let color = indicators[name].color; - this.i_objs[name] = new SMA(name, charts.chart_1, color); - } - if (i_type == 'BOLBands') { - // The color of three lines - let color_u = bt_data.indicators[name].color_1; - let color_m = bt_data.indicators[name].color_2; - let color_l = bt_data.indicators[name].color_3; - this.i_objs[name] = new Bolenger(name, charts.chart_1, color_u, color_m, color_l); - } - if (i_type == 'MACD') { - // The color of two lines - let color_m = bt_data.indicators[name].color_1; - let color_s = bt_data.indicators[name].color_2; - this.i_objs[name] = new MACD(name, charts, color_m, color_s); - } - if (i_type == 'Volume') { this.i_objs[name] = new Volume(name, charts.chart_1); } - if (i_type == 'ATR') { this.i_objs[name] = new ATR(name); } - if (i_type == 'LREG') { - // The color of the line - let color = indicators[name].color; - this.i_objs[name] = new Linear_Regression(name, charts.chart_1, color); - } - if (i_type == 'RSI') { - let color = indicators[name].color; - this.i_objs[name] = new RSI(name, charts, color); - } - if (i_type == 'EMA') { - // The color of the line - let color = indicators[name].color; - this.i_objs[name] = new EMA(name, charts.chart_1, color); - } - } - } init_indicators(data){ // Loop through all the indicators. for (name in data){ @@ -291,13 +329,13 @@ class Indicators { console.log('could not load:', name); continue; } - this.i_objs[name].init(data[name]['data']); + this.i_objs[name].init(data[name]); } } update(updates){ for (name in updates){ - window.UI.indicators.i_objs[name].update(updates[name].data); + window.UI.indicators.i_objs[name].update(updates[name]); } } @@ -327,14 +365,71 @@ class Indicators { }else{ document.getElementById("new_prop_list").insertAdjacentHTML('beforeend', ',' + JSON.stringify(p)); } } - submit_new_i(){ + // Call to display Create new signal dialog. + open_form() { + // Show the form + document.getElementById("new_ind_form").style.display = "grid"; + + // Prefill the form fields with the current chart data (if available) + const marketField = document.querySelector('[name="ei_symbol"]'); + const timeframeField = document.querySelector('[name="ei_timeframe"]'); + const exchangeField = document.querySelector('[name="ei_exchange_name"]'); + + // Set default values if fields are empty + if (!marketField.value) { + marketField.value = window.UI.data.trading_pair || ''; // Set to trading pair or empty string + } + if (!timeframeField.value) { + timeframeField.value = window.UI.data.timeframe || ''; // Set to current timeframe or empty string + } + if (!exchangeField.value) { + exchangeField.value = window.UI.data.exchange || ''; // Set to current exchange or empty string + } + } + // Call to hide Create new signal dialog. + close_form() { document.getElementById("new_ind_form").style.display = "none"; } + submit_new_i() { /* Populates a hidden with a value from another element then submits the form Used in the create indicator panel.*/ - let pl=document.getElementById("new_prop_list").innerHTML; - if(pl) { pl = '[' + pl + ']'; } - document.getElementById("new_prop_obj").value=pl; + + // Perform validation + const name = document.querySelector('[name="newi_name"]').value; + const type = document.querySelector('[name="newi_type"]').value; + let market = document.querySelector('[name="ei_symbol"]').value; + const timeframe = document.querySelector('[name="ei_timeframe"]').value; + const exchange = document.querySelector('[name="ei_exchange_name"]').value; + + let errorMsg = ''; + + if (!name) { + errorMsg += 'Indicator name is required.\n'; + } + if (!type) { + errorMsg += 'Indicator type is required.\n'; + } + if (!market) { + market = window.UI.data.trading_pair; + document.querySelector('[name="ei_symbol"]').value = market; // Set the form field + } + if (!timeframe) { + errorMsg += 'Timeframe is required.\n'; + } + if (!exchange) { + errorMsg += 'Exchange name is required.\n'; + } + + if (errorMsg) { + alert(errorMsg); // Display the error messages + return; // Stop form submission if there are errors + } + + // If validation passes, proceed with form submission + let pl = document.getElementById("new_prop_list").innerHTML; + if (pl) { + pl = '[' + pl + ']'; + } + document.getElementById("new_prop_obj").value = pl; document.getElementById("new_i_form").submit(); } - } diff --git a/src/templates/exchange_config_popup.html b/src/templates/exchange_config_popup.html index 4c843f7..d47b143 100644 --- a/src/templates/exchange_config_popup.html +++ b/src/templates/exchange_config_popup.html @@ -20,7 +20,7 @@
- +
diff --git a/src/templates/exchange_info_hud.html b/src/templates/exchange_info_hud.html index 58c9920..a8bfdec 100644 --- a/src/templates/exchange_info_hud.html +++ b/src/templates/exchange_info_hud.html @@ -1,15 +1,15 @@
- + Connected: -
-
- {% for exchange in cond_exchanges %} - {{ exchange }} - {% endfor %} +
+
+ {% for exchange in cond_exchanges %} + {{ exchange }}{% if not loop.last %}, {% endif %} + {% endfor %} +
-
Configured:
diff --git a/src/templates/index.html b/src/templates/index.html index 8cdab63..c37f677 100644 --- a/src/templates/index.html +++ b/src/templates/index.html @@ -34,6 +34,7 @@ {% include "new_trade_popup.html" %} {% include "new_strategy_popup.html" %} {% include "new_signal_popup.html" %} + {% include "new_indicator_popup.html" %} {% include "trade_details_popup.html" %} {% include "indicator_popup.html" %} {% include "exchange_config_popup.html" %} diff --git a/src/templates/indicators_hud.html b/src/templates/indicators_hud.html index 8da40e7..cca0ecb 100644 --- a/src/templates/indicators_hud.html +++ b/src/templates/indicators_hud.html @@ -3,7 +3,7 @@
- +
@@ -77,61 +77,4 @@
-
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
\ No newline at end of file diff --git a/src/templates/new_indicator_popup.html b/src/templates/new_indicator_popup.html new file mode 100644 index 0000000..3a6a051 --- /dev/null +++ b/src/templates/new_indicator_popup.html @@ -0,0 +1,66 @@ +
+
+ +
+ + +

Create New Indicator

+ + + +
+ + + +

+ Properties: + + +

+
+ + +
+ + + + + +
+
+ + + + + {% for symbol in symbols %} + + {% endfor %} + + + + + + +
+ + +
+ + +
+ +
+
+
diff --git a/src/templates/price_chart.html b/src/templates/price_chart.html index f823ffc..972b38b 100644 --- a/src/templates/price_chart.html +++ b/src/templates/price_chart.html @@ -29,7 +29,7 @@
diff --git a/src/maintenence/working_public_exchanges.txt b/src/working_public_exchanges.txt similarity index 100% rename from src/maintenence/working_public_exchanges.txt rename to src/working_public_exchanges.txt diff --git a/tests/test_DataCache.py b/tests/test_DataCache.py index 05bbe93..ae0e9c2 100644 --- a/tests/test_DataCache.py +++ b/tests/test_DataCache.py @@ -1,7 +1,8 @@ +import pickle import time import pytz from DataCache_v3 import DataCache, timeframe_to_timedelta, estimate_record_count, InMemoryCache, DataCacheBase, \ - SnapshotDataCache + SnapshotDataCache, IndicatorCache from ExchangeInterface import ExchangeInterface import unittest import pandas as pd @@ -9,6 +10,12 @@ import datetime as dt import os from Database import SQLite, Database +import logging + +from indicators import Indicator + +logging.basicConfig(level=logging.DEBUG) + class DataGenerator: def __init__(self, timeframe_str): @@ -93,7 +100,7 @@ class DataGenerator: df = pd.DataFrame({ 'market_id': 1, - 'open_time': times, + 'time': times, 'open': [100 + i for i in range(num_rec)], 'high': [110 + i for i in range(num_rec)], 'low': [90 + i for i in range(num_rec)], @@ -198,49 +205,51 @@ class TestDataCache(unittest.TestCase): # Set up database and exchanges self.exchanges = ExchangeInterface() self.exchanges.connect_exchange(exchange_name='binance', user_name='test_guy', api_keys=None) + self.exchanges.connect_exchange(exchange_name='binance', user_name='user_1', api_keys=None) + self.exchanges.connect_exchange(exchange_name='binance', user_name='user_2', api_keys=None) self.db_file = 'test_db.sqlite' self.database = Database(db_file=self.db_file) # Create necessary tables sql_create_table_1 = f""" - CREATE TABLE IF NOT EXISTS test_table ( - id INTEGER PRIMARY KEY, - market_id INTEGER, - open_time INTEGER UNIQUE ON CONFLICT IGNORE, - open REAL NOT NULL, - high REAL NOT NULL, - low REAL NOT NULL, - close REAL NOT NULL, - volume REAL NOT NULL, - FOREIGN KEY (market_id) REFERENCES market (id) - )""" + CREATE TABLE IF NOT EXISTS test_table ( + id INTEGER PRIMARY KEY, + market_id INTEGER, + time INTEGER UNIQUE ON CONFLICT IGNORE, + open REAL NOT NULL, + high REAL NOT NULL, + low REAL NOT NULL, + close REAL NOT NULL, + volume REAL NOT NULL, + FOREIGN KEY (market_id) REFERENCES market (id) + )""" sql_create_table_2 = """ - CREATE TABLE IF NOT EXISTS exchange ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT UNIQUE - )""" + CREATE TABLE IF NOT EXISTS exchange ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE + )""" sql_create_table_3 = """ - CREATE TABLE IF NOT EXISTS markets ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - symbol TEXT, - exchange_id INTEGER, - FOREIGN KEY (exchange_id) REFERENCES exchange(id) - )""" + CREATE TABLE IF NOT EXISTS markets ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + symbol TEXT, + exchange_id INTEGER, + FOREIGN KEY (exchange_id) REFERENCES exchange(id) + )""" sql_create_table_4 = f""" - CREATE TABLE IF NOT EXISTS test_table_2 ( - key TEXT PRIMARY KEY, - data TEXT NOT NULL - )""" + CREATE TABLE IF NOT EXISTS test_table_2 ( + key TEXT PRIMARY KEY, + data TEXT NOT NULL + )""" sql_create_table_5 = """ - CREATE TABLE IF NOT EXISTS users ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_name TEXT, - age INTEGER, - users_data TEXT, - data TEXT, - password TEXT -- Moved to a new line and added a comma after 'data' - ) - """ + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT, + age INTEGER, + users_data TEXT, + data TEXT, + password TEXT -- Moved to a new line and added a comma after 'data' + ) + """ with SQLite(db_file=self.db_file) as con: con.execute(sql_create_table_1) @@ -248,9 +257,15 @@ class TestDataCache(unittest.TestCase): con.execute(sql_create_table_3) con.execute(sql_create_table_4) con.execute(sql_create_table_5) - self.data = DataCache(self.exchanges) - self.data.db = self.database + # Initialize DataCache, which inherits IndicatorCache + self.data = DataCache(self.exchanges) + self.data.db = self.database # Keep the database setup + + # Create caches needed for testing + self.data.create_cache('candles', cache_type=InMemoryCache) + + # Reuse details for exchange and market self.ex_details = ['BTC/USD', '2h', 'binance', 'test_guy'] self.key = f'{self.ex_details[0]}_{self.ex_details[1]}_{self.ex_details[2]}' @@ -463,13 +478,13 @@ class TestDataCache(unittest.TestCase): # Create the expected DataFrame expected = data_gen.create_table(num_rec=6, start=dt.datetime(2024, 8, 9, 0, 0, 0, tzinfo=dt.timezone.utc)) - print(f'The expected open_time values are:\n{expected["open_time"].tolist()}\n') + print(f'The expected time values are:\n{expected["time"].tolist()}\n') - # Assert that the open_time values in the result match those in the expected DataFrame, in order - assert result['open_time'].tolist() == expected['open_time'].tolist(), \ - f"open_time values in result are {result['open_time'].tolist()} expected {expected['open_time'].tolist()}" + # Assert that the time values in the result match those in the expected DataFrame, in order + assert result['time'].tolist() == expected['time'].tolist(), \ + f"time values in result are {result['time'].tolist()} expected {expected['time'].tolist()}" - print(f'The result open_time values match:\n{result["open_time"].tolist()}\n') + print(f'The result time values match:\n{result["time"].tolist()}\n') print(' - Update cache with new records passed.') def test_update_cached_dict(self): @@ -538,7 +553,7 @@ class TestDataCache(unittest.TestCase): df_initial = data_gen.generate_missing_section(df_initial, drop_start=2, drop_end=5) temp_df = df_initial.copy() - temp_df['open_time'] = pd.to_datetime(temp_df['open_time'], unit='ms') + temp_df['time'] = pd.to_datetime(temp_df['time'], unit='ms') print(f'Table Created:\n{temp_df}') if set_cache: @@ -559,14 +574,14 @@ class TestDataCache(unittest.TestCase): result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details) - expected = df_initial[df_initial['open_time'] >= data_gen.unix_time_millis(start_datetime)].reset_index( + expected = df_initial[df_initial['time'] >= data_gen.unix_time_millis(start_datetime)].reset_index( drop=True) temp_df = expected.copy() - temp_df['open_time'] = pd.to_datetime(temp_df['open_time'], unit='ms') + temp_df['time'] = pd.to_datetime(temp_df['time'], unit='ms') print(f'Expected table:\n{temp_df}') temp_df = result.copy() - temp_df['open_time'] = pd.to_datetime(temp_df['open_time'], unit='ms') + temp_df['time'] = pd.to_datetime(temp_df['time'], unit='ms') print(f'Resulting table:\n{temp_df}') if simulate_scenarios in ['not_enough_data', 'incomplete_data', 'missing_section']: @@ -575,10 +590,10 @@ class TestDataCache(unittest.TestCase): print("\nThe returned DataFrame has filled in the missing data!") else: assert result.shape == expected.shape, f"Shape mismatch: {result.shape} vs {expected.shape}" - pd.testing.assert_series_equal(result['open_time'], expected['open_time'], check_dtype=False) - print("\nThe DataFrames have the same shape and the 'open_time' columns match.") + pd.testing.assert_series_equal(result['time'], expected['time'], check_dtype=False) + print("\nThe DataFrames have the same shape and the 'time' columns match.") - oldest_timestamp = pd.to_datetime(result['open_time'].min(), unit='ms').tz_localize('UTC') + oldest_timestamp = pd.to_datetime(result['time'].min(), unit='ms').tz_localize('UTC') time_diff = oldest_timestamp - start_datetime max_allowed_time_diff = dt.timedelta(**{data_gen.timeframe_unit: data_gen.timeframe_amount}) @@ -588,7 +603,7 @@ class TestDataCache(unittest.TestCase): print(f'The first timestamp is {time_diff} from {start_datetime}') - newest_timestamp = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC') + newest_timestamp = pd.to_datetime(result['time'].max(), unit='ms').tz_localize('UTC') time_diff_end = abs(query_end_time - newest_timestamp) assert dt.timedelta(0) <= time_diff_end <= max_allowed_time_diff, \ @@ -630,7 +645,7 @@ class TestDataCache(unittest.TestCase): start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=2) # Query the records since the calculated start time. result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details) - last_record_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC') + last_record_time = pd.to_datetime(result['time'].max(), unit='ms').tz_localize('UTC') assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=15.1) print('\nTest get_records_since with a different timeframe') @@ -638,7 +653,7 @@ class TestDataCache(unittest.TestCase): start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=1) # Query the records since the calculated start time. result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details) - last_record_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC') + last_record_time = pd.to_datetime(result['time'].max(), unit='ms').tz_localize('UTC') assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=5.1) print('\nTest get_records_since with a different timeframe') @@ -646,7 +661,7 @@ class TestDataCache(unittest.TestCase): start_datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=12) # Query the records since the calculated start time. result = self.data.get_records_since(start_datetime=start_datetime, ex_details=ex_details) - last_record_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC') + last_record_time = pd.to_datetime(result['time'].max(), unit='ms').tz_localize('UTC') assert last_record_time > dt.datetime.now(dt.timezone.utc) - dt.timedelta(hours=4.1) def test_populate_db(self): @@ -680,12 +695,12 @@ class TestDataCache(unittest.TestCase): # Validate that the DataFrame is not empty self.assertFalse(result.empty, "The DataFrame returned from the exchange is empty.") - # Ensure that the 'open_time' column exists in the DataFrame - self.assertIn('open_time', result.columns, "'open_time' column is missing in the result DataFrame.") + # Ensure that the 'time' column exists in the DataFrame + self.assertIn('time', result.columns, "'time' column is missing in the result DataFrame.") # Check if the DataFrame contains valid timestamps within the specified range - min_time = pd.to_datetime(result['open_time'].min(), unit='ms').tz_localize('UTC') - max_time = pd.to_datetime(result['open_time'].max(), unit='ms').tz_localize('UTC') + min_time = pd.to_datetime(result['time'].min(), unit='ms').tz_localize('UTC') + max_time = pd.to_datetime(result['time'].max(), unit='ms').tz_localize('UTC') self.assertTrue(start_time <= min_time <= end_time, f"Data starts outside the expected range: {min_time}") self.assertTrue(start_time <= max_time <= end_time, f"Data ends outside the expected range: {max_time}") @@ -776,27 +791,87 @@ class TestDataCache(unittest.TestCase): def test_estimate_record_count(self): print('Testing estimate_record_count() function:') + # Test with '1h' timeframe (24 records expected) start_time = dt.datetime(2023, 8, 1, 0, 0, 0, tzinfo=dt.timezone.utc) end_time = dt.datetime(2023, 8, 2, 0, 0, 0, tzinfo=dt.timezone.utc) - result = estimate_record_count(start_time, end_time, '1h') - expected = 24 - self.assertEqual(result, expected, "Failed to estimate record count for 1h timeframe") + self.assertEqual(result, 24, "Failed to estimate record count for 1h timeframe") + # Test with '1d' timeframe (1 record expected) result = estimate_record_count(start_time, end_time, '1d') - expected = 1 - self.assertEqual(result, expected, "Failed to estimate record count for 1d timeframe") + self.assertEqual(result, 1, "Failed to estimate record count for 1d timeframe") - start_time = int(start_time.timestamp() * 1000) # Convert to milliseconds - end_time = int(end_time.timestamp() * 1000) # Convert to milliseconds + # Test with '1h' timeframe and timestamps in milliseconds + start_time_ms = int(start_time.timestamp() * 1000) # Convert to milliseconds + end_time_ms = int(end_time.timestamp() * 1000) # Convert to milliseconds + result = estimate_record_count(start_time_ms, end_time_ms, '1h') + self.assertEqual(result, 24, "Failed to estimate record count for 1h timeframe with milliseconds") - result = estimate_record_count(start_time, end_time, '1h') - expected = 24 - self.assertEqual(result, expected, "Failed to estimate record count for 1h timeframe with milliseconds") + # Test with '5m' timeframe and Unix timestamps in milliseconds + start_time_ms = 1672531200000 # Equivalent to '2023-01-01 00:00:00 UTC' + end_time_ms = 1672534800000 # Equivalent to '2023-01-01 01:00:00 UTC' + result = estimate_record_count(start_time_ms, end_time_ms, '5m') + self.assertEqual(result, 12, "Failed to estimate record count for 5m timeframe with milliseconds") + # Test with '5m' timeframe (12 records expected for 1-hour duration) + start_time = dt.datetime(2023, 1, 1, 0, 0, tzinfo=dt.timezone.utc) + end_time = dt.datetime(2023, 1, 1, 1, 0, tzinfo=dt.timezone.utc) + result = estimate_record_count(start_time, end_time, '5m') + self.assertEqual(result, 12, "Failed to estimate record count for 5m timeframe") + + # Test with '1M' (3 records expected for 3 months) + start_time = dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc) + end_time = dt.datetime(2023, 4, 1, tzinfo=dt.timezone.utc) + result = estimate_record_count(start_time, end_time, '1M') + self.assertEqual(result, 3, "Failed to estimate record count for 1M timeframe") + + # Test with invalid timeframe + with self.assertRaises(ValueError): + estimate_record_count(start_time, end_time, 'xyz') # Invalid timeframe + + # Test with invalid start_time passed in with self.assertRaises(ValueError): estimate_record_count("invalid_start", end_time, '1h') + # Cross-Year Transition (Months) + start_time = dt.datetime(2022, 12, 1, tzinfo=dt.timezone.utc) + end_time = dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc) + result = estimate_record_count(start_time, end_time, '1M') + self.assertEqual(result, 1, "Failed to estimate record count for month across years") + + # Leap Year (Months) + start_time = dt.datetime(2020, 2, 1, tzinfo=dt.timezone.utc) + end_time = dt.datetime(2021, 2, 1, tzinfo=dt.timezone.utc) + result = estimate_record_count(start_time, end_time, '1M') + self.assertEqual(result, 12, "Failed to estimate record count for months during leap year") + + # Sub-Minute Timeframes (e.g., 30 seconds) + start_time = dt.datetime(2023, 1, 1, 0, 0, tzinfo=dt.timezone.utc) + end_time = dt.datetime(2023, 1, 1, 0, 1, tzinfo=dt.timezone.utc) + result = estimate_record_count(start_time, end_time, '30s') + self.assertEqual(result, 2, "Failed to estimate record count for 30 seconds timeframe") + + # Different Timezones + start_time = dt.datetime(2023, 1, 1, 0, 0, tzinfo=dt.timezone(dt.timedelta(hours=5))) # UTC+5 + end_time = dt.datetime(2023, 1, 1, 1, 0, tzinfo=dt.timezone.utc) # UTC + result = estimate_record_count(start_time, end_time, '1h') + self.assertEqual(result, 6, + "Failed to estimate record count for different timezones") # Expect 6 records, not 1 + + # Test with zero-length interval (should return 0) + result = estimate_record_count(start_time, start_time, '1h') + self.assertEqual(result, 0, "Failed to return 0 for zero-length interval") + + # Test with negative interval (end_time earlier than start_time, should return 0) + result = estimate_record_count(end_time, start_time, '1h') + self.assertEqual(result, 0, "Failed to return 0 for negative interval") + + # Test with small interval compared to timeframe (should return 0) + start_time = dt.datetime(2023, 8, 1, 0, 0, tzinfo=dt.timezone.utc) + end_time = dt.datetime(2023, 8, 1, 0, 30, tzinfo=dt.timezone.utc) # 30 minutes + result = estimate_record_count(start_time, end_time, '1h') + self.assertEqual(result, 0, "Failed to return 0 for small interval compared to timeframe") + print(' - All estimate_record_count() tests passed.') def test_get_or_fetch_rows(self): @@ -941,7 +1016,8 @@ class TestDataCache(unittest.TestCase): # Assert that the data in the cache matches what was inserted self.assertIsNotNone(result, "No data found in the cache for the inserted ID.") - self.assertEqual(result.iloc[0]['user_name'], 'Alice', "The name in the cache doesn't match the inserted value.") + self.assertEqual(result.iloc[0]['user_name'], 'Alice', + "The name in the cache doesn't match the inserted value.") self.assertEqual(result.iloc[0]['age'], 30, "The age in the cache does not match the inserted value.") # Now test with skipping the cache @@ -963,11 +1039,11 @@ class TestDataCache(unittest.TestCase): # Create mock data with gaps df = pd.DataFrame({ - 'open_time': [dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc).timestamp() * 1000, - dt.datetime(2023, 1, 1, 2, tzinfo=dt.timezone.utc).timestamp() * 1000, - dt.datetime(2023, 1, 1, 6, tzinfo=dt.timezone.utc).timestamp() * 1000, - dt.datetime(2023, 1, 1, 8, tzinfo=dt.timezone.utc).timestamp() * 1000, - dt.datetime(2023, 1, 1, 12, tzinfo=dt.timezone.utc).timestamp() * 1000] + 'time': [dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc).timestamp() * 1000, + dt.datetime(2023, 1, 1, 2, tzinfo=dt.timezone.utc).timestamp() * 1000, + dt.datetime(2023, 1, 1, 6, tzinfo=dt.timezone.utc).timestamp() * 1000, + dt.datetime(2023, 1, 1, 8, tzinfo=dt.timezone.utc).timestamp() * 1000, + dt.datetime(2023, 1, 1, 12, tzinfo=dt.timezone.utc).timestamp() * 1000] }) # Call the method @@ -975,6 +1051,342 @@ class TestDataCache(unittest.TestCase): self.assertEqual(len(result), 7, "Data holes were not filled correctly.") print(' - _fill_data_holes passed.') + def test_get_cache_item(self): + # Case 1: Retrieve a stored Indicator instance (serialized) + indicator = Indicator(name='SMA', indicator_type='SMA', properties={'period': 5}) + self.data.set_cache_item('indicator_key', indicator, cache_name='indicators') + stored_data = self.data.get_cache_item('indicator_key', cache_name='indicators') + self.assertIsInstance(stored_data, Indicator, "Failed to retrieve and deserialize the Indicator instance") + + # Case 2: Retrieve non-Indicator data (e.g., dict) + data = {'key': 'value'} + self.data.set_cache_item('non_indicator_key', data) + stored_data = self.data.get_cache_item('non_indicator_key') + self.assertEqual(stored_data, data, "Failed to retrieve non-Indicator data correctly") + + # Case 3: Retrieve expired cache item (should return None) + self.data.set_cache_item('expiring_key', 'test_data', expire_delta=dt.timedelta(seconds=1)) + time.sleep(2) # Wait for the cache to expire + self.assertIsNone(self.data.get_cache_item('expiring_key'), "Expired cache item should return None") + + # Case 4: Retrieve non-existent key (should return None) + self.assertIsNone(self.data.get_cache_item('non_existent_key'), "Non-existent key should return None") + + # Case 5: Retrieve with invalid key type (should raise ValueError) + with self.assertRaises(ValueError): + self.data.get_cache_item(12345) # Invalid key type + + # Case 6: Test Deserialization Failure + # Simulate corrupted serialized data + corrupted_data = b'\x80\x03corrupted_data' + self.data.set_cache_item('corrupted_key', corrupted_data, cache_name='indicators') + with self.assertLogs(level='ERROR') as log: + self.assertIsNone(self.data.get_cache_item('corrupted_key', cache_name='indicators')) + self.assertIn("Deserialization failed", log.output[0]) + + # Case 7: Test Cache Eviction + # Create a cache with a limit of 2 items + self.data.set_cache_item('key1', 'data1', cache_name='test_cache', limit=2) + self.data.set_cache_item('key2', 'data2', cache_name='test_cache', limit=2) + self.data.set_cache_item('key3', 'data3', cache_name='test_cache', limit=2) + + # Verify that the oldest item (key1) has been evicted + self.assertIsNone(self.data.get_cache_item('key1', cache_name='test_cache')) + self.assertEqual(self.data.get_cache_item('key2', cache_name='test_cache'), 'data2') + self.assertEqual(self.data.get_cache_item('key3', cache_name='test_cache'), 'data3') + + def test_set_user_indicator_properties(self): + # Case 1: Store user-specific display properties + user_id = 'user123' + indicator_type = 'SMA' + symbol = 'AAPL' + timeframe = '1h' + exchange_name = 'NYSE' + display_properties = {'color': 'blue', 'line_width': 2} + + # Call the method to set properties + self.data.set_user_indicator_properties(user_id, indicator_type, symbol, timeframe, exchange_name, + display_properties) + + # Construct the cache key manually for validation + user_cache_key = f"user_{user_id}_{indicator_type}_{symbol}_{timeframe}_{exchange_name}" + + # Retrieve the stored properties + stored_properties = self.data.get_cache_item(user_cache_key, cache_name='user_display_properties') + + # Check if the properties were stored correctly + self.assertEqual(stored_properties, display_properties, "Failed to store user-specific display properties") + + # Case 2: Update existing user-specific properties + updated_properties = {'color': 'red', 'line_width': 3} + + # Update the properties + self.data.set_user_indicator_properties(user_id, indicator_type, symbol, timeframe, exchange_name, + updated_properties) + + # Retrieve the updated properties + updated_stored_properties = self.data.get_cache_item(user_cache_key, cache_name='user_display_properties') + + # Check if the properties were updated correctly + self.assertEqual(updated_stored_properties, updated_properties, + "Failed to update user-specific display properties") + + # Case 3: Handle invalid user properties (e.g., non-dict input) + with self.assertRaises(ValueError): + self.data.set_user_indicator_properties(user_id, indicator_type, symbol, timeframe, exchange_name, + "invalid_properties") + + def test_get_user_indicator_properties(self): + # Case 1: Retrieve existing user-specific display properties + user_id = 'user123' + indicator_type = 'SMA' + symbol = 'AAPL' + timeframe = '1h' + exchange_name = 'NYSE' + display_properties = {'color': 'blue', 'line_width': 2} + + # Set the properties first + self.data.set_user_indicator_properties(user_id, indicator_type, symbol, timeframe, exchange_name, + display_properties) + + # Retrieve the properties + retrieved_properties = self.data.get_user_indicator_properties(user_id, indicator_type, symbol, timeframe, + exchange_name) + self.assertEqual(retrieved_properties, display_properties, + "Failed to retrieve user-specific display properties") + + # Case 2: Handle missing key (should return None) + missing_properties = self.data.get_user_indicator_properties('nonexistent_user', indicator_type, symbol, + timeframe, exchange_name) + self.assertIsNone(missing_properties, "Expected None for missing user-specific display properties") + + # Case 3: Invalid argument handling + with self.assertRaises(TypeError): + self.data.get_user_indicator_properties(123, indicator_type, symbol, timeframe, + exchange_name) # Invalid user_id type + + def test_set_cache_item(self): + # Case 1: Store and retrieve an Indicator instance (serialized) + indicator = Indicator(name='SMA', indicator_type='SMA', properties={'period': 5}) + self.data.set_cache_item('indicator_key', indicator, cache_name='indicators') + stored_data = self.data.get_cache_item('indicator_key', cache_name='indicators') + self.assertIsInstance(stored_data, Indicator, "Failed to deserialize the Indicator instance") + + # Case 2: Store and retrieve non-Indicator data (e.g., dict) + data = {'key': 'value'} + self.data.set_cache_item('non_indicator_key', data) + stored_data = self.data.get_cache_item('non_indicator_key') + self.assertEqual(stored_data, data, "Non-Indicator data was modified or not stored correctly") + + # Case 3: Handle invalid key type (non-string) + with self.assertRaises(ValueError): + self.data.set_cache_item(12345, 'test_data') # Invalid key type + + # Case 4: Cache item expiration (item should expire after set time) + self.data.set_cache_item('expiring_key', 'test_data', expire_delta=dt.timedelta(seconds=1)) + time.sleep(2) # Wait for expiration time + self.assertIsNone(self.data.get_cache_item('expiring_key'), "Cached item did not expire as expected") + + def test_calculate_and_cache_indicator(self): + # Testing the calculation and caching of an indicator through DataCache (which includes IndicatorCache + # functionality) + + user_properties = {'color_line_1': 'blue', 'thickness_line_1': 2} + ex_details = ['BTC/USD', '5m', 'binance', 'test_guy'] + + # Define the time range for the calculation + start_datetime = dt.datetime(2023, 9, 1, 0, 0, 0, tzinfo=dt.timezone.utc) + end_datetime = dt.datetime(2023, 9, 2, 0, 0, 0, tzinfo=dt.timezone.utc) + + # Simulate calculating an indicator and caching it through DataCache + result = self.data.calculate_indicator( + user_name='test_guy', + symbol=ex_details[0], + timeframe=ex_details[1], + exchange_name=ex_details[2], + indicator_type='SMA', # Type of indicator + start_datetime=start_datetime, + end_datetime=end_datetime, + properties={'period': 5} # Add the necessary indicator properties like period + ) + + # Ensure that result is not None + self.assertIsNotNone(result, "Indicator calculation returned None.") + + def test_calculate_indicator_multiple_users(self): + """ + Test that the calculate_indicator method handles multiple users' requests with different properties. + """ + ex_details = ['BTC/USD', '5m', 'binance', 'test_guy'] + user1_properties = {'color': 'blue', 'thickness': 2} + user2_properties = {'color': 'red', 'thickness': 1} + + # Set user-specific properties + self.data.set_user_indicator_properties('user_1', 'SMA', 'BTC/USD', '5m', 'binance', user1_properties) + self.data.set_user_indicator_properties('user_2', 'SMA', 'BTC/USD', '5m', 'binance', user2_properties) + + # User 1 calculates the SMA indicator + result_user1 = self.data.calculate_indicator( + user_name='user_1', + symbol='BTC/USD', + timeframe='5m', + exchange_name='binance', + indicator_type='SMA', + start_datetime=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc), + end_datetime=dt.datetime(2023, 1, 2, tzinfo=dt.timezone.utc), + properties={'period': 5} + ) + + # User 2 calculates the same SMA indicator but with different display properties + result_user2 = self.data.calculate_indicator( + user_name='user_2', + symbol='BTC/USD', + timeframe='5m', + exchange_name='binance', + indicator_type='SMA', + start_datetime=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc), + end_datetime=dt.datetime(2023, 1, 2, tzinfo=dt.timezone.utc), + properties={'period': 5} + ) + + # Assert that the calculation data is the same + self.assertEqual(result_user1['calculation_data'], result_user2['calculation_data']) + + # Assert that the display properties are different + self.assertNotEqual(result_user1['display_properties'], result_user2['display_properties']) + + # Assert that the correct display properties are returned + self.assertEqual(result_user1['display_properties']['color'], 'blue') + self.assertEqual(result_user2['display_properties']['color'], 'red') + + def test_calculate_indicator_cache_retrieval(self): + """ + Test that cached data is retrieved efficiently without recalculating when the same request is made. + """ + ex_details = ['BTC/USD', '5m', 'binance', 'test_guy'] + properties = {'period': 5} + cache_key = 'BTC/USD_5m_binance_SMA_5' + + # First calculation (should store result in cache) + result_first = self.data.calculate_indicator( + user_name='user_1', + symbol='BTC/USD', + timeframe='5m', + exchange_name='binance', + indicator_type='SMA', + start_datetime=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc), + end_datetime=dt.datetime(2023, 1, 2, tzinfo=dt.timezone.utc), + properties=properties + ) + + # Check if the data was cached after the first calculation + cached_data = self.data.get_cache_item(cache_key, cache_name='indicator_data') + print(f"Cached Data after first calculation: {cached_data}") + + # Ensure the data was cached correctly + self.assertIsNotNone(cached_data, "The first calculation did not cache the result properly.") + + # Second calculation with the same parameters (should retrieve from cache) + with self.assertLogs(level='INFO') as log: + result_second = self.data.calculate_indicator( + user_name='user_1', + symbol='BTC/USD', + timeframe='5m', + exchange_name='binance', + indicator_type='SMA', + start_datetime=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc), + end_datetime=dt.datetime(2023, 1, 2, tzinfo=dt.timezone.utc), + properties=properties + ) + # Verify the log message for cache retrieval + self.assertTrue( + any(f"DataFrame retrieved from cache for key: {cache_key}" in message for message in log.output), + f"Cache retrieval log message not found for key: {cache_key}" + ) + + def test_calculate_indicator_partial_cache(self): + """ + Test handling of partial cache where some of the requested data is already cached, + and the rest needs to be fetched. + """ + ex_details = ['BTC/USD', '5m', 'binance', 'test_guy'] + properties = {'period': 5} + + # Simulate cache for part of the range (manual setup, no call to `get_records_since`) + cached_data = pd.DataFrame({ + 'time': pd.date_range(start="2023-01-01", periods=144, freq='5min', tz=dt.timezone.utc), + # Cached half a day of data + 'value': [16500 + i for i in range(144)] + }) + + # Generate cache key with correct format + cache_key = self.data._make_indicator_key('BTC/USD', '5m', 'binance', 'SMA', properties['period']) + + # Store the cached data as DataFrame (no need for to_dict('records')) + self.data.set_cache_item(cache_key, cached_data, cache_name='indicator_data') + + # Print cached data to inspect its range + print("Cached data time range:") + print(f"Min cached time: {cached_data['time'].min()}") + print(f"Max cached time: {cached_data['time'].max()}") + + # Now request a range that partially overlaps the cached data + result = self.data.calculate_indicator( + user_name='user_1', + symbol='BTC/USD', + timeframe='5m', + exchange_name='binance', + indicator_type='SMA', + start_datetime=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc), + end_datetime=dt.datetime(2023, 1, 2, tzinfo=dt.timezone.utc), + properties=properties + ) + + # Convert the result into a DataFrame + result_df = pd.DataFrame(result['calculation_data']) + + # Convert the 'time' column from Unix timestamp (ms) back to datetime with timezone + result_df['time'] = pd.to_datetime(result_df['time'], unit='ms', utc=True) + + # Debugging: print the full result to inspect the time range + print("Result data time range:") + print(f"Min result time: {result_df['time'].min()}") + print(f"Max result time: {result_df['time'].max()}") + + # Now you can safely find the min and max values + min_time = result_df['time'].min() + max_time = result_df['time'].max() + + # Debugging print statements to confirm the values + print(f"Min time in result: {min_time}") + print(f"Max time in result: {max_time}") + + # Assert that the min and max time in the result cover the full range from the cache and new data + self.assertEqual(min_time, pd.Timestamp("2023-01-01 00:00:00", tz=dt.timezone.utc)) + self.assertEqual(max_time, pd.Timestamp("2023-01-02 00:00:00", tz=dt.timezone.utc)) + + def test_calculate_indicator_no_data(self): + """ + Test that the indicator calculation handles cases where no data is available for the requested range. + """ + ex_details = ['BTC/USD', '5m', 'binance', 'test_guy'] + properties = {'period': 5} + + # Request data for a period where no data exists + result = self.data.calculate_indicator( + user_name='user_1', + symbol='BTC/USD', + timeframe='5m', + exchange_name='binance', + indicator_type='SMA', + start_datetime=dt.datetime(1900, 1, 1, tzinfo=dt.timezone.utc), + end_datetime=dt.datetime(1900, 1, 2, tzinfo=dt.timezone.utc), + properties=properties + ) + + # Ensure no calculation data is returned + self.assertEqual(len(result['calculation_data']), 0) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_Exchange.py b/tests/test_Exchange.py index 2f13c8e..e62d7bf 100644 --- a/tests/test_Exchange.py +++ b/tests/test_Exchange.py @@ -117,7 +117,7 @@ class TestExchange(unittest.TestCase): end_dt = datetime(2021, 1, 2) klines = self.exchange.get_historical_klines('BTC/USDT', '1d', start_dt, end_dt) expected_df = pd.DataFrame([ - {'open_time': 1609459200, 'open': 29000.0, 'high': 29500.0, 'low': 28800.0, 'close': 29400.0, + {'time': 1609459200, 'open': 29000.0, 'high': 29500.0, 'low': 28800.0, 'close': 29400.0, 'volume': 1000} ]) pd.testing.assert_frame_equal(klines, expected_df) diff --git a/tests/test_candles.py b/tests/test_candles.py index 0ba64f8..784efe1 100644 --- a/tests/test_candles.py +++ b/tests/test_candles.py @@ -121,19 +121,19 @@ def test_get_records_since(): print(f'\ntest_candles_get_records() starting @: {start_time}') result = candles_obj.get_records_since(symbol=symbol, timeframe=interval, exchange_name=exchange_name, start_time=start_time) - print(f'test_candles_received records from: {datetime.datetime.fromtimestamp(result.open_time.min() / 1000)}') + print(f'test_candles_received records from: {datetime.datetime.fromtimestamp(result.time.min() / 1000)}') start_time = datetime.datetime(year=2023, month=3, day=16, hour=1, minute=0) print(f'\ntest_candles_get_records() starting @: {start_time}') result = candles_obj.get_records_since(symbol=symbol, timeframe=interval, exchange_name=exchange_name, start_time=start_time) - print(f'test_candles_received records from: {datetime.datetime.fromtimestamp(result.open_time.min() / 1000)}') + print(f'test_candles_received records from: {datetime.datetime.fromtimestamp(result.time.min() / 1000)}') start_time = datetime.datetime(year=2023, month=3, day=13, hour=1, minute=0) print(f'\ntest_candles_get_records() starting @: {start_time}') result = candles_obj.get_records_since(symbol=symbol, timeframe=interval, exchange_name=exchange_name, start_time=start_time) - print(f'test_candles_received records from: {datetime.datetime.fromtimestamp(result.open_time.min() / 1000)}') + print(f'test_candles_received records from: {datetime.datetime.fromtimestamp(result.time.min() / 1000)}') assert result is not None diff --git a/tests/test_database.py b/tests/test_database.py index ebadc45..ce3ec03 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -111,7 +111,7 @@ class TestDatabase(unittest.TestCase): def test_get_timestamped_records(self): print("\nRunning test_get_timestamped_records...") df = pd.DataFrame({ - 'open_time': [unix_time_millis(dt.datetime.utcnow())], + 'time': [unix_time_millis(dt.datetime.utcnow())], 'open': [1.0], 'high': [1.0], 'low': [1.0], @@ -122,7 +122,7 @@ class TestDatabase(unittest.TestCase): self.cursor.execute(f""" CREATE TABLE {table_name} ( id INTEGER PRIMARY KEY, - open_time INTEGER UNIQUE, + time INTEGER UNIQUE, open REAL NOT NULL, high REAL NOT NULL, low REAL NOT NULL, @@ -134,7 +134,7 @@ class TestDatabase(unittest.TestCase): self.db.insert_dataframe(df, table_name) st = dt.datetime.utcnow() - dt.timedelta(minutes=1) et = dt.datetime.utcnow() - records = self.db.get_timestamped_records(table_name, 'open_time', st, et) + records = self.db.get_timestamped_records(table_name, 'time', st, et) self.assertIsInstance(records, pd.DataFrame) self.assertFalse(records.empty) print("Get timestamped records test passed.") @@ -153,7 +153,7 @@ class TestDatabase(unittest.TestCase): def test_insert_candles_into_db(self): print("\nRunning test_insert_candles_into_db...") df = pd.DataFrame({ - 'open_time': [unix_time_millis(dt.datetime.utcnow())], + 'time': [unix_time_millis(dt.datetime.utcnow())], 'open': [1.0], 'high': [1.0], 'low': [1.0], @@ -165,7 +165,7 @@ class TestDatabase(unittest.TestCase): CREATE TABLE {table_name} ( id INTEGER PRIMARY KEY, market_id INTEGER, - open_time INTEGER UNIQUE, + time INTEGER UNIQUE, open REAL NOT NULL, high REAL NOT NULL, low REAL NOT NULL, diff --git a/tests/test_shared_utilities.py b/tests/test_shared_utilities.py index 07bbbcd..4c2d6dd 100644 --- a/tests/test_shared_utilities.py +++ b/tests/test_shared_utilities.py @@ -15,7 +15,7 @@ class TestSharedUtilities(unittest.TestCase): # (Test case 1) The records should not be up-to-date (very old timestamps) records = pd.DataFrame({ - 'open_time': [1, 2, 3, 4, 5] + 'time': [1, 2, 3, 4, 5] }) result = query_uptodate(records, 1) if result is None: @@ -28,7 +28,7 @@ class TestSharedUtilities(unittest.TestCase): # (Test case 2) The records should be up-to-date (recent timestamps) now = unix_time_millis(dt.datetime.utcnow()) recent_records = pd.DataFrame({ - 'open_time': [now - 70000, now - 60000, now - 40000] + 'time': [now - 70000, now - 60000, now - 40000] }) result = query_uptodate(recent_records, 1) if result is None: @@ -44,7 +44,7 @@ class TestSharedUtilities(unittest.TestCase): tolerance_milliseconds = 10 * 1000 # tolerance in milliseconds recent_time = unix_time_millis(dt.datetime.utcnow()) borderline_records = pd.DataFrame({ - 'open_time': [recent_time - one_hour + (tolerance_milliseconds - 3)] # just within the tolerance + 'time': [recent_time - one_hour + (tolerance_milliseconds - 3)] # just within the tolerance }) result = query_uptodate(borderline_records, 60) if result is None: @@ -60,7 +60,7 @@ class TestSharedUtilities(unittest.TestCase): tolerance_milliseconds = 10 * 1000 # tolerance in milliseconds recent_time = unix_time_millis(dt.datetime.utcnow()) borderline_records = pd.DataFrame({ - 'open_time': [recent_time - one_hour + (tolerance_milliseconds + 3)] # just within the tolerance + 'time': [recent_time - one_hour + (tolerance_milliseconds + 3)] # just within the tolerance }) result = query_uptodate(borderline_records, 60) if result is None: @@ -91,7 +91,7 @@ class TestSharedUtilities(unittest.TestCase): # Test case where the records should satisfy the query (records cover the start time) start_datetime = dt.datetime(2020, 1, 1) records = pd.DataFrame({ - 'open_time': [1577836800000.0 - 600000, 1577836800000.0 - 300000, 1577836800000.0] + 'time': [1577836800000.0 - 600000, 1577836800000.0 - 300000, 1577836800000.0] # Covering the start time }) result = query_satisfied(start_datetime, records, 1) @@ -105,7 +105,7 @@ class TestSharedUtilities(unittest.TestCase): # Test case where the records should not satisfy the query (recent records but not enough) recent_time = unix_time_millis(dt.datetime.utcnow()) records = pd.DataFrame({ - 'open_time': [recent_time - 300 * 60 * 1000, recent_time - 240 * 60 * 1000, recent_time - 180 * 60 * 1000] + 'time': [recent_time - 300 * 60 * 1000, recent_time - 240 * 60 * 1000, recent_time - 180 * 60 * 1000] }) result = query_satisfied(start_datetime, records, 1) if result is None: @@ -118,7 +118,7 @@ class TestSharedUtilities(unittest.TestCase): # Additional test case for partial coverage start_datetime = dt.datetime.utcnow() - dt.timedelta(minutes=300) records = pd.DataFrame({ - 'open_time': [unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=240)), + 'time': [unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=240)), unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=180)), unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=120))] })