Indicators are fixed after last update.
This commit is contained in:
parent
1ff21b56dd
commit
f1d0f2a4b1
|
|
@ -33,7 +33,7 @@ class BrighterTrades:
|
|||
self.signals = Signals(self.config)
|
||||
|
||||
# Object that maintains candlestick and price data.
|
||||
self.candles = Candles(users=self.users, exchanges=self.exchanges, data_source=self.data,
|
||||
self.candles = Candles(users=self.users, exchanges=self.exchanges, datacache=self.data,
|
||||
config=self.config)
|
||||
|
||||
# Object that interacts with and maintains data from available indicators
|
||||
|
|
@ -343,7 +343,8 @@ class BrighterTrades:
|
|||
|
||||
self.strategies.delete_strategy(strategy_name)
|
||||
try:
|
||||
self.config.remove('strategies', strategy_name)
|
||||
# self.config.remove('strategies', strategy_name)TODO
|
||||
pass
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to remove the strategy '{strategy_name}' from the configuration file: {str(e)}")
|
||||
|
||||
|
|
@ -358,8 +359,8 @@ class BrighterTrades:
|
|||
# Delete the signal from the signals instance.
|
||||
self.signals.delete_signal(signal_name)
|
||||
|
||||
# Delete the signal from the configuration file.
|
||||
self.config.remove('signals', signal_name)
|
||||
# # Delete the signal from the configuration file.TODO
|
||||
# self.config.remove('signals', signal_name)
|
||||
|
||||
def get_signals_json(self) -> str:
|
||||
"""
|
||||
|
|
@ -394,8 +395,8 @@ class BrighterTrades:
|
|||
}
|
||||
|
||||
try:
|
||||
if self.data.get_cache_item().get_cache('exchange_data').query([('user', user_name),
|
||||
('name', exchange_name)]).empty:
|
||||
if self.data.get_serialized_datacache(cache_name='exchange_data',
|
||||
filter_vals=([('user', user_name), ('name', exchange_name)])).empty:
|
||||
# Exchange is not connected, try to connect
|
||||
success = self.exchanges.connect_exchange(exchange_name=exchange_name, user_name=user_name,
|
||||
api_keys=api_keys)
|
||||
|
|
@ -409,11 +410,19 @@ class BrighterTrades:
|
|||
result['status'] = 'failure'
|
||||
result['message'] = f'Failed to connect to {exchange_name}.'
|
||||
else:
|
||||
# Exchange is already connected, update API keys if provided
|
||||
# Exchange is already connected, check if API keys need updating
|
||||
if api_keys:
|
||||
self.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, user_name=user_name)
|
||||
# Get current API keys
|
||||
current_keys = self.users.get_api_keys(user_name, exchange_name)
|
||||
|
||||
# Compare current keys with provided keys
|
||||
if current_keys != api_keys:
|
||||
self.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, user_name=user_name)
|
||||
result['message'] = f'{exchange_name}: API keys updated.'
|
||||
else:
|
||||
result['message'] = f'{exchange_name}: API keys unchanged.'
|
||||
|
||||
result['status'] = 'already_connected'
|
||||
result['message'] = f'{exchange_name}: API keys updated.'
|
||||
except Exception as e:
|
||||
result['status'] = 'error'
|
||||
result['message'] = f"Failed to connect to {exchange_name} for user '{user_name}': {str(e)}"
|
||||
|
|
@ -427,8 +436,9 @@ class BrighterTrades:
|
|||
:param trade_id: The ID of the trade to be closed.
|
||||
"""
|
||||
if self.trades.is_valid_trade_id(trade_id):
|
||||
self.trades.close_trade(trade_id)
|
||||
self.config.remove('trades', trade_id)
|
||||
pass
|
||||
# self.trades.close_trade(trade_id)TODO
|
||||
# self.config.remove('trades', trade_id)
|
||||
print(f"Trade {trade_id} has been closed.")
|
||||
else:
|
||||
print(f"Invalid trade ID: {trade_id}. Unable to close the trade.")
|
||||
|
|
@ -469,8 +479,8 @@ class BrighterTrades:
|
|||
f'quantity={vld("quantity")}, '
|
||||
f'price={vld("price")}')
|
||||
|
||||
# Update config's list of trades and save to file.
|
||||
self.config.update_data('trades', self.trades.get_trades('dict'))
|
||||
# Update config's list of trades and save to file.TODO
|
||||
# self.config.update_data('trades', self.trades.get_trades('dict'))
|
||||
|
||||
trade_obj = self.trades.get_trade_by_id(result)
|
||||
if trade_obj:
|
||||
|
|
@ -547,8 +557,9 @@ class BrighterTrades:
|
|||
print(f'ERROR SETTING VALUE')
|
||||
print(f'The string received by the server was: /n{params}')
|
||||
|
||||
# Todo this doesn't seem necessary anymore, because the cache now updates per request.
|
||||
# Now that the state is changed reload price history.
|
||||
self.candles.set_cache(user_name=user_name)
|
||||
# self.candles.set_cache(user_name=user_name)
|
||||
return
|
||||
|
||||
def process_incoming_message(self, msg_type: str, msg_data: dict | str) -> dict | None:
|
||||
|
|
|
|||
|
|
@ -166,22 +166,27 @@ class RowBasedCache:
|
|||
def query(self, conditions: List[Tuple[str, Any]]) -> pd.DataFrame:
|
||||
"""Query cache entries by conditions, ignoring expired entries."""
|
||||
self._check_purge() # Check if purge is needed
|
||||
key_value = next((value for key, value in conditions if key == 'key'), None)
|
||||
|
||||
# Get the value of tbl_key out of the list of key-value pairs.
|
||||
key_value = next((value for key, value in conditions if key == 'tbl_key'), None)
|
||||
if key_value is None or key_value not in self.cache:
|
||||
return pd.DataFrame() # Return an empty DataFrame if key is not found
|
||||
|
||||
entry = self.cache[key_value]
|
||||
|
||||
# Expire entry if expired.
|
||||
if entry.metadata.is_expired():
|
||||
del self.cache[key_value] # Remove expired entry
|
||||
return pd.DataFrame() # Return an empty DataFrame if the entry has expired
|
||||
# Remove expired entry and Return an empty DataFrame
|
||||
del self.cache[key_value]
|
||||
return pd.DataFrame()
|
||||
|
||||
data = entry.data
|
||||
|
||||
# If the data is a DataFrame, apply the conditions using pandas .query()
|
||||
if isinstance(data, pd.DataFrame):
|
||||
# Construct the query string and prepare local variables for the query
|
||||
query_conditions = ' and '.join([f'`{col}` == @val_{col}' for col, _ in conditions if col != 'key'])
|
||||
query_vars = {f'val_{col}': val for col, val in conditions if col != 'key'}
|
||||
query_conditions = ' and '.join([f'`{col}` == @val_{col}' for col, _ in conditions if col != 'tbl_key'])
|
||||
query_vars = {f'val_{col}': val for col, val in conditions if col != 'tbl_key'}
|
||||
|
||||
# Use pandas .query() with local_dict to pass the variables
|
||||
return data.query(query_conditions, local_dict=query_vars) if query_conditions else data
|
||||
|
|
@ -226,10 +231,10 @@ class RowBasedCache:
|
|||
|
||||
def remove_item(self, conditions: List[Tuple[str, Any]]) -> bool:
|
||||
"""Remove an item from the cache using key-value conditions.
|
||||
In row cache, only 'key' is used to identify the entry.
|
||||
In row cache, only 'tbl_key' is used to identify the entry.
|
||||
"""
|
||||
# Find the value of 'key' from the conditions
|
||||
key_value = next((value for key, value in conditions if key == 'key'), None)
|
||||
# Find the value of 'tbl_key' from the conditions
|
||||
key_value = next((value for key, value in conditions if key == 'tbl_key'), None)
|
||||
if key_value is None or key_value not in self.cache:
|
||||
return False # Key not found, so nothing to remove
|
||||
|
||||
|
|
@ -243,8 +248,8 @@ class RowBasedCache:
|
|||
# If the data is a DataFrame, apply additional filtering
|
||||
if isinstance(entry.data, pd.DataFrame):
|
||||
# Construct the query string and prepare local variables for the query
|
||||
query_conditions = ' and '.join([f'`{col}` == @val_{col}' for col, _ in conditions if col != 'key'])
|
||||
query_vars = {f'val_{col}': val for col, val in conditions if col != 'key'}
|
||||
query_conditions = ' and '.join([f'`{col}` == @val_{col}' for col, _ in conditions if col != 'tbl_key'])
|
||||
query_vars = {f'val_{col}': val for col, val in conditions if col != 'tbl_key'}
|
||||
|
||||
# Apply the query to the DataFrame, removing matching rows
|
||||
remaining_data = entry.data.query(f'not ({query_conditions})', local_dict=query_vars)
|
||||
|
|
@ -256,7 +261,7 @@ class RowBasedCache:
|
|||
# Update the entry with the remaining rows
|
||||
entry.data = remaining_data
|
||||
else:
|
||||
# If the data is not a DataFrame, remove the entire entry if the 'key' matches
|
||||
# If the data is not a DataFrame, remove the entire entry if the 'tbl_key' matches
|
||||
del self.cache[key_value]
|
||||
self.access_order.remove(key_value)
|
||||
return True # Successfully removed the item
|
||||
|
|
@ -304,7 +309,7 @@ class TableBasedCache:
|
|||
df_with_metadata = df.copy()
|
||||
df_with_metadata['metadata'] = metadata
|
||||
|
||||
# If a key is provided, add a 'key' column to the DataFrame
|
||||
# If a key is provided, add a 'tbl_key' column to the DataFrame
|
||||
if key is not None:
|
||||
df_with_metadata['tbl_key'] = key
|
||||
|
||||
|
|
@ -343,9 +348,6 @@ class TableBasedCache:
|
|||
# Start with the entire cache
|
||||
result = self.cache.copy()
|
||||
|
||||
# Replace any query for 'key' with 'tbl_key' since that's what we are using in the table-based cache
|
||||
conditions = [(('tbl_key' if col == 'key' else col), val) for col, val in conditions]
|
||||
|
||||
# Apply conditions using pandas .query()
|
||||
if not result.empty:
|
||||
query_conditions = ' and '.join([f'`{col}` == @val_{col}' for col, _ in conditions])
|
||||
|
|
@ -355,7 +357,7 @@ class TableBasedCache:
|
|||
result = result.query(query_conditions, local_dict=query_vars) if query_conditions else result
|
||||
|
||||
# Remove the metadata and tbl_key columns for the result
|
||||
return result.drop(columns=['metadata', 'tbl_key'], errors='ignore')
|
||||
return result.drop(columns=['metadata'], errors='ignore')
|
||||
|
||||
def is_attr_taken(self, column: str, value: Any) -> bool:
|
||||
"""Check if a column contains the specified value in the Table-Based Cache."""
|
||||
|
|
@ -406,10 +408,6 @@ class CacheManager:
|
|||
def __init__(self):
|
||||
self.caches = {}
|
||||
|
||||
import pandas as pd
|
||||
import datetime as dt
|
||||
from typing import Optional
|
||||
|
||||
def create_cache(self, name: str, cache_type: str,
|
||||
size_limit: Optional[int] = None,
|
||||
eviction_policy: str = 'evict',
|
||||
|
|
@ -469,20 +467,20 @@ class CacheManager:
|
|||
cache = self.get_cache(cache_name)
|
||||
|
||||
# Ensure the cache contains DataFrames (required for querying)
|
||||
if isinstance(cache, (TableBasedCache, RowBasedCache)):
|
||||
# Perform the query on the cache using filter_vals
|
||||
filtered_cache = cache.query(filter_vals) # Pass the list of filters
|
||||
if not isinstance(cache, (TableBasedCache, RowBasedCache)):
|
||||
raise ValueError(f"Cache '{cache_name}' does not contain TableBasedCache or RowBasedCache.")
|
||||
|
||||
# If data is found in the cache, return it
|
||||
if not filtered_cache.empty:
|
||||
return filtered_cache
|
||||
else:
|
||||
raise ValueError(f"Cache '{cache_name}' does not contain DataFrames.")
|
||||
# Perform the query on the cache using filter_vals
|
||||
filtered_cache = cache.query(filter_vals) # Pass the list of filters
|
||||
|
||||
# If data is found in the cache, return it
|
||||
if not filtered_cache.empty:
|
||||
return filtered_cache
|
||||
|
||||
# No result return an empty Dataframe
|
||||
return pd.DataFrame()
|
||||
|
||||
def fetch_cache_item(self, item_name: str, cache_name: str, filter_vals: tuple[str, any]) -> any:
|
||||
def get_cache_item(self, item_name: str, cache_name: str, filter_vals: tuple[str, any]) -> any:
|
||||
"""
|
||||
Retrieves a specific item from the cache.
|
||||
|
||||
|
|
@ -600,10 +598,10 @@ class CacheManager:
|
|||
|
||||
# Set the updated row in the cache
|
||||
if isinstance(cache, RowBasedCache):
|
||||
# For row-based cache, the 'key' must be in filter_vals
|
||||
key_value = next((val for key, val in filter_vals if key == 'key'), None)
|
||||
# For row-based cache, the 'tbl_key' must be in filter_vals
|
||||
key_value = next((val for key, val in filter_vals if key == 'tbl_key'), None)
|
||||
if key_value is None:
|
||||
raise ValueError("'key' must be present in filter_vals for row-based caches.")
|
||||
raise ValueError("'tbl_key' must be present in filter_vals for row-based caches.")
|
||||
# Update the cache entry with the modified row
|
||||
cache.add_entry(key=key_value, data=rows)
|
||||
elif isinstance(cache, TableBasedCache):
|
||||
|
|
@ -612,6 +610,15 @@ class CacheManager:
|
|||
else:
|
||||
raise ValueError(f"Unsupported cache type for {cache_name}")
|
||||
|
||||
@staticmethod
|
||||
def key_exists(cache, key):
|
||||
# Handle different cache types
|
||||
if isinstance(cache, RowBasedCache):
|
||||
return True if key in cache.cache else False
|
||||
if isinstance(cache, TableBasedCache):
|
||||
existing_rows = cache.query([("tbl_key", key)])
|
||||
return False if existing_rows.empty else True
|
||||
|
||||
|
||||
class SnapshotDataCache(CacheManager):
|
||||
"""
|
||||
|
|
@ -719,21 +726,37 @@ class DatabaseInteractions(SnapshotDataCache):
|
|||
super().__init__()
|
||||
self.db = Database()
|
||||
|
||||
def get_rows_from_datacache(self, cache_name: str, filter_vals: list[tuple[str, Any]]) -> pd.DataFrame | None:
|
||||
def get_rows_from_datacache(self, cache_name: str, filter_vals: list[tuple[str, Any]] = None,
|
||||
key: str = None) -> pd.DataFrame | None:
|
||||
"""
|
||||
Retrieves rows from the cache if available; otherwise, queries the database and caches the result.
|
||||
|
||||
:param key: Optional
|
||||
:param cache_name: The key used to identify the cache (also the name of the database table).
|
||||
:param filter_vals: A list of tuples, each containing a column name and the value(s) to filter by.
|
||||
:return: A DataFrame containing the requested rows, or None if no matching rows are found.
|
||||
:raises ValueError: If the cache is not a DataFrame or does not contain DataFrames in the 'data' column.
|
||||
"""
|
||||
# Ensure at least one of filter_vals or key is provided
|
||||
if not filter_vals and not key:
|
||||
raise ValueError("At least one of 'filter_vals' or 'key' must be provided.")
|
||||
|
||||
# Use an empty list if filter_vals is None
|
||||
filter_vals = filter_vals or []
|
||||
|
||||
# Insert the key if provided
|
||||
if key:
|
||||
filter_vals.insert(0, ('tbl_key', key))
|
||||
|
||||
result = self.get_rows_from_cache(cache_name, filter_vals)
|
||||
if result.empty:
|
||||
# Fallback: fetch from the database and cache the result if necessary
|
||||
return self._fetch_from_database(cache_name, filter_vals)
|
||||
# Fallback: Fetch from the database and cache the result if necessary
|
||||
result = self._fetch_from_database(cache_name, filter_vals)
|
||||
|
||||
def _fetch_from_database(self, cache_name: str, filter_vals: List[tuple[str, Any]]) -> pd.DataFrame | None:
|
||||
# Take the key out on return.
|
||||
return result.drop(columns=['tbl_key'], errors='ignore')
|
||||
|
||||
def _fetch_from_database(self, cache_name: str, filter_vals: List[tuple[str, Any]]) -> pd.DataFrame:
|
||||
"""
|
||||
Fetch rows from the database and cache the result.
|
||||
|
||||
|
|
@ -754,15 +777,15 @@ class DatabaseInteractions(SnapshotDataCache):
|
|||
cache.add_entry(key=key_value, data=rows)
|
||||
else:
|
||||
# For table-based cache, add the entire DataFrame to the cache
|
||||
cache.add_table(df=rows)
|
||||
cache.add_table(df=rows, overwrite='tbl_key')
|
||||
|
||||
# Return the fetched rows
|
||||
return rows
|
||||
|
||||
# If no rows are found, return None
|
||||
return None
|
||||
return pd.DataFrame()
|
||||
|
||||
def fetch_datacache_item(self, item_name: str, cache_name: str, filter_vals: tuple[str, any]) -> any:
|
||||
def get_datacache_item(self, item_name: str, cache_name: str, filter_vals: tuple[str, any]) -> any:
|
||||
"""
|
||||
Retrieves a specific item from the cache or database, caching the result if necessary.
|
||||
|
||||
|
|
@ -781,9 +804,8 @@ class DatabaseInteractions(SnapshotDataCache):
|
|||
# Return the specific item from the first matching row.
|
||||
return rows.iloc[0][item_name]
|
||||
|
||||
# If the item is not found, raise an error.todo do I want to raise an error or return empty?
|
||||
raise ValueError(
|
||||
f"Item '{item_name}' not found in cache or table '{cache_name}' where {filter_vals[0]} = {filter_vals[1]}")
|
||||
# The item was not found.
|
||||
return None
|
||||
|
||||
def insert_row_into_datacache(self, cache_name: str, columns: tuple, values: tuple, key: str = None,
|
||||
skip_cache: bool = False) -> None:
|
||||
|
|
@ -796,16 +818,14 @@ class DatabaseInteractions(SnapshotDataCache):
|
|||
:param key: Optional key for the cache item. If None, the auto-incremented ID from the database will be used.
|
||||
:param skip_cache: If True, skips inserting the row into the cache. Default is False.
|
||||
"""
|
||||
# Insert the row into the database and fetch the auto-incremented ID
|
||||
auto_incremented_id = self.db.insert_row(table=cache_name, columns=columns, values=values)
|
||||
if key:
|
||||
columns, values = columns + ('tbl_key',), values + (key,)
|
||||
|
||||
# Insert the row into the database
|
||||
self.db.insert_row(table=cache_name, columns=columns, values=values)
|
||||
# Insert the row into the cache
|
||||
if skip_cache:
|
||||
return
|
||||
|
||||
# Use the auto-incremented ID as the key if none was provided (for row-based caches)
|
||||
if key is None:
|
||||
key = str(auto_incremented_id)
|
||||
|
||||
self.insert_row_into_cache(cache_name, columns, values, key)
|
||||
|
||||
def insert_df_into_datacache(self, df: pd.DataFrame, cache_name: str, skip_cache: bool = False) -> None:
|
||||
|
|
@ -823,17 +843,21 @@ class DatabaseInteractions(SnapshotDataCache):
|
|||
self.insert_df_into_cache(df, cache_name)
|
||||
|
||||
def remove_row_from_datacache(self, cache_name: str, filter_vals: List[tuple[str, Any]],
|
||||
remove_from_db: bool = True) -> None:
|
||||
remove_from_db: bool = True, key: str = None) -> None:
|
||||
"""
|
||||
Removes rows from the cache and optionally from the database based on multiple filter criteria.
|
||||
|
||||
This method is specifically designed for caches stored as DataFrames.
|
||||
|
||||
:param key: Optional key
|
||||
:param cache_name: The name of the cache (or table) from which to remove rows.
|
||||
:param filter_vals: A list of tuples, each containing a column name and the value to filter by.
|
||||
:param remove_from_db: If True, also removes the rows from the database. Default is True.
|
||||
:raises ValueError: If the cache is not a DataFrame or if no valid cache is found.
|
||||
"""
|
||||
if key:
|
||||
filter_vals.insert(0, ('tbl_key', key))
|
||||
|
||||
self.remove_row_from_cache(cache_name, filter_vals)
|
||||
|
||||
# Remove from the database if required
|
||||
|
|
@ -846,16 +870,21 @@ class DatabaseInteractions(SnapshotDataCache):
|
|||
self.db.execute_sql(sql, params)
|
||||
|
||||
def modify_datacache_item(self, cache_name: str, filter_vals: List[Tuple[str, any]], field_name: str,
|
||||
new_data: any) -> None:
|
||||
new_data: any, key: str = None, overwrite: str = None) -> None:
|
||||
"""
|
||||
Modifies a specific field in a row within the cache and updates the database accordingly.
|
||||
|
||||
:param overwrite:
|
||||
:param key: optional key
|
||||
:param cache_name: The name used to identify the cache (also the name of the database table).
|
||||
:param filter_vals: A list of tuples containing column names and values to filter by.
|
||||
:param field_name: The field to be updated.
|
||||
:param new_data: The new data to be set.
|
||||
:raises ValueError: If the row is not found in the cache or the database, or if multiple rows are returned.
|
||||
"""
|
||||
if key:
|
||||
filter_vals.insert(0, ('tbl_key', key))
|
||||
|
||||
# Retrieve the row from the cache or database
|
||||
rows = self.get_rows_from_datacache(cache_name=cache_name, filter_vals=filter_vals)
|
||||
|
||||
|
|
@ -880,15 +909,15 @@ class DatabaseInteractions(SnapshotDataCache):
|
|||
|
||||
# Set the updated row in the cache
|
||||
if isinstance(cache, RowBasedCache):
|
||||
# For row-based cache, the 'key' must be in filter_vals
|
||||
key_value = next((val for key, val in filter_vals if key == 'key'), None)
|
||||
# For row-based cache, the 'tbl_key' must be in filter_vals
|
||||
key_value = next((val for key, val in filter_vals if key == 'tbl_key'), None)
|
||||
if key_value is None:
|
||||
raise ValueError("'key' must be present in filter_vals for row-based caches.")
|
||||
raise ValueError("'tbl_key' must be present in filter_vals for row-based caches.")
|
||||
# Update the cache entry with the modified row
|
||||
cache.add_entry(key=key_value, data=rows)
|
||||
elif isinstance(cache, TableBasedCache):
|
||||
# For table-based cache, use the existing query method to update the correct rows
|
||||
cache.add_table(rows)
|
||||
cache.add_table(rows, overwrite=overwrite)
|
||||
else:
|
||||
raise ValueError(f"Unsupported cache type for {cache_name}")
|
||||
|
||||
|
|
@ -900,6 +929,152 @@ class DatabaseInteractions(SnapshotDataCache):
|
|||
# Execute the SQL update to modify the database
|
||||
self.db.execute_sql(sql_update, params)
|
||||
|
||||
def serialized_datacache_insert(self, cache_name: str, data: Any, key: str = None,
|
||||
do_not_overwrite: bool = False):
|
||||
"""
|
||||
Stores an item in the cache, with custom serialization for object instances.
|
||||
If the data is not a DataFrame, the entire object is serialized and stored under a column named 'data'.
|
||||
|
||||
:param cache_name: The name of the cache.
|
||||
:param data: Any object to store in the cache, but should be a DataFrame with one row for normal operations.
|
||||
:param key: The key for row-based caches, used to identify the entry. Required for row-based caches.
|
||||
:param do_not_overwrite: If True, prevents overwriting existing entries in the cache.
|
||||
"""
|
||||
|
||||
# Retrieve the cache
|
||||
cache = self.get_cache(cache_name)
|
||||
|
||||
# If overwrite is disabled and the key already exists, prevent overwriting
|
||||
if do_not_overwrite and self.key_exists(cache, key):
|
||||
logging.warning(f"Key '{key}' already exists in cache '{cache_name}'. Overwrite prevented.")
|
||||
return
|
||||
|
||||
# If the data is a DataFrame, ensure it contains exactly one row
|
||||
if isinstance(data, pd.DataFrame):
|
||||
if len(data) != 1:
|
||||
raise ValueError('This method is for inserting a DataFrame with exactly one row.')
|
||||
|
||||
# Ensure key is provided for RowBasedCache
|
||||
if isinstance(cache, RowBasedCache) and key is None:
|
||||
raise ValueError("RowBasedCache requires a key to store the data.")
|
||||
|
||||
# List of types to exclude from serialization
|
||||
excluded_objects = (str, int, float, bool, type(None), bytes)
|
||||
|
||||
# Process and serialize non-excluded objects in the row
|
||||
row = data.iloc[0] # Access the first (and only) row
|
||||
row_values = []
|
||||
for col_value in row:
|
||||
# Serialize column value if it's not one of the excluded types
|
||||
if not isinstance(col_value, excluded_objects):
|
||||
col_value = pickle.dumps(col_value)
|
||||
row_values.append(col_value)
|
||||
|
||||
# Insert the row into the cache and database (key is handled in insert_row_into_datacache)
|
||||
self.insert_row_into_datacache(cache_name=cache_name, columns=tuple(data.columns),
|
||||
values=tuple(row_values), key=key)
|
||||
|
||||
else:
|
||||
# For non-DataFrame data, serialize the entire object
|
||||
serialized_data = pickle.dumps(data)
|
||||
|
||||
# Insert the serialized object under a column named 'data'
|
||||
self.insert_row_into_datacache(cache_name=cache_name, columns=('data',),
|
||||
values=(serialized_data,), key=key)
|
||||
|
||||
return
|
||||
|
||||
def get_serialized_datacache(self,
|
||||
cache_name: str,
|
||||
filter_vals: List[Tuple[str, Any]] = None,
|
||||
key: str = None) -> pd.DataFrame | Any:
|
||||
"""
|
||||
Retrieves an item from the specified cache and deserializes object columns if necessary.
|
||||
If the stored data is a serialized object (not a DataFrame), it returns the deserialized object.
|
||||
|
||||
:param key: The key to identify the cache entry.
|
||||
:param filter_vals: List of column filters (name, value) for the cache query.
|
||||
:param cache_name: The name of the cache.
|
||||
:return Any: Cached data with deserialized objects or the original non-DataFrame object, or None if not found.
|
||||
"""
|
||||
|
||||
# Ensure at least one of filter_vals or key is provided
|
||||
if not filter_vals and not key:
|
||||
raise ValueError("At least one of 'filter_vals' or 'key' must be provided.")
|
||||
|
||||
# Prepare filter values
|
||||
filter_vals = filter_vals or []
|
||||
if key:
|
||||
filter_vals.insert(0, ('tbl_key', key))
|
||||
|
||||
# Retrieve rows from the cache using the key
|
||||
data = self.get_rows_from_datacache(cache_name=cache_name, filter_vals=filter_vals)
|
||||
|
||||
# Return None if no data is found
|
||||
if data is None or (isinstance(data, pd.DataFrame) and data.empty):
|
||||
logging.info(f"No data found in cache '{cache_name}' for key: {key}")
|
||||
return pd.DataFrame()
|
||||
|
||||
# Handle non-DataFrame data
|
||||
if not isinstance(data, pd.DataFrame):
|
||||
logging.warning(f"Unexpected data format from cache '{cache_name}'. Returning None.")
|
||||
return pd.DataFrame()
|
||||
|
||||
# Check for single column 'data' (serialized object case)
|
||||
if 'data' in data.columns and len(data.columns) == 1:
|
||||
return self._deserialize_object(data.iloc[0]['data'], cache_name)
|
||||
|
||||
# Handle deserialization of DataFrame columns
|
||||
return self._deserialize_dataframe_row(data, cache_name)
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_object(serialized_data: Any, cache_name: str) -> Any:
|
||||
"""
|
||||
Deserializes an object stored as a serialized byte stream in the cache.
|
||||
|
||||
:param serialized_data: Serialized byte data to deserialize.
|
||||
:param cache_name: The name of the cache (used for logging).
|
||||
:return: Deserialized object, or the raw bytes if deserialization fails.
|
||||
"""
|
||||
if not isinstance(serialized_data, bytes):
|
||||
logging.warning(f"Expected bytes for deserialization in cache '{cache_name}', got {type(serialized_data)}.")
|
||||
return serialized_data
|
||||
|
||||
try:
|
||||
deserialized_data = pickle.loads(serialized_data)
|
||||
logging.info(f"Serialized object retrieved and deserialized from cache '{cache_name}'")
|
||||
return deserialized_data
|
||||
except (pickle.PickleError, TypeError) as e:
|
||||
logging.warning(f"Failed to deserialize object from cache '{cache_name}': {e}")
|
||||
return serialized_data # Fallback to the raw serialized data
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_dataframe_row(data: pd.DataFrame, cache_name: str) -> pd.DataFrame:
|
||||
"""
|
||||
Deserializes any serialized columns in a DataFrame row.
|
||||
|
||||
:param data: The DataFrame containing serialized columns.
|
||||
:param cache_name: The name of the cache (used for logging).
|
||||
:return: DataFrame with deserialized values.
|
||||
"""
|
||||
row = data.iloc[0] # Assuming we only retrieve one row
|
||||
deserialized_row = []
|
||||
|
||||
for col_value in row:
|
||||
if isinstance(col_value, bytes):
|
||||
try:
|
||||
deserialized_col_value = pickle.loads(col_value)
|
||||
deserialized_row.append(deserialized_col_value)
|
||||
except (pickle.PickleError, TypeError) as e:
|
||||
logging.warning(f"Failed to deserialize column value in cache '{cache_name}': {e}")
|
||||
deserialized_row.append(col_value) # Fallback to the original value
|
||||
else:
|
||||
deserialized_row.append(col_value)
|
||||
|
||||
deserialized_data = pd.DataFrame([deserialized_row], columns=data.columns)
|
||||
logging.info(f"Data retrieved and deserialized from cache '{cache_name}'")
|
||||
return deserialized_data
|
||||
|
||||
|
||||
class ServerInteractions(DatabaseInteractions):
|
||||
"""
|
||||
|
|
@ -908,10 +1083,7 @@ class ServerInteractions(DatabaseInteractions):
|
|||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# !SET THE MAXIMUM NUMBER OF MARKETS TO KEEP IN RAM HERE!
|
||||
self.exchanges = None
|
||||
self.create_cache(name='candles', cache_type='row', default_expiration=dt.timedelta(days=5),
|
||||
size_limit=100, eviction_policy='evict')
|
||||
|
||||
def set_exchange(self, exchanges):
|
||||
"""
|
||||
|
|
@ -1333,103 +1505,6 @@ class IndicatorCache(ServerInteractions):
|
|||
indicator_class = self.indicator_registry[indicator_type]
|
||||
return indicator_class(name=indicator_type, indicator_type=indicator_type, properties=properties)
|
||||
|
||||
def set_cache_item(self, cache_name: str, data: Any, key: str = None,
|
||||
expire_delta: Optional[dt.timedelta] = None,
|
||||
do_not_overwrite: bool = False):
|
||||
"""
|
||||
Stores an item in the cache, with custom serialization for Indicator instances.
|
||||
Handles both row-based and table-based caches differently.
|
||||
|
||||
:param cache_name: The name of the cache.
|
||||
:param data: The data to store in the cache. Can be a DataFrame or an Indicator instance.
|
||||
:param key: The key for row-based caches, used to identify the entry. Required for row-based caches.
|
||||
:param expire_delta: An optional expiration timedelta. If not provided, the cache's default expiration is used.
|
||||
:param do_not_overwrite: If True, prevents overwriting existing entries in the cache.
|
||||
"""
|
||||
|
||||
# Convert expiration delta (if provided) to seconds
|
||||
expiration_time = expire_delta.total_seconds() if expire_delta else None
|
||||
|
||||
# Retrieve the specified cache by its name
|
||||
cache = self.get_cache(cache_name)
|
||||
|
||||
# Handle Row-Based Cache
|
||||
if isinstance(cache, RowBasedCache):
|
||||
if key is None:
|
||||
raise ValueError("RowBasedCache requires a key to store the data.")
|
||||
|
||||
# If the data is an Indicator instance, serialize it
|
||||
if isinstance(data, Indicator):
|
||||
data = pickle.dumps(data)
|
||||
|
||||
# If overwrite is disabled and the key already exists, prevent overwrite
|
||||
if do_not_overwrite and key in cache.cache:
|
||||
logging.warning(f"Key '{key}' already exists in cache '{cache_name}'. Overwrite prevented.")
|
||||
return
|
||||
|
||||
# Add the entry to the row-based cache
|
||||
cache.add_entry(key=key, data=data, expiration_time=expiration_time)
|
||||
|
||||
# Handle Table-Based Cache (only accepts DataFrame)
|
||||
elif isinstance(cache, TableBasedCache):
|
||||
# Ensure data is a DataFrame, as only DataFrames are allowed in table-based caches
|
||||
if isinstance(data, pd.DataFrame):
|
||||
if do_not_overwrite:
|
||||
existing_rows = cache.query([("key", key)])
|
||||
if not existing_rows.empty:
|
||||
logging.warning(
|
||||
f"Entry with key '{key}' already exists in cache '{cache_name}'. Overwrite prevented."
|
||||
)
|
||||
return
|
||||
# Add the DataFrame to the table-based cache
|
||||
cache.add_table(df=data, expiration_time=expiration_time, key=key)
|
||||
else:
|
||||
raise ValueError("TableBasedCache can only store DataFrames.")
|
||||
else:
|
||||
raise ValueError(f"Unsupported cache type for '{cache_name}'")
|
||||
|
||||
def get_cache_item(self, key: str, cache_name: str = 'default_cache') -> Any:
|
||||
"""
|
||||
Retrieves an item from the specified cache.
|
||||
|
||||
:param cache_name: The name of the cache.
|
||||
:param key: The key associated with the cache item.
|
||||
:return Any: The cached data, or None if the key does not exist or the item is expired.
|
||||
"""
|
||||
# Retrieve the cache instance
|
||||
cache = self.get_cache(cache_name)
|
||||
|
||||
# Handle different cache types
|
||||
if isinstance(cache, RowBasedCache):
|
||||
data = cache.get_entry(key=key)
|
||||
elif isinstance(cache, TableBasedCache):
|
||||
data = cache.query([('key', key)]) # Assuming 'key' is a valid query parameter
|
||||
else:
|
||||
logging.error(f"Unsupported cache type for '{cache_name}'")
|
||||
return None
|
||||
|
||||
# If no data is found, log and return None
|
||||
if data is None or (isinstance(data, pd.DataFrame) and data.empty):
|
||||
logging.info(f"No data found in cache '{cache_name}' for key: {key}")
|
||||
return None
|
||||
|
||||
# Handle Indicator case (deserialize using pickle)
|
||||
if cache_name == 'indicators':
|
||||
logging.info(f"Indicator data retrieved from cache for key: {key}")
|
||||
try:
|
||||
deserialized_data = pickle.loads(data)
|
||||
if isinstance(deserialized_data, Indicator):
|
||||
return deserialized_data
|
||||
else:
|
||||
logging.warning(f"Expected Indicator instance, got {type(deserialized_data)}")
|
||||
return deserialized_data # Fallback: Return deserialized data even if it's not an Indicator
|
||||
except (pickle.PickleError, TypeError) as e:
|
||||
logging.error(f"Deserialization failed for key '{key}' in cache '{cache_name}': {e}")
|
||||
return None
|
||||
|
||||
logging.info(f"Data retrieved from cache '{cache_name}' for key: {key}")
|
||||
return data
|
||||
|
||||
def set_user_indicator_properties(self, user_id: str, indicator_type: str, symbol: str, timeframe: str,
|
||||
exchange_name: str, display_properties: dict):
|
||||
"""
|
||||
|
|
@ -1439,7 +1514,8 @@ class IndicatorCache(ServerInteractions):
|
|||
raise ValueError("display_properties must be a dictionary")
|
||||
|
||||
user_cache_key = f"user_{user_id}_{indicator_type}_{symbol}_{timeframe}_{exchange_name}"
|
||||
self.set_cache_item(key=user_cache_key, data=display_properties, cache_name='user_display_properties')
|
||||
self.serialized_datacache_insert(key=user_cache_key, data=display_properties,
|
||||
cache_name='user_display_properties')
|
||||
|
||||
def get_user_indicator_properties(self, user_id: str, indicator_type: str, symbol: str, timeframe: str,
|
||||
exchange_name: str) -> dict:
|
||||
|
|
@ -1452,7 +1528,7 @@ class IndicatorCache(ServerInteractions):
|
|||
not isinstance(exchange_name, str):
|
||||
raise TypeError("All arguments must be of type str")
|
||||
user_cache_key = f"user_{user_id}_{indicator_type}_{symbol}_{timeframe}_{exchange_name}"
|
||||
return self.get_cache_item(user_cache_key, cache_name='user_display_properties')
|
||||
return self.get_rows_from_datacache(key=user_cache_key, cache_name='user_display_properties')
|
||||
|
||||
def find_gaps_in_intervals(self, cached_data, start_idx, end_idx, timeframe, min_gap_size=None):
|
||||
"""
|
||||
|
|
@ -1511,7 +1587,7 @@ class IndicatorCache(ServerInteractions):
|
|||
)
|
||||
|
||||
# Step 4: Cache the newly calculated data
|
||||
self.set_cache_item(key=cache_key, data=calculated_data, cache_name='indicator_data')
|
||||
self.serialized_datacache_insert(key=cache_key, data=calculated_data, cache_name='indicator_data')
|
||||
|
||||
# Step 5: Retrieve and merge user-specific display properties with defaults
|
||||
merged_properties = self._get_merged_properties(user_name, indicator_type, symbol, timeframe, exchange_name,
|
||||
|
|
@ -1528,7 +1604,7 @@ class IndicatorCache(ServerInteractions):
|
|||
Fetches cached data for the given time range.
|
||||
"""
|
||||
# Retrieve cached data (expected to be a DataFrame with 'time' in Unix ms)
|
||||
cached_df = self.get_cache_item(cache_key, cache_name='indicator_data')
|
||||
cached_df = self.get_rows_from_datacache(key=cache_key, cache_name='indicator_data')
|
||||
|
||||
# If no cached data, return an empty DataFrame
|
||||
if cached_df is None or cached_df.empty:
|
||||
|
|
@ -1556,7 +1632,7 @@ class IndicatorCache(ServerInteractions):
|
|||
:return: A list of tuples representing the missing intervals (start, end).
|
||||
"""
|
||||
# Convert start and end datetime to Unix timestamps (milliseconds)
|
||||
start_timestamp = int(start_datetime.timestamp() * 1000)
|
||||
# start_timestamp = int(start_datetime.timestamp() * 1000)
|
||||
end_timestamp = int(end_datetime.timestamp() * 1000)
|
||||
|
||||
if cached_data is not None and not cached_data.empty:
|
||||
|
|
@ -1630,7 +1706,7 @@ class DataCache(IndicatorCache):
|
|||
'limit': [limit],
|
||||
'eviction_policy': [eviction_policy],
|
||||
'cache': pd.DataFrame({
|
||||
'key': [key],
|
||||
'tbl_key': [key],
|
||||
'creation_time': [creation_time],
|
||||
'expire_delta': [expire_delta],
|
||||
'data': pd.DataFrame({
|
||||
|
|
@ -1649,7 +1725,7 @@ class DataCache(IndicatorCache):
|
|||
'limit': 100,
|
||||
'eviction_policy': 'deny',
|
||||
'cache': pd.DataFrame({
|
||||
'key': ['user1', 'user2'],
|
||||
'tbl_key': ['user1', 'user2'],
|
||||
'creation_time': [
|
||||
datetime.datetime(2024, 8, 24, 12, 45, 31, 117684),
|
||||
datetime.datetime(2023, 8, 12, 6, 35, 21, 113644)
|
||||
|
|
@ -1676,7 +1752,7 @@ class DataCache(IndicatorCache):
|
|||
'limit': 100,
|
||||
'eviction_policy': 'expire',
|
||||
'cache': pd.DataFrame({
|
||||
'key': ['strategy1', 'strategy2'],
|
||||
'tbl_key': ['strategy1', 'strategy2'],
|
||||
'creation_time': [
|
||||
datetime.datetime(2024, 3, 13, 12, 45, 31, 117684),
|
||||
datetime.datetime(2023, 4, 12, 6, 35, 21, 113644)
|
||||
|
|
|
|||
|
|
@ -58,17 +58,21 @@ def make_query(item: str, table: str, columns: List[str]) -> str:
|
|||
return f"SELECT {item} FROM {table} WHERE {placeholders};"
|
||||
|
||||
|
||||
def make_insert(table: str, columns: Tuple[str, ...]) -> str:
|
||||
def make_insert(table: str, columns: Tuple[str, ...], replace: bool = False) -> str:
|
||||
"""
|
||||
Creates a SQL insert query string with the required number of placeholders.
|
||||
|
||||
:param replace: bool will replace if set.
|
||||
:param table: The table to insert into.
|
||||
:param columns: Tuple of column names.
|
||||
:return: The query string.
|
||||
"""
|
||||
col_names = ", ".join([f"'{col}'" for col in columns])
|
||||
col_names = ", ".join([f'"{col}"' for col in columns]) # Use double quotes for column names
|
||||
placeholders = ", ".join(["?" for _ in columns])
|
||||
return f"INSERT INTO {table} ({col_names}) VALUES ({placeholders});"
|
||||
if replace:
|
||||
return f'INSERT OR REPLACE INTO "{table}" ({col_names}) VALUES ({placeholders});'
|
||||
return f'INSERT INTO "{table}" ({col_names}) VALUES ({placeholders});'
|
||||
|
||||
|
||||
|
||||
class Database:
|
||||
|
|
|
|||
|
|
@ -67,8 +67,7 @@ class Exchange:
|
|||
|
||||
def _check_authentication(self):
|
||||
try:
|
||||
# Perform an authenticated request to check if the API keys are valid
|
||||
self.client.fetch_balance()
|
||||
self.client.fetch_open_orders() # Much faster than fetch_balance
|
||||
self.configured = True
|
||||
logger.info("Authentication successful.")
|
||||
except ccxt.AuthenticationError:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import ccxt
|
|||
from Exchange import Exchange
|
||||
from DataCache_v3 import DataCache
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -91,8 +90,7 @@ class ExchangeInterface:
|
|||
'user': user_name, 'name': exchange.name,
|
||||
'reference': exchange, 'balances': exchange.balances}])
|
||||
|
||||
cache = self.cache_manager.get_cache('exchange_data')
|
||||
cache.add_table(df=row)
|
||||
self.cache_manager.serialized_datacache_insert(cache_name='exchange_data', data=row)
|
||||
except Exception as e:
|
||||
logger.error(f"Couldn't create an instance of the exchange! {str(e)}")
|
||||
raise
|
||||
|
|
@ -108,12 +106,12 @@ class ExchangeInterface:
|
|||
if not ename or not uname:
|
||||
raise ValueError('Missing argument!')
|
||||
|
||||
cache = self.cache_manager.get_cache('exchange_data')
|
||||
exchange_data = cache.query([('name', ename), ('user', uname)])
|
||||
exchange_data = self.cache_manager.get_serialized_datacache(cache_name='exchange_data',
|
||||
filter_vals=[('name', ename), ('user', uname)])
|
||||
|
||||
if exchange_data.empty:
|
||||
raise ValueError('No matching exchange found.')
|
||||
|
||||
# todo check this
|
||||
return exchange_data.at[exchange_data.index[0], 'reference']
|
||||
|
||||
def get_connected_exchanges(self, user_name: str) -> List[str]:
|
||||
|
|
@ -123,8 +121,8 @@ class ExchangeInterface:
|
|||
:param user_name: The name of the user.
|
||||
:return: A list of connected exchange names.
|
||||
"""
|
||||
cache = self.cache_manager.get_cache('exchange_data')
|
||||
exchanges = cache.query([('user', user_name)])
|
||||
exchanges = self.cache_manager.get_serialized_datacache(
|
||||
cache_name='exchange_data', filter_vals=[('user', user_name)])
|
||||
return exchanges['name'].tolist()
|
||||
|
||||
def get_available_exchanges(self) -> List[str]:
|
||||
|
|
@ -139,8 +137,11 @@ class ExchangeInterface:
|
|||
:param name: The name of the exchange.
|
||||
:return: A Series containing the balances.
|
||||
"""
|
||||
cache = self.cache_manager.get_cache('exchange_data')
|
||||
exchange = cache.query([('user', user_name), ('name', name)])
|
||||
exchange = self.cache_manager.get_serialized_datacache(
|
||||
cache_name='exchange_data',
|
||||
filter_vals=[('user', user_name), ('name', name)]
|
||||
)
|
||||
|
||||
if not exchange.empty:
|
||||
return exchange.iloc[0]['balances']
|
||||
else:
|
||||
|
|
@ -154,8 +155,10 @@ class ExchangeInterface:
|
|||
:return: A dictionary containing the balances of all connected exchanges.
|
||||
"""
|
||||
# Query exchange data for the given user
|
||||
cache = self.cache_manager.get_cache('exchange_data')
|
||||
exchanges = cache.query([('user', user_name)])
|
||||
exchanges = self.cache_manager.get_serialized_datacache(
|
||||
cache_name='exchange_data',
|
||||
filter_vals=[('user', user_name)]
|
||||
)
|
||||
|
||||
# Select 'name' and 'balances' columns for all rows
|
||||
filtered_data = exchanges.loc[:, ['name', 'balances']]
|
||||
|
|
@ -171,8 +174,10 @@ class ExchangeInterface:
|
|||
:param fetch_type: The type of data to fetch ('trades' or 'orders').
|
||||
:return: A dictionary indexed by exchange name with lists of active trades or open orders.
|
||||
"""
|
||||
cache = self.cache_manager.get_cache('exchange_data')
|
||||
exchanges = cache.query([('user', user_name)])
|
||||
exchanges = self.cache_manager.get_serialized_datacache(
|
||||
cache_name='exchange_data',
|
||||
filter_vals=[('user', user_name)]
|
||||
)
|
||||
|
||||
# Select the 'name' and 'reference' columns
|
||||
filtered_data = exchanges.loc[:, ['name', 'reference']]
|
||||
|
|
|
|||
58
src/Users.py
58
src/Users.py
|
|
@ -39,23 +39,23 @@ class BaseUser:
|
|||
:param user_name: The name of the user.
|
||||
:return: The ID of the user as an integer.
|
||||
"""
|
||||
return self.data.fetch_datacache_item(
|
||||
return self.data.get_datacache_item(
|
||||
item_name='id',
|
||||
cache_name='users',
|
||||
filter_vals=('user_name', user_name)
|
||||
)
|
||||
|
||||
def get_username(self, id: int) -> str:
|
||||
def get_username(self, user_id: int) -> str:
|
||||
"""
|
||||
Retrieves the user username based on the ID.
|
||||
|
||||
:param id: The id of the user.
|
||||
:param user_id: The id of the user.
|
||||
:return: The name of the user as a str.
|
||||
"""
|
||||
return self.data.fetch_datacache_item(
|
||||
return self.data.get_datacache_item(
|
||||
item_name='user_name',
|
||||
cache_name='users',
|
||||
filter_vals=('id', id)
|
||||
filter_vals=('id', user_id)
|
||||
)
|
||||
|
||||
def _remove_user_from_memory(self, user_name: str) -> None:
|
||||
|
|
@ -109,8 +109,8 @@ class BaseUser:
|
|||
cache_name='users',
|
||||
filter_vals=[('user_name', username)],
|
||||
field_name=field_name,
|
||||
new_data=new_data
|
||||
)
|
||||
new_data=new_data,
|
||||
overwrite='user_name')
|
||||
|
||||
|
||||
class UserAccountManagement(BaseUser):
|
||||
|
|
@ -481,28 +481,28 @@ class UserIndicatorManagement(UserExchangeManagement):
|
|||
|
||||
return df
|
||||
|
||||
def save_indicators(self, indicators: pd.DataFrame) -> None:
|
||||
"""
|
||||
Stores one or many indicators in the database.
|
||||
|
||||
:param indicators: A DataFrame containing indicator attributes and properties.
|
||||
"""
|
||||
for _, indicator in indicators.iterrows():
|
||||
try:
|
||||
# Convert necessary fields to JSON strings
|
||||
src_string = json.dumps(indicator['source'])
|
||||
prop_string = json.dumps(indicator['properties'])
|
||||
|
||||
# Prepare the values and columns for insertion
|
||||
values = (indicator['creator'], indicator['name'], indicator['visible'],
|
||||
indicator['kind'], src_string, prop_string)
|
||||
columns = ('creator', 'name', 'visible', 'kind', 'source', 'properties')
|
||||
|
||||
# Insert the row into the database and cache using DataCache
|
||||
self.data.insert_row_into_datacache(cache_name='indicators', columns=columns, values=values)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error saving indicator {indicator['name']} for creator {indicator['creator']}: {str(e)}")
|
||||
# def save_indicators(self, indicators: pd.DataFrame) -> None:
|
||||
# """
|
||||
# Stores one or many indicators in the database.
|
||||
#
|
||||
# :param indicators: A DataFrame containing indicator attributes and properties.
|
||||
# """
|
||||
# for _, indicator in indicators.iterrows():
|
||||
# try:
|
||||
# # Convert necessary fields to JSON strings
|
||||
# src_string = json.dumps(indicator['source'])
|
||||
# prop_string = json.dumps(indicator['properties'])
|
||||
#
|
||||
# # Prepare the values and columns for insertion
|
||||
# values = (indicator['creator'], indicator['name'], indicator['visible'],
|
||||
# indicator['kind'], src_string, prop_string)
|
||||
# columns = ('creator', 'name', 'visible', 'kind', 'source', 'properties')
|
||||
#
|
||||
# # Insert the row into the database and cache using DataCache
|
||||
# self.data.insert_row_into_datacache(cache_name='indicators', columns=columns, values=values)
|
||||
#
|
||||
# except Exception as e:
|
||||
# print(f"Error saving indicator {indicator['name']} for creator {indicator['creator']}: {str(e)}")
|
||||
|
||||
def get_chart_view(self, user_name: str, prop: str | None = None):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -9,17 +9,21 @@ from shared_utilities import timeframe_to_minutes, ts_of_n_minutes_ago
|
|||
# log.basicConfig(level=log.ERROR)
|
||||
|
||||
class Candles:
|
||||
def __init__(self, exchanges, users, data_source, config):
|
||||
def __init__(self, exchanges, users, datacache, config):
|
||||
|
||||
# A reference to the app configuration
|
||||
self.users = users
|
||||
|
||||
# This object maintains all the cached data.
|
||||
self.data = datacache
|
||||
|
||||
# size_limit is the max number of lists of candle(ohlc) data allowed.
|
||||
self.data.create_cache(name='candles', cache_type='row', default_expiration=dt.timedelta(days=5),
|
||||
size_limit=100, eviction_policy='evict')
|
||||
|
||||
# The maximum amount of candles to load at one time.
|
||||
self.max_records = config.get_setting('max_data_loaded')
|
||||
|
||||
# This object maintains all the cached data.
|
||||
self.data = data_source
|
||||
|
||||
# print('Setting the candle data.')
|
||||
# # Populate the data:
|
||||
# self.set_cache(symbol=self.users.get_chart_view(user_name='guest', specific_property='market'),
|
||||
|
|
@ -84,6 +88,7 @@ class Candles:
|
|||
def set_cache(self, symbol=None, interval=None, exchange_name=None, user_name=None):
|
||||
"""
|
||||
This method requests a chart from memory to ensure the data is initialized.
|
||||
TODO: This method is un-used.
|
||||
|
||||
:param user_name:
|
||||
:param symbol: str - The symbol of the market.
|
||||
|
|
@ -107,10 +112,9 @@ class Candles:
|
|||
# Log the completion to the console.
|
||||
log.info('set_candle_history(): Loading candle data...')
|
||||
|
||||
# Todo this doesn't seem necessary.
|
||||
# Load candles from database
|
||||
# _cdata = self.get_last_n_candles(num_candles=self.max_records,
|
||||
# asset=symbol, timeframe=interval, exchange=exchange_name, user_name=user_name)
|
||||
_cdata = self.get_last_n_candles(num_candles=self.max_records,
|
||||
asset=symbol, timeframe=interval, exchange=exchange_name, user_name=user_name)
|
||||
|
||||
# Log the completion to the console.
|
||||
log.info('set_candle_history(): Candle data Loaded.')
|
||||
|
|
@ -178,8 +182,6 @@ class Candles:
|
|||
|
||||
new_candles = candles.loc[:, ['time', 'open', 'high', 'low', 'close', 'volume']]
|
||||
|
||||
new_candles.rename(columns={'time': 'time'}, inplace=True)
|
||||
|
||||
# The timestamps are in milliseconds but lightweight charts needs it divided by 1000.
|
||||
new_candles.loc[:, ['time']] = new_candles.loc[:, ['time']].div(1000)
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class Indicator:
|
|||
closes = candles.close.to_numpy(dtype='float')
|
||||
|
||||
# Processing the close prices to calculate the Indicator
|
||||
i_values = self.process(closes, self.properties['period'])
|
||||
i_values = self.process(closes, int(self.properties['period']))
|
||||
|
||||
# Stores the last calculated value.
|
||||
self.properties['value'] = round(float(i_values[-1]), 2)
|
||||
|
|
@ -43,7 +43,7 @@ class Indicator:
|
|||
df = pd.DataFrame({'time': candles.time, 'value': i_values.tolist()})
|
||||
|
||||
# Slice the DataFrame to skip initial rows where the indicator will be undefined
|
||||
return df.iloc[self.properties['period']:]
|
||||
return df.iloc[int(self.properties['period']):]
|
||||
|
||||
def process(self, data, period):
|
||||
"""
|
||||
|
|
@ -101,7 +101,7 @@ class SMA(Indicator):
|
|||
"""
|
||||
Calculate the Simple Moving Average (SMA) of the given data.
|
||||
"""
|
||||
return talib.SMA(data, period)
|
||||
return talib.SMA(data, int(period))
|
||||
|
||||
|
||||
class EMA(SMA):
|
||||
|
|
@ -109,7 +109,7 @@ class EMA(SMA):
|
|||
"""
|
||||
Calculate the Exponential Moving Average (EMA) of the given data.
|
||||
"""
|
||||
return talib.EMA(data, period)
|
||||
return talib.EMA(data, int(period))
|
||||
|
||||
|
||||
class RSI(SMA):
|
||||
|
|
@ -122,7 +122,7 @@ class RSI(SMA):
|
|||
"""
|
||||
Calculate the Relative Strength Index (RSI) of the given data.
|
||||
"""
|
||||
return talib.RSI(data, period)
|
||||
return talib.RSI(data, int(period))
|
||||
|
||||
|
||||
class LREG(SMA):
|
||||
|
|
@ -130,7 +130,7 @@ class LREG(SMA):
|
|||
"""
|
||||
Calculate the Linear Regression (LREG) of the given data.
|
||||
"""
|
||||
return talib.LINEARREG(data, period)
|
||||
return talib.LINEARREG(data, int(period))
|
||||
|
||||
|
||||
class ATR(SMA):
|
||||
|
|
@ -143,7 +143,7 @@ class ATR(SMA):
|
|||
closes = candles.close.to_numpy(dtype='float')
|
||||
|
||||
# Calculate ATR using the talib library
|
||||
atr = talib.ATR(high=highs, low=lows, close=closes, timeperiod=self.properties['period'])
|
||||
atr = talib.ATR(high=highs, low=lows, close=closes, timeperiod=int(self.properties['period']))
|
||||
|
||||
# Create DataFrame with 'time' and 'value' columns
|
||||
df = pd.DataFrame({'time': candles.time, 'value': atr})
|
||||
|
|
@ -152,7 +152,7 @@ class ATR(SMA):
|
|||
self.properties['value'] = round(float(atr[-1]), 2)
|
||||
|
||||
# Return the sliced DataFrame, excluding rows where the indicator is not fully calculated
|
||||
return df.iloc[self.properties['period']:]
|
||||
return df.iloc[int(self.properties['period']):]
|
||||
|
||||
|
||||
class BolBands(Indicator):
|
||||
|
|
@ -175,10 +175,10 @@ class BolBands(Indicator):
|
|||
|
||||
# Calculate the Bollinger Bands (upper, middle, lower)
|
||||
upper, middle, lower = talib.BBANDS(np_real_data,
|
||||
timeperiod=self.properties['period'],
|
||||
nbdevup=self.properties['devup'],
|
||||
nbdevdn=self.properties['devdn'],
|
||||
matype=self.properties['ma'])
|
||||
timeperiod=int(self.properties['period']),
|
||||
nbdevup=int(self.properties['devup']),
|
||||
nbdevdn=int(self.properties['devdn']),
|
||||
matype=int(self.properties['ma']))
|
||||
|
||||
# Store the last calculated values in properties
|
||||
self.properties['value'] = round(float(upper[-1]), 2)
|
||||
|
|
@ -197,7 +197,7 @@ class BolBands(Indicator):
|
|||
df = df.round({'upper': 2, 'middle': 2, 'lower': 2})
|
||||
|
||||
# Slice the DataFrame to skip initial rows where the indicator might be undefined
|
||||
return df.iloc[self.properties['period']:]
|
||||
return df.iloc[int(self.properties['period']):]
|
||||
|
||||
|
||||
class MACD(Indicator):
|
||||
|
|
@ -219,9 +219,9 @@ class MACD(Indicator):
|
|||
|
||||
# Calculate MACD, Signal Line, and MACD Histogram
|
||||
macd, signal, hist = talib.MACD(closes,
|
||||
fastperiod=self.properties['fast_p'],
|
||||
slowperiod=self.properties['slow_p'],
|
||||
signalperiod=self.properties['signal_p'])
|
||||
fastperiod=int(self.properties['fast_p']),
|
||||
slowperiod=int(self.properties['slow_p']),
|
||||
signalperiod=int(self.properties['signal_p']))
|
||||
|
||||
# Store the last calculated values
|
||||
self.properties['macd'] = round(float(macd[-1]), 2)
|
||||
|
|
@ -255,6 +255,10 @@ indicators_registry['MACD'] = MACD
|
|||
|
||||
|
||||
class Indicators:
|
||||
"""
|
||||
Indicators are stored along
|
||||
"""
|
||||
|
||||
def __init__(self, candles, users, cache_manager):
|
||||
# Object manages and serves price and candle data.
|
||||
self.candles = candles
|
||||
|
|
@ -267,13 +271,13 @@ class Indicators:
|
|||
|
||||
# Cache for storing instantiated indicator objects
|
||||
cache_manager.create_cache(
|
||||
name='indicators',
|
||||
cache_type='table',
|
||||
size_limit=100,
|
||||
eviction_policy='deny',
|
||||
default_expiration=dt.timedelta(days=1),
|
||||
columns=['creator', 'name', 'visible', 'kind', 'source', 'properties', 'ref']
|
||||
)
|
||||
name='indicators',
|
||||
cache_type='table',
|
||||
size_limit=100,
|
||||
eviction_policy='deny',
|
||||
default_expiration=dt.timedelta(days=1),
|
||||
columns=['creator', 'name', 'visible', 'kind', 'source', 'properties', 'ref']
|
||||
)
|
||||
|
||||
# Cache for storing calculated indicator data
|
||||
cache_manager.create_cache('indicator_data', cache_type='row', size_limit=100,
|
||||
|
|
@ -292,60 +296,6 @@ class Indicators:
|
|||
self.MV_AVERAGE_ENUM = {'SMA': 0, 'EMA': 1, 'WMA': 2, 'DEMA': 3, 'TEMA': 4,
|
||||
'TRIMA': 5, 'KAMA': 6, 'MAMA': 7, 'T3': 8}
|
||||
|
||||
def load_indicators(self, user_name):
|
||||
"""
|
||||
Get the users watch-list from the database and load the indicators into a dataframe.
|
||||
:return: None
|
||||
"""
|
||||
active_indicators: pd.DataFrame = self.users.get_indicators(user_name)
|
||||
|
||||
if active_indicators is not None:
|
||||
# Create an instance for each indicator.
|
||||
for i in active_indicators.itertuples():
|
||||
self.create_indicator(
|
||||
creator=user_name, name=i.name,
|
||||
kind=i.kind, source=i.source,
|
||||
visible=i.visible, properties=i.properties
|
||||
)
|
||||
|
||||
def save_indicator(self, indicator):
|
||||
"""
|
||||
Saves the indicators in the database indexed by the user id.
|
||||
:return: None
|
||||
"""
|
||||
self.users.save_indicators(indicator)
|
||||
|
||||
# @staticmethod
|
||||
# def get_indicator_defaults():
|
||||
# """Set the default settings for each indicator"""
|
||||
#
|
||||
# indicator_list = {
|
||||
# 'EMA 5': {'type': 'EMA', 'period': 5, 'visible': True, 'color': f"#{random.randrange(0x1000000):06x}",
|
||||
# 'value': 0, 'market': 'BTC/USD', 'time_frame': '5m', 'exchange_name': 'alpaca'},
|
||||
# 'EMA 15': {'type': 'EMA', 'period': 15, 'visible': True, 'color': f"#{random.randrange(0x1000000):06x}",
|
||||
# 'value': 0, 'market': 'BTC/USD', 'time_frame': '5m', 'exchange_name': 'alpaca'},
|
||||
# 'EMA 20': {'type': 'EMA', 'period': 20, 'visible': True, 'color': f"#{random.randrange(0x1000000):06x}",
|
||||
# 'value': 0, 'market': 'BTC/USD', 'time_frame': '5m', 'exchange_name': 'alpaca'},
|
||||
# 'EMA 50': {'type': 'EMA', 'period': 50, 'visible': True, 'color': f"#{random.randrange(0x1000000):06x}",
|
||||
# 'value': 0, 'market': 'BTC/USD', 'time_frame': '5m', 'exchange_name': 'alpaca'},
|
||||
# 'EMA 100': {'type': 'EMA', 'period': 100, 'visible': True, 'color': f"#{random.randrange(0x1000000):06x}",
|
||||
# 'value': 0, 'market': 'BTC/USD', 'time_frame': '5m', 'exchange_name': 'alpaca'},
|
||||
# 'EMA 200': {'type': 'EMA', 'period': 200, 'visible': True, 'color': f"#{random.randrange(0x1000000):06x}",
|
||||
# 'value': 0, 'market': 'BTC/USD', 'time_frame': '5m', 'exchange_name': 'alpaca'},
|
||||
# 'RSI 14': {'type': 'RSI', 'period': 14, 'visible': True, 'color': f"#{random.randrange(0x1000000):06x}",
|
||||
# 'value': 0, 'market': 'BTC/USD', 'time_frame': '5m', 'exchange_name': 'alpaca'},
|
||||
# 'RSI 8': {'type': 'RSI', 'period': 8, 'visible': True, 'color': f"#{random.randrange(0x1000000):06x}",
|
||||
# 'value': 0, 'market': 'BTC/USD', 'time_frame': '5m', 'exchange_name': 'alpaca'},
|
||||
# 'Bolenger': {'color_1': '#5ad858', 'color_2': '#f0f664', 'color_3': '#5ad858', 'devdn': 2, 'devup': 2,
|
||||
# 'ma': 1, 'period': 20, 'type': 'BOLBands', 'value': '38691.58',
|
||||
# 'value2': '38552.36',
|
||||
# 'value3': '38413.14', 'visible': True, 'market': 'BTC/USD', 'time_frame': '5m',
|
||||
# 'exchange_name': 'alpaca'},
|
||||
# 'vol': {'type': 'Volume', 'visible': True, 'value': 0, 'market': 'BTC/USD', 'time_frame': '5m',
|
||||
# 'exchange_name': 'alpaca'}
|
||||
# }
|
||||
# return indicator_list
|
||||
|
||||
def get_available_indicator_types(self) -> list:
|
||||
"""Returns a list of all available indicator types."""
|
||||
return list(self.indicator_registry.keys())
|
||||
|
|
@ -359,14 +309,15 @@ class Indicators:
|
|||
:param only_enabled: bool - If True, return only indicators marked as visible.
|
||||
:return: dict - A dictionary of indicator names as keys and their attributes as values.
|
||||
"""
|
||||
user_id = self.users.get_id(username)
|
||||
user_id = str(self.users.get_id(username))
|
||||
|
||||
if not user_id:
|
||||
raise ValueError(f"Invalid user_name: {username}")
|
||||
|
||||
# Fetch indicators based on visibility status
|
||||
if only_enabled:
|
||||
indicators_df = self.cache_manager.get_rows_from_datacache('indicators', [('creator', user_id), ('visible', 1)])
|
||||
indicators_df = self.cache_manager.get_rows_from_datacache('indicators',
|
||||
[('creator', user_id), ('visible', str(1))])
|
||||
else:
|
||||
indicators_df = self.cache_manager.get_rows_from_datacache('indicators', [('creator', user_id)])
|
||||
|
||||
|
|
@ -380,8 +331,7 @@ class Indicators:
|
|||
for _, row in indicators_df.iterrows():
|
||||
# Ensure that row['properties'] is a dictionary
|
||||
properties = row.get('properties', {})
|
||||
if not isinstance(properties, dict):
|
||||
properties = {}
|
||||
properties = json.loads(properties) if isinstance(properties, str) else properties
|
||||
|
||||
# Construct the result dictionary for each indicator
|
||||
result[row['name']] = {
|
||||
|
|
@ -406,11 +356,12 @@ class Indicators:
|
|||
return
|
||||
|
||||
# Set visibility for all indicators off
|
||||
self.cache_manager.modify_datacache_item('indicators', [('creator', user_id)], field_name='visible', new_data=0)
|
||||
self.cache_manager.modify_datacache_item('indicators', [('creator', user_id)],
|
||||
field_name='visible', new_data=0, overwrite='name')
|
||||
|
||||
# Set visibility for the specified indicators on
|
||||
self.cache_manager.modify_datacache_item('indicators', [('creator', user_id), ('name', indicator_names)],
|
||||
field_name='visible', new_data=1)
|
||||
field_name='visible', new_data=1, overwrite='name')
|
||||
|
||||
def edit_indicator(self, user_name: str, params: dict):
|
||||
"""
|
||||
|
|
@ -423,16 +374,27 @@ class Indicators:
|
|||
raise ValueError("Indicator name is required for editing.")
|
||||
|
||||
# Get the indicator from the user's indicator list
|
||||
user_id = self.users.get_id(user_name)
|
||||
indicator = self.cache_manager.get_rows_from_datacache('indicators', [('name', indicator_name), ('creator', user_id)])
|
||||
user_id = str(self.users.get_id(user_name))
|
||||
indicator = self.cache_manager.get_rows_from_datacache('indicators',
|
||||
[('name', indicator_name), ('creator', user_id)])
|
||||
|
||||
if indicator.empty:
|
||||
raise ValueError(f"Indicator '{indicator_name}' not found for user '{user_name}'.")
|
||||
|
||||
# Modify indicator.
|
||||
self.cache_manager.modify_datacache_item('indicators',
|
||||
[('creator', params.get('user_name')), ('name', params.get('name'))],
|
||||
field_name=params.get('setting'), new_data=params.get('value'))
|
||||
[('creator', user_id), ('name', indicator_name)],
|
||||
field_name='properties', new_data=params.get('properties'),
|
||||
overwrite='name')
|
||||
|
||||
new_visible = params.get('visible')
|
||||
current_visible = indicator['visible'].iloc[0]
|
||||
|
||||
if current_visible != new_visible:
|
||||
self.cache_manager.modify_datacache_item('indicators',
|
||||
[('creator', user_id), ('name', indicator_name)],
|
||||
field_name='visible', new_data=new_visible,
|
||||
overwrite='name')
|
||||
|
||||
def new_indicator(self, user_name: str, params) -> None:
|
||||
"""
|
||||
|
|
@ -485,7 +447,7 @@ class Indicators:
|
|||
|
||||
# Adjust num_results to account for the lookup period if specified in the indicator properties.
|
||||
if 'period' in properties:
|
||||
num_results += properties['period']
|
||||
num_results += int(properties['period'])
|
||||
|
||||
# Request the data from the defined source.
|
||||
data = self.candles.get_last_n_candles(num_candles=num_results,
|
||||
|
|
@ -522,7 +484,8 @@ class Indicators:
|
|||
visible = 1 if visible_only else 0
|
||||
|
||||
# Filter the indicators based on the query.
|
||||
indicators = self.cache_manager.get_rows_from_datacache('indicators', [('creator', user_id), ('visible', visible)])
|
||||
indicators = self.cache_manager.get_rows_from_datacache('indicators',
|
||||
[('creator', user_id), ('visible', visible)])
|
||||
|
||||
# Return None if no indicators matched the query.
|
||||
if indicators.empty:
|
||||
|
|
@ -552,7 +515,7 @@ class Indicators:
|
|||
# Process each indicator, convert DataFrame to JSON-serializable format, and collect the results
|
||||
json_ready_results = {}
|
||||
|
||||
for indicator in indicators.itertuples(index=False):
|
||||
for indicator in filtered_indicators.itertuples(index=False):
|
||||
indicator_results = self.process_indicator(indicator=indicator, num_results=num_results)
|
||||
|
||||
# Convert DataFrame to list of dictionaries if necessary
|
||||
|
|
@ -616,7 +579,7 @@ class Indicators:
|
|||
'source': source,
|
||||
'properties': properties
|
||||
}])
|
||||
self.cache_manager.insert_df_into_datacache(df=row_data, cache_name="users", skip_cache=False)
|
||||
self.cache_manager.insert_df_into_datacache(df=row_data, cache_name="indicators", skip_cache=False)
|
||||
|
||||
# def update_indicators(self, user_name):
|
||||
# """
|
||||
|
|
|
|||
|
|
@ -469,9 +469,22 @@ class Indicators {
|
|||
const indicatorName = nameDiv.innerText.trim(); // Get the indicator name
|
||||
|
||||
// Gather input data
|
||||
const formObj = { name: indicatorName }; // Initialize formObj with the name
|
||||
// Initialize formObj with the name of the indicator
|
||||
const formObj = {
|
||||
name: indicatorName,
|
||||
visible: false, // Default value for visible (will be updated based on the checkbox input)
|
||||
properties: {}
|
||||
};
|
||||
|
||||
// Iterate over each input (text, checkbox, select) and add its name and value to formObj
|
||||
inputs.forEach(input => {
|
||||
formObj[input.name] = input.type === 'checkbox' ? input.checked : input.value;
|
||||
if (input.name === 'visible') {
|
||||
// Handle the visible checkbox separately
|
||||
formObj.visible = input.checked;
|
||||
} else {
|
||||
// Add all other inputs (type, period, color) to the properties object
|
||||
formObj.properties[input.name] = input.type === 'checkbox' ? input.checked : input.value;
|
||||
}
|
||||
});
|
||||
|
||||
// Call comms to send data to the server
|
||||
|
|
|
|||
|
|
@ -205,8 +205,8 @@ class DataGenerator:
|
|||
class TestDataCache(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Initialize DataCache
|
||||
self.exchanges = ExchangeInterface()
|
||||
self.data = DataCache(self.exchanges)
|
||||
self.data = DataCache()
|
||||
self.exchanges = ExchangeInterface(self.data)
|
||||
|
||||
self.exchanges_connected = False
|
||||
self.database_is_setup = False
|
||||
|
|
@ -947,7 +947,7 @@ class TestDataCache(unittest.TestCase):
|
|||
df_initial = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 0, 0, tzinfo=dt.timezone.utc))
|
||||
|
||||
print(f'Inserting this table into cache:\n{df_initial}\n')
|
||||
self.data.set_cache_item(key=candle_cache_key, data=df_initial, cache_name='candles')
|
||||
self.data.serialized_datacache_insert(key=candle_cache_key, data=df_initial, cache_name='candles')
|
||||
|
||||
# Create new DataFrame to be added to the cache
|
||||
df_new = data_gen.create_table(num_rec=3, start=dt.datetime(2024, 8, 9, 0, 15, 0, tzinfo=dt.timezone.utc))
|
||||
|
|
@ -956,7 +956,7 @@ class TestDataCache(unittest.TestCase):
|
|||
self.data._update_candle_cache(more_records=df_new, key=candle_cache_key)
|
||||
|
||||
# Retrieve the resulting DataFrame from the cache
|
||||
result = self.data.get_cache_item(key=candle_cache_key, cache_name='candles')
|
||||
result = self.data.get_serialized_datacache(key=candle_cache_key, cache_name='candles')
|
||||
print(f'The resulting table in cache is:\n{result}\n')
|
||||
|
||||
# Create the expected DataFrame
|
||||
|
|
@ -1089,7 +1089,7 @@ class TestDataCache(unittest.TestCase):
|
|||
|
||||
if set_cache:
|
||||
print('Ensuring the cache exists and then inserting table into the cache.')
|
||||
self.data.set_cache_item(data=df_initial, key=key, cache_name='candles')
|
||||
self.data.serialized_datacache_insert(data=df_initial, key=key, cache_name='candles')
|
||||
|
||||
if set_db:
|
||||
print('Inserting table into the database.')
|
||||
|
|
@ -1432,26 +1432,26 @@ class TestDataCache(unittest.TestCase):
|
|||
|
||||
# Insert these DataFrames into the 'users' cache with row-based caching
|
||||
self.data.create_cache('users', cache_type='row') # Assuming 'row' cache type for this test
|
||||
self.data.set_cache_item(key='user_billy', data=df1, cache_name='users')
|
||||
self.data.set_cache_item(key='user_john', data=df2, cache_name='users')
|
||||
self.data.set_cache_item(key='user_alice', data=df3, cache_name='users')
|
||||
self.data.serialized_datacache_insert(key='user_billy', data=df1, cache_name='users')
|
||||
self.data.serialized_datacache_insert(key='user_john', data=df2, cache_name='users')
|
||||
self.data.serialized_datacache_insert(key='user_alice', data=df3, cache_name='users')
|
||||
|
||||
print('Testing get_or_fetch_rows() method:')
|
||||
|
||||
# Fetch user directly by key since this is a row-based cache
|
||||
result = self.data.get_cache_item(key='user_billy', cache_name='users')
|
||||
result = self.data.get_serialized_datacache(key='user_billy', cache_name='users')
|
||||
self.assertIsInstance(result, pd.DataFrame, "Failed to fetch DataFrame from cache")
|
||||
self.assertFalse(result.empty, "The fetched DataFrame is empty")
|
||||
self.assertEqual(result.iloc[0]['password'], '1234', "Incorrect data fetched from cache")
|
||||
|
||||
# Fetch another user by key
|
||||
result = self.data.get_cache_item(key='user_john', cache_name='users')
|
||||
result = self.data.get_serialized_datacache(key='user_john', cache_name='users')
|
||||
self.assertIsInstance(result, pd.DataFrame, "Failed to fetch DataFrame from cache")
|
||||
self.assertFalse(result.empty, "The fetched DataFrame is empty")
|
||||
self.assertEqual(result.iloc[0]['password'], '5678', "Incorrect data fetched from cache")
|
||||
|
||||
# Test fetching a user that does not exist in the cache
|
||||
result = self.data.get_cache_item(key='non_existent_user', cache_name='users')
|
||||
result = self.data.get_serialized_datacache(key='non_existent_user', cache_name='users')
|
||||
|
||||
# Check if result is None (indicating that no data was found)
|
||||
self.assertIsNone(result, "Expected result to be None for a non-existent user")
|
||||
|
|
@ -1480,9 +1480,9 @@ class TestDataCache(unittest.TestCase):
|
|||
})
|
||||
|
||||
# Insert mock data into the cache
|
||||
self.data.set_cache_item(cache_name='users', data=user_data_1)
|
||||
self.data.set_cache_item(cache_name='users', data=user_data_2)
|
||||
self.data.set_cache_item(cache_name='users', data=user_data_3)
|
||||
self.data.serialized_datacache_insert(cache_name='users', data=user_data_1)
|
||||
self.data.serialized_datacache_insert(cache_name='users', data=user_data_2)
|
||||
self.data.serialized_datacache_insert(cache_name='users', data=user_data_3)
|
||||
|
||||
# Test when attribute value is taken
|
||||
result_taken = user_cache.is_attr_taken('user_name', 'billy')
|
||||
|
|
@ -1606,10 +1606,10 @@ class TestDataCache(unittest.TestCase):
|
|||
|
||||
# Create a row-based cache for indicators and store serialized Indicator data
|
||||
self.data.create_cache('indicators', cache_type='row')
|
||||
self.data.set_cache_item(key='indicator_key', data=indicator, cache_name='indicators')
|
||||
self.data.serialized_datacache_insert(key='indicator_key', data=indicator, cache_name='indicators')
|
||||
|
||||
# Retrieve the indicator and check for deserialization
|
||||
stored_data = self.data.get_cache_item('indicator_key', cache_name='indicators')
|
||||
stored_data = self.data.get_serialized_datacache('indicator_key', cache_name='indicators')
|
||||
self.assertIsInstance(stored_data, Indicator, "Failed to retrieve and deserialize the Indicator instance")
|
||||
|
||||
# Case 2: Retrieve non-Indicator data (e.g., dict)
|
||||
|
|
@ -1619,20 +1619,20 @@ class TestDataCache(unittest.TestCase):
|
|||
self.data.create_cache('default_cache', cache_type='row')
|
||||
|
||||
# Store a dictionary
|
||||
self.data.set_cache_item(key='dict_key', data=data_dict, cache_name='default_cache')
|
||||
self.data.serialized_datacache_insert(key='dict_key', data=data_dict, cache_name='default_cache')
|
||||
|
||||
# Retrieve and check if the data matches the original dict
|
||||
stored_data = self.data.get_cache_item('dict_key', cache_name='default_cache')
|
||||
stored_data = self.data.get_serialized_datacache('dict_key', cache_name='default_cache')
|
||||
self.assertEqual(stored_data, data_dict, "Failed to retrieve non-Indicator data correctly")
|
||||
|
||||
# Case 3: Retrieve a list stored in the cache
|
||||
data_list = [1, 2, 3, 4, 5]
|
||||
|
||||
# Store a list in row-based cache
|
||||
self.data.set_cache_item(key='list_key', data=data_list, cache_name='default_cache')
|
||||
self.data.serialized_datacache_insert(key='list_key', data=data_list, cache_name='default_cache')
|
||||
|
||||
# Retrieve and check if the data matches the original list
|
||||
stored_data = self.data.get_cache_item('list_key', cache_name='default_cache')
|
||||
stored_data = self.data.get_serialized_datacache('list_key', cache_name='default_cache')
|
||||
self.assertEqual(stored_data, data_list, "Failed to retrieve list data correctly")
|
||||
|
||||
# Case 4: Retrieve a DataFrame stored in the cache (Table-Based Cache)
|
||||
|
|
@ -1645,14 +1645,14 @@ class TestDataCache(unittest.TestCase):
|
|||
self.data.create_cache('table_cache', cache_type='table')
|
||||
|
||||
# Store a DataFrame in table-based cache
|
||||
self.data.set_cache_item(key='testkey', data=data_df, cache_name='table_cache')
|
||||
self.data.serialized_datacache_insert(key='testkey', data=data_df, cache_name='table_cache')
|
||||
|
||||
# Retrieve and check if the DataFrame matches the original
|
||||
stored_data = self.data.get_cache_item(key='testkey', cache_name='table_cache')
|
||||
stored_data = self.data.get_serialized_datacache(key='testkey', cache_name='table_cache')
|
||||
pd.testing.assert_frame_equal(stored_data, data_df)
|
||||
|
||||
# Case 5: Attempt to retrieve a non-existent key
|
||||
non_existent = self.data.get_cache_item('non_existent_key', cache_name='default_cache')
|
||||
non_existent = self.data.get_serialized_datacache('non_existent_key', cache_name='default_cache')
|
||||
self.assertIsNone(non_existent, "Expected None for non-existent cache key")
|
||||
|
||||
print(" - All get_cache_item tests passed.")
|
||||
|
|
@ -1674,7 +1674,7 @@ class TestDataCache(unittest.TestCase):
|
|||
user_cache_key = f"user_{user_id}_{indicator_type}_{symbol}_{timeframe}_{exchange_name}"
|
||||
|
||||
# Retrieve the stored properties
|
||||
stored_properties = self.data.get_cache_item(user_cache_key, cache_name='user_display_properties')
|
||||
stored_properties = self.data.get_serialized_datacache(user_cache_key, cache_name='user_display_properties')
|
||||
|
||||
# Check if the properties were stored correctly
|
||||
self.assertEqual(stored_properties, display_properties, "Failed to store user-specific display properties")
|
||||
|
|
@ -1687,7 +1687,7 @@ class TestDataCache(unittest.TestCase):
|
|||
updated_properties)
|
||||
|
||||
# Retrieve the updated properties
|
||||
updated_stored_properties = self.data.get_cache_item(user_cache_key, cache_name='user_display_properties')
|
||||
updated_stored_properties = self.data.get_serialized_datacache(user_cache_key, cache_name='user_display_properties')
|
||||
|
||||
# Check if the properties were updated correctly
|
||||
self.assertEqual(updated_stored_properties, updated_properties,
|
||||
|
|
@ -1738,27 +1738,27 @@ class TestDataCache(unittest.TestCase):
|
|||
key = 'row_key'
|
||||
data = {'some': 'data'}
|
||||
|
||||
data_cache.set_cache_item(cache_name='row_cache', data=data, key=key)
|
||||
cached_item = data_cache.get_cache_item(key, cache_name='row_cache')
|
||||
data_cache.serialized_datacache_insert(cache_name='row_cache', data=data, key=key)
|
||||
cached_item = data_cache.get_serialized_datacache(key, cache_name='row_cache')
|
||||
self.assertEqual(cached_item, data, "Failed to store and retrieve data in RowBasedCache")
|
||||
|
||||
# Case 2: Store and retrieve an Indicator instance (serialization)
|
||||
indicator = Indicator(name='SMA', indicator_type='SMA', properties={'period': 5})
|
||||
data_cache.set_cache_item(cache_name='row_cache', data=indicator, key='indicator_key')
|
||||
cached_indicator = data_cache.get_cache_item('indicator_key', cache_name='row_cache')
|
||||
data_cache.serialized_datacache_insert(cache_name='row_cache', data=indicator, key='indicator_key')
|
||||
cached_indicator = data_cache.get_serialized_datacache('indicator_key', cache_name='row_cache')
|
||||
|
||||
# Assert that the data was correctly serialized and deserialized
|
||||
self.assertIsInstance(pickle.loads(cached_indicator), Indicator, "Failed to deserialize Indicator instance")
|
||||
|
||||
# Case 3: Prevent overwriting an existing key if do_not_overwrite=True
|
||||
new_data = {'new': 'data'}
|
||||
data_cache.set_cache_item(cache_name='row_cache', data=new_data, key=key, do_not_overwrite=True)
|
||||
cached_item_after = data_cache.get_cache_item(key, cache_name='row_cache')
|
||||
data_cache.serialized_datacache_insert(cache_name='row_cache', data=new_data, key=key, do_not_overwrite=True)
|
||||
cached_item_after = data_cache.get_serialized_datacache(key, cache_name='row_cache')
|
||||
self.assertEqual(cached_item_after, data, "Overwriting occurred when it should have been prevented")
|
||||
|
||||
# Case 4: Raise ValueError if key is None in RowBasedCache
|
||||
with self.assertRaises(ValueError, msg="RowBasedCache requires a key to store the data."):
|
||||
data_cache.set_cache_item(cache_name='row_cache', data=data, key=None)
|
||||
data_cache.serialized_datacache_insert(cache_name='row_cache', data=data, key=None)
|
||||
|
||||
# -------------------------
|
||||
# Table-Based Cache Test Cases
|
||||
|
|
@ -1767,19 +1767,19 @@ class TestDataCache(unittest.TestCase):
|
|||
data_cache.create_cache('table_cache', cache_type='table') # Create table-based cache
|
||||
df = pd.DataFrame({'col1': [1, 2], 'col2': ['A', 'B']})
|
||||
|
||||
data_cache.set_cache_item(cache_name='table_cache', data=df, key='table_key')
|
||||
cached_df = data_cache.get_cache_item('table_key', cache_name='table_cache')
|
||||
data_cache.serialized_datacache_insert(cache_name='table_cache', data=df, key='table_key')
|
||||
cached_df = data_cache.get_serialized_datacache('table_key', cache_name='table_cache')
|
||||
pd.testing.assert_frame_equal(cached_df, df, "Failed to store and retrieve DataFrame in TableBasedCache")
|
||||
|
||||
# Case 6: Prevent overwriting an existing key if do_not_overwrite=True in TableBasedCache
|
||||
new_df = pd.DataFrame({'col1': [3, 4], 'col2': ['C', 'D']})
|
||||
data_cache.set_cache_item(cache_name='table_cache', data=new_df, key='table_key', do_not_overwrite=True)
|
||||
cached_df_after = data_cache.get_cache_item('table_key', cache_name='table_cache')
|
||||
data_cache.serialized_datacache_insert(cache_name='table_cache', data=new_df, key='table_key', do_not_overwrite=True)
|
||||
cached_df_after = data_cache.get_serialized_datacache('table_key', cache_name='table_cache')
|
||||
pd.testing.assert_frame_equal(cached_df_after, df, "Overwriting occurred when it should have been prevented")
|
||||
|
||||
# Case 7: Raise ValueError if non-DataFrame data is provided in TableBasedCache
|
||||
with self.assertRaises(ValueError, msg="TableBasedCache can only store DataFrames."):
|
||||
data_cache.set_cache_item(cache_name='table_cache', data={'not': 'a dataframe'}, key='table_key')
|
||||
data_cache.serialized_datacache_insert(cache_name='table_cache', data={'not': 'a dataframe'}, key='table_key')
|
||||
|
||||
# -------------------------
|
||||
# Expiration Handling Test Case
|
||||
|
|
@ -1789,14 +1789,14 @@ class TestDataCache(unittest.TestCase):
|
|||
data = {'some': 'data'}
|
||||
expire_delta = dt.timedelta(seconds=5)
|
||||
|
||||
data_cache.set_cache_item(cache_name='row_cache', data=data, key=key, expire_delta=expire_delta)
|
||||
cached_item = data_cache.get_cache_item(key, cache_name='row_cache')
|
||||
data_cache.serialized_datacache_insert(cache_name='row_cache', data=data, key=key, expire_delta=expire_delta)
|
||||
cached_item = data_cache.get_serialized_datacache(key, cache_name='row_cache')
|
||||
self.assertEqual(cached_item, data, "Failed to store and retrieve data with expiration")
|
||||
|
||||
# Wait for expiration to occur (ensure data is removed after expiration)
|
||||
import time
|
||||
time.sleep(6)
|
||||
expired_item = data_cache.get_cache_item(key, cache_name='row_cache')
|
||||
expired_item = data_cache.get_serialized_datacache(key, cache_name='row_cache')
|
||||
self.assertIsNone(expired_item, "Data was not removed after expiration time")
|
||||
|
||||
# -------------------------
|
||||
|
|
@ -1804,7 +1804,7 @@ class TestDataCache(unittest.TestCase):
|
|||
# -------------------------
|
||||
# Case 9: Raise ValueError if unsupported cache type is provided
|
||||
with self.assertRaises(KeyError, msg="Unsupported cache type for 'unsupported_cache'"):
|
||||
data_cache.set_cache_item(cache_name='unsupported_cache', data={'some': 'data'}, key='some_key')
|
||||
data_cache.serialized_datacache_insert(cache_name='unsupported_cache', data={'some': 'data'}, key='some_key')
|
||||
|
||||
def test_calculate_and_cache_indicator(self):
|
||||
# Testing the calculation and caching of an indicator through DataCache (which includes IndicatorCache
|
||||
|
|
@ -1899,7 +1899,7 @@ class TestDataCache(unittest.TestCase):
|
|||
)
|
||||
|
||||
# Check if the data was cached after the first calculation
|
||||
cached_data = self.data.get_cache_item(cache_key, cache_name='indicator_data')
|
||||
cached_data = self.data.get_serialized_datacache(cache_key, cache_name='indicator_data')
|
||||
print(f"Cached Data after first calculation: {cached_data}")
|
||||
|
||||
# Ensure the data was cached correctly
|
||||
|
|
@ -1942,7 +1942,7 @@ class TestDataCache(unittest.TestCase):
|
|||
cache_key = self.data._make_indicator_key('BTC/USD', '5m', 'binance', 'SMA', properties['period'])
|
||||
|
||||
# Store the cached data as DataFrame (no need for to_dict('records'))
|
||||
self.data.set_cache_item(cache_name='indicator_data',data=cached_data, key=cache_key)
|
||||
self.data.serialized_datacache_insert(cache_name='indicator_data', data=cached_data, key=cache_key)
|
||||
|
||||
# Print cached data to inspect its range
|
||||
print("Cached data time range:")
|
||||
|
|
|
|||
Loading…
Reference in New Issue