diff --git a/src/BrighterTrades.py b/src/BrighterTrades.py index 8b2bd36..c98567b 100644 --- a/src/BrighterTrades.py +++ b/src/BrighterTrades.py @@ -645,11 +645,14 @@ class BrighterTrades: logger.error(f"Error editing strategy: {e}", exc_info=True) return {"success": False, "message": "An unexpected error occurred while editing the strategy"} - def delete_strategy(self, data: dict) -> dict: + def delete_strategy(self, data: dict, user_id: int = None) -> dict: """ Deletes the specified strategy identified by tbl_key from the strategies instance. - :param data: Dictionary containing 'tbl_key' and 'user_name'. + Security: Ownership is verified before deletion to prevent unauthorized access. + + :param data: Dictionary containing 'tbl_key'. + :param user_id: The authenticated user ID (for ownership verification). :return: A dictionary indicating success or failure with an appropriate message and the tbl_key. """ # Validate tbl_key @@ -657,8 +660,12 @@ class BrighterTrades: if not tbl_key: return {"success": False, "message": "tbl_key not provided", "tbl_key": None} - # Call the delete_strategy method to remove the strategy - result = self.strategies.delete_strategy(tbl_key=tbl_key) + # Get user_id from data if not passed directly (backwards compatibility) + if user_id is None: + user_id = data.get('user_id') or data.get('userId') + + # Call the delete_strategy method to remove the strategy (with ownership check) + result = self.strategies.delete_strategy(tbl_key=tbl_key, user_id=user_id) # Return the result with tbl_key included if result.get('success'): @@ -724,41 +731,26 @@ class BrighterTrades: strategy_row = strategy_data.iloc[0] strategy_name = strategy_row.get('name', 'Unknown') - # Authorization check: user must own the strategy or strategy must be public + # Authorization check: user must own the strategy OR be subscribed to it strategy_creator = strategy_row.get('creator') - is_public = bool(strategy_row.get('public', False)) + try: + creator_id = int(strategy_creator) if strategy_creator is not None else None + except (ValueError, TypeError): + creator_id = None - if not is_public: - requester_name = None - try: - requester_name = self.users.get_username(user_id=user_id) - except Exception: - logger.warning(f"Unable to resolve username for user id '{user_id}'.") + is_owner = (creator_id == user_id) if (creator_id is not None and user_id is not None) else False + is_subscribed = self.strategies.is_subscribed(user_id, strategy_id) - creator_str = str(strategy_creator) if strategy_creator is not None else '' - requester_id_str = str(user_id) + # Must be owner OR subscribed to run + if not is_owner and not is_subscribed: + return { + "success": False, + "message": "Subscribe to this strategy first" + } - creator_matches_user = False - if creator_str: - # Support creator being stored as user_name or user_id. - creator_matches_user = ( - (requester_name is not None and creator_str == requester_name) or - (creator_str == requester_id_str) - ) - - if not creator_matches_user and creator_str: - # Also check if creator is a username that resolves to the current user id. - try: - creator_id = self.get_user_info(user_name=creator_str, info='User_id') - creator_matches_user = creator_id == user_id - except Exception: - creator_matches_user = False - - if not creator_matches_user: - return { - "success": False, - "message": "You do not have permission to run this strategy." - } + # For subscribed strategies, use creator's indicators + # This ensures subscribers run with the creator's indicator definitions + indicator_owner_id = creator_id if is_subscribed and not is_owner else None # Check if already running instance_key = (user_id, strategy_id, effective_mode) @@ -917,6 +909,7 @@ class BrighterTrades: testnet=actual_testnet, max_position_pct=max_position_pct, circuit_breaker_pct=circuit_breaker_pct, + indicator_owner_id=indicator_owner_id, # For subscribed strategies, use creator's indicators ) # Store the active instance @@ -1115,11 +1108,14 @@ class BrighterTrades: def get_strategies_json(self, user_id) -> list: """ - Retrieve all public and user strategies from the strategies instance and return them as a list of dictionaries. + Retrieve strategies that the user owns or is subscribed to. + Returns owned strategies with full data and subscribed strategies with redacted internals. + + :param user_id: The user's ID. :return: list - A list of dictionaries, each representing a strategy. """ - return self.strategies.get_all_strategies(user_id, 'dict') + return self.strategies.get_user_strategies(user_id) def connect_or_config_exchange(self, user_name: str, exchange_name: str, api_keys: dict = None) -> dict: """ @@ -1427,14 +1423,21 @@ class BrighterTrades: return - def process_incoming_message(self, msg_type: str, msg_data: dict, socket_conn_id: str) -> dict | None: - + def process_incoming_message( + self, + msg_type: str, + msg_data: dict, + socket_conn_id: str, + authenticated_user_id: int = None + ) -> dict | None: """ Processes an incoming message and performs the corresponding actions based on the message type and data. - :param socket_conn_id: The WebSocket connection to send updates back to the client. :param msg_type: The type of the incoming message. :param msg_data: The data associated with the incoming message. + :param socket_conn_id: The WebSocket connection to send updates back to the client. + :param authenticated_user_id: Server-verified user ID from socket mapping. If provided, this takes + precedence over any user identity in msg_data. :return: dict|None - A dictionary containing the response message and data, or None if no response is needed or no data is found to ensure the WebSocket channel isn't burdened with unnecessary communication. @@ -1447,8 +1450,14 @@ class BrighterTrades: """ Formats a standard reply message. """ return {"reply": reply_msg, "data": reply_data} - user_name = self.resolve_user_name(msg_data) - user_id = self.resolve_user_id(msg_data, user_name=user_name) + # Use authenticated_user_id if provided (from secure socket mapping) + # Otherwise fall back to resolving from msg_data (for backwards compatibility) + if authenticated_user_id is not None: + user_id = authenticated_user_id + user_name = self.users.get_username(user_id=authenticated_user_id) + else: + user_name = self.resolve_user_name(msg_data) + user_id = self.resolve_user_id(msg_data, user_name=user_name) if user_name: msg_data.setdefault('user_name', user_name) @@ -1470,8 +1479,9 @@ class BrighterTrades: elif request_for == 'strategies': if user_id is None: return standard_reply("strategy_error", {"message": "User not specified"}) - if strategies := self.get_strategies_json(user_id): - return standard_reply("strategies", strategies) + # Always return response, even if empty list + strategies = self.get_strategies_json(user_id) + return standard_reply("strategies", strategies or []) elif request_for == 'trades': trades = self.get_trades(user_id) @@ -1496,7 +1506,8 @@ class BrighterTrades: }) if msg_type == 'delete_strategy': - result = self.delete_strategy(msg_data) + # Pass authenticated user_id for ownership verification + result = self.delete_strategy(msg_data, user_id=user_id) if result.get('success'): return standard_reply("strategy_deleted", { "message": result.get('message'), @@ -1716,6 +1727,39 @@ class BrighterTrades: logger.error(f"Error getting strategy status: {e}", exc_info=True) return standard_reply("strategy_status_error", {"message": f"Failed to get status: {str(e)}"}) + # ===== Strategy Subscription Handlers ===== + + if msg_type == 'subscribe_strategy': + strategy_tbl_key = msg_data.get('strategy_tbl_key') or msg_data.get('tbl_key') + if not strategy_tbl_key: + return standard_reply("subscription_error", {"message": "Strategy not specified"}) + + result = self.strategies.subscribe_to_strategy(user_id, strategy_tbl_key) + if result.get('success'): + return standard_reply("strategy_subscribed", result) + else: + return standard_reply("subscription_error", result) + + if msg_type == 'unsubscribe_strategy': + strategy_tbl_key = msg_data.get('strategy_tbl_key') or msg_data.get('tbl_key') + if not strategy_tbl_key: + return standard_reply("subscription_error", {"message": "Strategy not specified"}) + + result = self.strategies.unsubscribe_from_strategy(user_id, strategy_tbl_key) + if result.get('success'): + return standard_reply("strategy_unsubscribed", result) + else: + return standard_reply("subscription_error", result) + + if msg_type == 'get_public_strategies': + # Returns all public strategies for the browse dialog + try: + strategies = self.strategies.get_public_strategies_catalog(user_id) + return standard_reply("public_strategies", {"strategies": strategies or []}) + except Exception as e: + logger.error(f"Error getting public strategies: {e}", exc_info=True) + return standard_reply("public_strategies_error", {"message": str(e)}) + if msg_type == 'reply': # If the message is a reply log the response to the terminal. print(f"\napp.py:Received reply: {msg_data}") diff --git a/src/DataCache_v3.py b/src/DataCache_v3.py index 8f09c2e..43ec29c 100644 --- a/src/DataCache_v3.py +++ b/src/DataCache_v3.py @@ -923,12 +923,24 @@ class DatabaseInteractions(SnapshotDataCache): # Case 1: Retrieve all rows from the cache if isinstance(cache, RowBasedCache): - return pd.DataFrame.from_dict(cache.get_all_items(), orient='index') + result = pd.DataFrame.from_dict(cache.get_all_items(), orient='index') + if not result.empty: + return result elif isinstance(cache, TableBasedCache): - return cache.get_all_items() + result = cache.get_all_items() + if not result.empty: + return result - # Case 2: Fallback to retrieve all rows from the database using Database class - return self.db.get_all_rows(cache_name) + # Case 2: Fallback to retrieve all rows from the database if cache is empty + db_result = self.db.get_all_rows(cache_name) + + # Populate the cache with database results for future queries + if db_result is not None and not db_result.empty and cache is not None: + if isinstance(cache, TableBasedCache): + cache.add_table(df=db_result) + # For RowBasedCache, we'd need a key which we don't have here + + return db_result if db_result is not None else pd.DataFrame() def _fetch_from_database_with_list_support(self, cache_name: str, filter_vals: List[tuple[str, Any]]) -> pd.DataFrame: diff --git a/src/Strategies.py b/src/Strategies.py index d8f6072..9599160 100644 --- a/src/Strategies.py +++ b/src/Strategies.py @@ -62,6 +62,19 @@ class Strategies: ] ) + # Create a cache for strategy subscriptions + self.data_cache.create_cache( + name='strategy_subscriptions', + cache_type='table', + size_limit=1000, + eviction_policy='deny', + default_expiration=dt.timedelta(hours=24), + columns=["id", "user_id", "strategy_tbl_key", "subscribed_at"] + ) + + # Ensure the subscriptions table exists in SQLite + self._ensure_subscriptions_table() + # Initialize default settings self.default_timeframe = '5m' self.default_exchange = 'Binance' @@ -97,6 +110,46 @@ class Strategies: except Exception as e: logger.warning(f"Migration check failed (may be expected on fresh install): {e}") + def _ensure_subscriptions_table(self) -> None: + """ + Create the strategy_subscriptions table if it doesn't exist. + + Note: execute_sql runs single statements only, so we run each statement separately. + We don't rely on FK ON DELETE CASCADE since SQLite requires PRAGMA foreign_keys = ON. + """ + import config + + try: + db = self.data_cache.db + + # Statement 1: Create table + db.execute_sql(""" + CREATE TABLE IF NOT EXISTS strategy_subscriptions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + strategy_tbl_key TEXT NOT NULL, + subscribed_at TEXT NOT NULL, + UNIQUE(user_id, strategy_tbl_key) + ) + """) + + # Statement 2: Create user index + db.execute_sql(""" + CREATE INDEX IF NOT EXISTS idx_subscriptions_user + ON strategy_subscriptions(user_id) + """) + + # Statement 3: Create strategy index + db.execute_sql(""" + CREATE INDEX IF NOT EXISTS idx_subscriptions_strategy + ON strategy_subscriptions(strategy_tbl_key) + """) + + logger.info("Strategy subscriptions table initialized") + + except Exception as e: + logger.warning(f"Failed to create subscriptions table (may already exist): {e}") + def update(self, candle_data: dict = None) -> list: """ Update all active strategy instances with new price data. @@ -188,6 +241,7 @@ class Strategies: testnet: bool = True, max_position_pct: float = 0.5, circuit_breaker_pct: float = -0.10, + indicator_owner_id: int = None, ) -> StrategyInstance: """ Factory method to create the appropriate strategy instance based on mode. @@ -206,6 +260,7 @@ class Strategies: :param testnet: Use testnet for live trading (default True for safety). :param max_position_pct: Maximum position size as % of balance for live trading. :param circuit_breaker_pct: Drawdown % to halt trading for live trading. + :param indicator_owner_id: For subscribed strategies, the creator's user ID for indicator lookup. :return: Strategy instance appropriate for the mode. """ mode = mode.lower() @@ -226,6 +281,7 @@ class Strategies: commission=commission, slippage=slippage if slippage > 0 else 0.0005, price_provider=price_provider, + indicator_owner_id=indicator_owner_id, ) elif mode == TradingMode.BACKTEST: @@ -240,6 +296,7 @@ class Strategies: indicators=self.indicators_manager, trades=self.trades, edm_client=self.edm_client, + indicator_owner_id=indicator_owner_id, ) elif mode == TradingMode.LIVE: @@ -280,6 +337,7 @@ class Strategies: slippage=slippage, max_position_pct=max_position_pct, circuit_breaker_pct=circuit_breaker_pct, + indicator_owner_id=indicator_owner_id, ) else: @@ -295,6 +353,7 @@ class Strategies: indicators=self.indicators_manager, trades=self.trades, edm_client=self.edm_client, + indicator_owner_id=indicator_owner_id, ) def _save_strategy(self, strategy_data: dict, default_source: dict) -> dict: @@ -319,6 +378,16 @@ class Strategies: ) if existing_strategy.empty: return {"success": False, "message": "Strategy not found."} + + # Check if strategy is being made private (public -> private transition) + was_public = bool(existing_strategy.iloc[0].get('public', 0)) + is_public = bool(strategy_data.get('public', 0)) + + if was_public and not is_public: + # Strategy being made private - remove all subscriptions + self._remove_all_subscriptions_for_strategy(tbl_key) + logger.info(f"Strategy '{tbl_key}' made private - removed all subscriptions") + else: # Check for duplicate strategy name filter_conditions = [ @@ -333,6 +402,13 @@ class Strategies: if not existing_strategy.empty: return {"success": False, "message": "A strategy with this name already exists"} + # Check unique public name requirement + is_public = bool(strategy_data.get('public', 0)) + if is_public: + exclude_key = tbl_key if is_edit else None + if not self._check_unique_public_name(strategy_data['name'], exclude_tbl_key=exclude_key): + return {"success": False, "message": "A public strategy with this name already exists"} + # Validate and serialize 'workspace' workspace_data = strategy_data.get('workspace') if not isinstance(workspace_data, str) or not workspace_data.strip(): @@ -448,14 +524,352 @@ class Strategies: """ return self._save_strategy(strategy_data, default_source) - def delete_strategy(self, tbl_key: str) -> dict: + def verify_ownership(self, user_id: int, strategy_tbl_key: str) -> bool: + """ + Verify that the user owns the strategy. + + :param user_id: The ID of the user to check. + :param strategy_tbl_key: The tbl_key of the strategy. + :return: True if the user owns the strategy, False otherwise. + """ + if user_id is None or not strategy_tbl_key: + return False + + strategy = self.get_strategy_by_tbl_key(strategy_tbl_key) + if not strategy: + return False + + creator = strategy.get('creator') + if creator is None: + return False + + try: + return int(creator) == int(user_id) + except (ValueError, TypeError): + return False + + def _remove_all_subscriptions_for_strategy(self, strategy_tbl_key: str) -> None: + """ + Remove all subscriptions to a strategy. + + Called when a strategy is deleted or made private. + We explicitly delete subscriptions rather than relying on FK cascade + since SQLite requires PRAGMA foreign_keys = ON. + + :param strategy_tbl_key: The tbl_key of the strategy. + """ + try: + # Remove from cache and database (remove_row_from_datacache handles both) + self.data_cache.remove_row_from_datacache( + cache_name='strategy_subscriptions', + filter_vals=[('strategy_tbl_key', strategy_tbl_key)] + ) + + logger.info(f"Removed all subscriptions for strategy '{strategy_tbl_key}'") + + except Exception as e: + logger.warning(f"Failed to remove subscriptions for strategy '{strategy_tbl_key}': {e}") + + def _get_username_for_id(self, user_id: int) -> str: + """ + Get username for a user ID. Used for displaying creator names. + + :param user_id: The user's ID. + :return: The username or a fallback string. + """ + if user_id is None: + return "Unknown" + + try: + # Direct lookup via data_cache + users_cache = self.data_cache.get_rows_from_datacache( + cache_name='users', + filter_vals=[('id', int(user_id))], + include_tbl_key=True + ) + if not users_cache.empty: + return users_cache.iloc[0].get('user_name', f'User #{user_id}') + except Exception as e: + logger.debug(f"Failed to get username for user_id {user_id}: {e}") + + return f'User #{user_id}' # Fallback + + def _redact_strategy_internals(self, strategy: dict) -> dict: + """ + Remove sensitive internals from strategy for non-owners. + + This prevents subscribers from seeing the strategy's implementation details. + + :param strategy: The full strategy dictionary. + :return: A redacted copy of the strategy. + """ + redacted = strategy.copy() + redacted['workspace'] = None + redacted['code'] = None + redacted['strategy_components'] = None # CRITICAL: contains generated code + return redacted + + def is_subscribed(self, user_id: int, strategy_tbl_key: str) -> bool: + """ + Check if a user is subscribed to a strategy. + + :param user_id: The user's ID. + :param strategy_tbl_key: The strategy's tbl_key. + :return: True if subscribed, False otherwise. + """ + if user_id is None or not strategy_tbl_key: + return False + + try: + subscriptions = self.data_cache.get_rows_from_datacache( + cache_name='strategy_subscriptions', + filter_vals=[('user_id', user_id), ('strategy_tbl_key', strategy_tbl_key)] + ) + return not subscriptions.empty + except Exception: + return False + + def _is_strategy_running_for_user(self, user_id: int, strategy_tbl_key: str) -> bool: + """ + Check if user has a running instance of the given strategy. + + :param user_id: The user's ID. + :param strategy_tbl_key: The strategy's tbl_key. + :return: True if the strategy is running for this user. + """ + for (uid, sid, mode) in self.active_instances.keys(): + if uid == user_id and sid == strategy_tbl_key: + return True + return False + + def subscribe_to_strategy(self, user_id: int, strategy_tbl_key: str) -> dict: + """ + Subscribe a user to a public strategy. + + :param user_id: The user's ID. + :param strategy_tbl_key: The strategy's tbl_key. + :return: A dictionary indicating success or failure. + """ + if user_id is None or not strategy_tbl_key: + return {"success": False, "message": "Invalid user or strategy"} + + # Get the strategy + strategy = self.get_strategy_by_tbl_key(strategy_tbl_key) + if not strategy: + return {"success": False, "message": "Strategy not found"} + + # Check if strategy is public + if not strategy.get('public'): + return {"success": False, "message": "Cannot subscribe to private strategy"} + + # Check if user is the owner + if self.verify_ownership(user_id, strategy_tbl_key): + return {"success": False, "message": "You cannot subscribe to your own strategy"} + + # Check if already subscribed + if self.is_subscribed(user_id, strategy_tbl_key): + return {"success": False, "message": "Already subscribed to this strategy"} + + try: + # Add subscription via datacache (handles both cache and DB insert) + subscribed_at = dt.datetime.now(dt.timezone.utc).isoformat() + self.data_cache.insert_row_into_datacache( + cache_name='strategy_subscriptions', + columns=("user_id", "strategy_tbl_key", "subscribed_at"), + values=(user_id, strategy_tbl_key, subscribed_at) + ) + + logger.info(f"User {user_id} subscribed to strategy '{strategy_tbl_key}'") + return { + "success": True, + "message": "Successfully subscribed to strategy", + "strategy_name": strategy.get('name') + } + + except Exception as e: + logger.error(f"Failed to subscribe user {user_id} to strategy '{strategy_tbl_key}': {e}") + return {"success": False, "message": f"Failed to subscribe: {str(e)}"} + + def unsubscribe_from_strategy(self, user_id: int, strategy_tbl_key: str) -> dict: + """ + Unsubscribe a user from a strategy. + + :param user_id: The user's ID. + :param strategy_tbl_key: The strategy's tbl_key. + :return: A dictionary indicating success or failure. + """ + if user_id is None or not strategy_tbl_key: + return {"success": False, "message": "Invalid user or strategy"} + + # Check if subscribed + if not self.is_subscribed(user_id, strategy_tbl_key): + return {"success": False, "message": "Not subscribed to this strategy"} + + # Check if strategy is running for this user + if self._is_strategy_running_for_user(user_id, strategy_tbl_key): + return {"success": False, "message": "Stop the strategy before unsubscribing"} + + try: + # Remove from cache and database (remove_row_from_datacache handles both) + self.data_cache.remove_row_from_datacache( + cache_name='strategy_subscriptions', + filter_vals=[('user_id', user_id), ('strategy_tbl_key', strategy_tbl_key)] + ) + + logger.info(f"User {user_id} unsubscribed from strategy '{strategy_tbl_key}'") + return {"success": True, "message": "Successfully unsubscribed from strategy"} + + except Exception as e: + logger.error(f"Failed to unsubscribe user {user_id} from strategy '{strategy_tbl_key}': {e}") + return {"success": False, "message": f"Failed to unsubscribe: {str(e)}"} + + def get_user_strategies(self, user_id: int) -> list: + """ + Get strategies that a user owns OR is subscribed to. + + This is the primary method for getting a user's strategy list. + Subscribed strategies are redacted (no code/workspace visible). + + :param user_id: The user's ID. + :return: List of strategy dictionaries. + """ + result = [] + + if user_id is None: + return result + + # Get user's own strategies (public or private) + owned = self.data_cache.get_rows_from_datacache( + cache_name='strategies', + filter_vals=[('creator', user_id)], + include_tbl_key=True + ) + + # Add owned strategies (full access) + if owned is not None and not owned.empty: + for _, row in owned.iterrows(): + strat = row.to_dict() + strat['is_owner'] = True + strat['is_subscribed'] = False + # Deserialize JSON fields + if isinstance(strat.get('default_source'), str): + try: + strat['default_source'] = json.loads(strat['default_source']) + except json.JSONDecodeError: + strat['default_source'] = {} + if isinstance(strat.get('stats'), str): + try: + strat['stats'] = json.loads(strat['stats']) + except json.JSONDecodeError: + strat['stats'] = {} + result.append(strat) + + # Get subscribed strategy keys + subscriptions = self.data_cache.get_rows_from_datacache( + cache_name='strategy_subscriptions', + filter_vals=[('user_id', user_id)] + ) + + # Add subscribed strategies (redacted) + if subscriptions is not None and not subscriptions.empty: + for _, sub in subscriptions.iterrows(): + strategy_key = sub.get('strategy_tbl_key') + strategy = self.get_strategy_by_tbl_key(strategy_key) + + if strategy and strategy.get('public'): # Still public + strat = self._redact_strategy_internals(strategy) + strat['is_owner'] = False + strat['is_subscribed'] = True + strat['creator_name'] = self._get_username_for_id(strategy.get('creator')) + result.append(strat) + + return result + + def get_public_strategies_catalog(self, user_id: int) -> list: + """ + Get all public strategies for the browse dialog (sanitized). + + Does not include the user's own strategies. + + :param user_id: The user's ID (to exclude their own strategies). + :return: List of sanitized strategy dictionaries. + """ + result = [] + + # Get all public strategies + public = self.data_cache.get_rows_from_datacache( + cache_name='strategies', + filter_vals=[('public', 1)], + include_tbl_key=True + ) + + if public is None or public.empty: + return result + + for _, row in public.iterrows(): + creator = row.get('creator') + tbl_key = row.get('tbl_key') + + # Skip user's own strategies + try: + if user_id is not None and int(creator) == int(user_id): + continue + except (ValueError, TypeError): + pass + + strat = self._redact_strategy_internals(row.to_dict()) + strat['creator_name'] = self._get_username_for_id(creator) + strat['is_subscribed'] = self.is_subscribed(user_id, tbl_key) + result.append(strat) + + return result + + def _check_unique_public_name(self, name: str, exclude_tbl_key: str = None) -> bool: + """ + Check if a public strategy name is unique. + + :param name: The strategy name to check. + :param exclude_tbl_key: Optional tbl_key to exclude (for edit operations). + :return: True if the name is unique among public strategies. + """ + public_with_name = self.data_cache.get_rows_from_datacache( + cache_name='strategies', + filter_vals=[('public', 1), ('name', name)], + include_tbl_key=True + ) + + if public_with_name is None or public_with_name.empty: + return True # Name is unique + + # Check all matches, not just first + for _, row in public_with_name.iterrows(): + if row.get('tbl_key') != exclude_tbl_key: + return False # Found another public strategy with same name + + return True + + def delete_strategy(self, tbl_key: str, user_id: int = None) -> dict: """ Deletes a strategy identified by its tbl_key. + Security: If user_id is provided, ownership is verified before deletion. + Also removes all subscriptions to this strategy (explicit cleanup, not FK cascade). + :param tbl_key: The unique identifier of the strategy to delete. + :param user_id: The ID of the user requesting deletion (for ownership check). :return: A dictionary indicating success or failure with an appropriate message. """ + # Ownership check if user_id is provided + if user_id is not None: + if not self.verify_ownership(user_id, tbl_key): + logger.warning(f"User {user_id} attempted to delete strategy '{tbl_key}' without ownership") + return {"success": False, "message": "You don't own this strategy."} + try: + # Remove all subscriptions first (explicit cleanup, don't rely on FK cascade) + self._remove_all_subscriptions_for_strategy(tbl_key) + + # Then delete the strategy self.data_cache.remove_row_from_datacache( cache_name='strategies', filter_vals=[('tbl_key', tbl_key)] diff --git a/src/StrategyInstance.py b/src/StrategyInstance.py index 3a797db..16be176 100644 --- a/src/StrategyInstance.py +++ b/src/StrategyInstance.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) class StrategyInstance: def __init__(self, strategy_instance_id: str, strategy_id: str, strategy_name: str, user_id: int, generated_code: str, data_cache: Any, indicators: Any | None, trades: Any | None, - edm_client: Any = None): + edm_client: Any = None, indicator_owner_id: int = None): """ Initializes a StrategyInstance. @@ -28,6 +28,8 @@ class StrategyInstance: :param indicators: Reference to the Indicators manager. :param trades: Reference to the Trades manager. :param edm_client: Reference to the EDM client for candle data. + :param indicator_owner_id: For subscribed strategies, the creator's user ID for indicator lookup. + If None, uses user_id. """ # Initialize the backtrader_strategy attribute self.backtrader_strategy = None # Will be set by Backtrader's MappedStrategy @@ -42,6 +44,9 @@ class StrategyInstance: self.trades = trades self.edm_client = edm_client + # For subscribed strategies, indicator lookup uses the creator's indicators + self.indicator_owner_id = indicator_owner_id if indicator_owner_id is not None else user_id + # Initialize context variables self.flags: dict[str, Any] = {} self.variables: dict[str, Any] = {} @@ -741,17 +746,22 @@ class StrategyInstance: """ Retrieves the latest value of an indicator. + For subscribed strategies, indicators are looked up using indicator_owner_id + (the strategy creator's ID) rather than the running user's ID. + :param indicator_name: Name of the indicator. :param output_field: Specific field of the indicator. :return: Indicator value. """ - logger.debug(f"StrategyInstance is Retrieving indicator '{indicator_name}' from Indicators for user '{self.user_id}'.") + # Use indicator_owner_id for lookup (creator's indicators for subscribed strategies) + lookup_user_id = self.indicator_owner_id + logger.debug(f"StrategyInstance is Retrieving indicator '{indicator_name}' from Indicators for user '{lookup_user_id}'.") try: - user_indicators = self.indicators.get_indicator_list(user_id=self.user_id) + user_indicators = self.indicators.get_indicator_list(user_id=lookup_user_id) indicator = user_indicators.get(indicator_name) if not indicator: - logger.error(f"Indicator '{indicator_name}' not found for user '{self.user_id}'.") + logger.error(f"Indicator '{indicator_name}' not found for user '{lookup_user_id}'.") return None indicator_value = self.indicators.process_indicator(indicator) value = indicator_value.get(output_field, None) diff --git a/src/app.py b/src/app.py index 66fae19..374d343 100644 --- a/src/app.py +++ b/src/app.py @@ -69,6 +69,10 @@ brighter_trades = BrighterTrades(socketio) # Set server configuration globals. CORS_HEADERS = 'Content-Type' +# Socket ID to authenticated user_id mapping +# This is the source of truth for WebSocket authentication - never trust client payloads +socket_user_mapping = {} # request.sid -> user_id + # Set the app directly with the globals. app.config.from_object(__name__) app.secret_key = '1_BAD_secrete_KEY_is_not_2' @@ -343,27 +347,59 @@ def index(): @socketio.on('connect') def handle_connect(): - user_name = request.args.get('user_name') - if not user_name: - user_name = resolve_user_name({ - 'userId': request.args.get('userId'), - 'user_id': request.args.get('user_id') - }) - if user_name and brighter_trades.get_user_info(user_name=user_name, info='Is logged in?'): - # Join a room specific to the user for targeted messaging - room = user_name # You can choose an appropriate room naming strategy - join_room(room) - emit('message', {'reply': 'connected', 'data': 'Connection established'}) - else: - emit('message', {'reply': 'error', 'data': 'User not authenticated'}) - # Disconnect the client if not authenticated + """ + Handle WebSocket connection. + + Security: User identity is determined from Flask session (set during HTTP login), + NOT from query parameters. This prevents identity spoofing. + """ + # Get user from Flask session - this is set during HTTP login (/login route) + session_user = session.get('user') + + if not session_user: + # No session user - reject connection + emit('message', {'reply': 'error', 'data': 'User not authenticated - no session'}) disconnect() + return + + # Verify user is logged in + if not brighter_trades.get_user_info(user_name=session_user, info='Is logged in?'): + emit('message', {'reply': 'error', 'data': 'User not logged in'}) + disconnect() + return + + # Get user_id from username + try: + user_id = brighter_trades.get_user_info(user_name=session_user, info='User_id') + except Exception: + emit('message', {'reply': 'error', 'data': 'Could not resolve user identity'}) + disconnect() + return + + # Store the authenticated user_id for this socket - THIS IS THE SOURCE OF TRUTH + socket_user_mapping[request.sid] = { + 'user_id': user_id, + 'user_name': session_user + } + + # Join a room specific to the user for targeted messaging + join_room(session_user) + emit('message', {'reply': 'connected', 'data': 'Connection established'}) + + +@socketio.on('disconnect') +def handle_disconnect(): + """Clean up socket mapping on disconnect.""" + socket_user_mapping.pop(request.sid, None) @socketio.on('message') def handle_message(data): """ Handle incoming JSON messages with authentication. + + Security: User identity is determined from socket_user_mapping (set at connect time), + NOT from message payload. This prevents identity spoofing attacks. """ # Validate input if 'message_type' not in data or 'data' not in data: @@ -372,28 +408,30 @@ def handle_message(data): msg_type, msg_data = data['message_type'], data['data'] - # Extract user_name from the incoming message data - user_name = resolve_user_name(msg_data) - if not user_name: - emit('message', {"success": False, "message": "User not specified"}) + # Get authenticated user from our mapping - THIS IS THE SOURCE OF TRUTH + # DO NOT trust msg_data.get('user_name') or msg_data.get('user_id') + auth_info = socket_user_mapping.get(request.sid) + if not auth_info: + emit('message', {"success": False, "message": "Not authenticated"}) return - msg_data.setdefault('user_name', user_name) - try: - user_id = brighter_trades.get_user_info(user_name=user_name, info='User_id') - if user_id is not None: - msg_data.setdefault('user_id', user_id) - msg_data.setdefault('userId', user_id) - except Exception: - pass + # Use server-verified identity, ignoring any identity claims in payload + authenticated_user_id = auth_info['user_id'] + authenticated_user_name = auth_info['user_name'] - # Check if the user is logged in - if not brighter_trades.get_user_info(user_name=user_name, info='Is logged in?'): - emit('message', {"success": False, "message": "User not logged in"}) - return + # Inject authenticated identity into msg_data (overwriting any client-provided values) + msg_data['user_name'] = authenticated_user_name + msg_data['user'] = authenticated_user_name + msg_data['user_id'] = authenticated_user_id + msg_data['userId'] = authenticated_user_id - # Process the incoming message based on the type - resp = brighter_trades.process_incoming_message(msg_type=msg_type, msg_data=msg_data, socket_conn_id=request.sid) + # Process the incoming message with server-verified user identity + resp = brighter_trades.process_incoming_message( + msg_type=msg_type, + msg_data=msg_data, + socket_conn_id=request.sid, + authenticated_user_id=authenticated_user_id + ) # Send the response back to the client if resp: @@ -479,6 +517,21 @@ def signout(): return redirect('/') if brighter_trades.log_user_in_out(user_name=user_name, cmd='logout'): + # Disconnect any active WebSocket connections for this user + # to prevent continued access after logout + sids_to_remove = [] + for sid, auth_info in list(socket_user_mapping.items()): + if auth_info.get('user_name') == user_name: + sids_to_remove.append(sid) + + for sid in sids_to_remove: + socket_user_mapping.pop(sid, None) + try: + # Disconnect the socket (they'll need to re-authenticate) + socketio.server.disconnect(sid) + except Exception: + pass # Socket may already be closed + # If the user was logged out successfully delete the session var. del session['user'] diff --git a/src/backtest_strategy_instance.py b/src/backtest_strategy_instance.py index 39ae834..c007025 100644 --- a/src/backtest_strategy_instance.py +++ b/src/backtest_strategy_instance.py @@ -22,13 +22,14 @@ class BacktestStrategyInstance(StrategyInstance): def __init__(self, strategy_instance_id: str, strategy_id: str, strategy_name: str, user_id: int, generated_code: str, data_cache: Any, indicators: Any | None, trades: Any | None, backtrader_strategy: Optional[bt.Strategy] = None, - edm_client: Any = None): + edm_client: Any = None, indicator_owner_id: int = None): # Set 'self.broker' and 'self.backtrader_strategy' to None before calling super().__init__() self.broker = None self.backtrader_strategy = None super().__init__(strategy_instance_id, strategy_id, strategy_name, user_id, - generated_code, data_cache, indicators, trades, edm_client) + generated_code, data_cache, indicators, trades, edm_client, + indicator_owner_id=indicator_owner_id) # Set the backtrader_strategy instance after super().__init__() self.backtrader_strategy = backtrader_strategy diff --git a/src/backtesting.py b/src/backtesting.py index e348f50..9dfd652 100644 --- a/src/backtesting.py +++ b/src/backtesting.py @@ -321,13 +321,19 @@ class Backtester: logger.error(f"Error preparing data feed: {e}") return pd.DataFrame() - def precompute_indicators(self, indicators_definitions: list, user_name: str, data_feed: pd.DataFrame) -> dict: + def precompute_indicators(self, indicators_definitions: list, user_name: str, data_feed: pd.DataFrame, + indicator_owner_id: int = None) -> dict: """ Precompute indicator values directly on the backtest data feed. IMPORTANT: This computes indicators on the actual backtest candle data, ensuring the indicator values align with the price data used in the backtest. Previously, this fetched fresh/latest candles which caused misalignment. + + :param indicators_definitions: List of indicator definitions needed. + :param user_name: The user running the backtest. + :param data_feed: The candle data for backtesting. + :param indicator_owner_id: For subscribed strategies, the creator's user ID for indicator lookup. """ import json as json_module # Local import to avoid conflicts @@ -363,11 +369,16 @@ class Backtester: indicator_outputs[indicator_name] = None # None indicates all outputs # Get user ID for indicator lookup - user_id = self.data_cache.get_datacache_item( - item_name='id', - cache_name='users', - filter_vals=('user_name', user_name) - ) + # For subscribed strategies, use indicator_owner_id (creator's ID) instead of running user's ID + if indicator_owner_id is not None: + user_id = indicator_owner_id + logger.info(f"[BACKTEST] Using indicator_owner_id {indicator_owner_id} for indicator lookup (subscribed strategy)") + else: + user_id = self.data_cache.get_datacache_item( + item_name='id', + cache_name='users', + filter_vals=('user_name', user_name) + ) logger.info(f"[BACKTEST] indicator_outputs to precompute: {indicator_outputs}") @@ -442,23 +453,29 @@ class Backtester: return precomputed_indicators - def _calculate_warmup_period(self, indicators_definitions: list, user_name: str) -> int: + def _calculate_warmup_period(self, indicators_definitions: list, user_name: str, indicator_owner_id: int = None) -> int: """ Calculate the maximum warmup period needed based on indicator periods. :param indicators_definitions: List of indicator definitions from strategy :param user_name: Username for looking up indicator configs + :param indicator_owner_id: For subscribed strategies, the creator's user ID for indicator lookup. :return: Maximum warmup period in candles """ import json as json_module max_period = 0 - user_id = self.data_cache.get_datacache_item( - item_name='id', - cache_name='users', - filter_vals=('user_name', user_name) - ) + # For subscribed strategies, use indicator_owner_id (creator's ID) instead of running user's ID + if indicator_owner_id is not None: + user_id = indicator_owner_id + logger.info(f"[BACKTEST] Using indicator_owner_id {indicator_owner_id} for warmup calculation (subscribed strategy)") + else: + user_id = self.data_cache.get_datacache_item( + item_name='id', + cache_name='users', + filter_vals=('user_name', user_name) + ) for indicator_def in indicators_definitions: indicator_name = indicator_def.get('name') @@ -498,11 +515,12 @@ class Backtester: } return timeframe_map.get(timeframe.lower(), 60) # Default to 1h - def prepare_backtest_data(self, msg_data: dict, strategy_components: dict) -> tuple: + def prepare_backtest_data(self, msg_data: dict, strategy_components: dict, indicator_owner_id: int = None) -> tuple: """ Prepare the data feed and precomputed indicators for backtesting. :param msg_data: Message data containing backtest parameters. :param strategy_components: Components of the user-defined strategy. + :param indicator_owner_id: For subscribed strategies, the creator's user ID for indicator lookup. :return: Tuple of (data_feed, precomputed_indicators). :raises ValueError: If data sources are invalid or data feed cannot be prepared. """ @@ -524,7 +542,7 @@ class Backtester: # Calculate warmup period needed for indicators indicators_definitions = strategy_components.get('indicators', []) - warmup_candles = self._calculate_warmup_period(indicators_definitions, user_name) + warmup_candles = self._calculate_warmup_period(indicators_definitions, user_name, indicator_owner_id=indicator_owner_id) # Get timeframe to calculate how far back to fetch for warmup timeframe = main_source.get('timeframe', '1h') @@ -549,7 +567,9 @@ class Backtester: raise ValueError("Data feed could not be prepared. Please check the data source.") # Precompute indicator values on the full dataset (including warmup candles) - precomputed_indicators = self.precompute_indicators(indicators_definitions, user_name, data_feed) + precomputed_indicators = self.precompute_indicators( + indicators_definitions, user_name, data_feed, indicator_owner_id=indicator_owner_id + ) # Now trim BOTH the data feed AND indicators to start at the user's original start_date # This ensures the first indicator values in the backtest have full warmup context @@ -740,6 +760,27 @@ class Backtester: tbl_key = msg_data.get('strategy') # Expecting tbl_key instead of strategy_name backtest_name = msg_data.get('backtest_name') # Use the client-provided backtest_name + # Authorization check: user must own the strategy OR be subscribed to it + strategy = self.strategies.get_strategy_by_tbl_key(tbl_key) + if not strategy: + return {"error": "Strategy not found."} + + strategy_creator = strategy.get('creator') + try: + creator_id = int(strategy_creator) if strategy_creator is not None else None + except (ValueError, TypeError): + creator_id = None + + is_owner = (creator_id == user_id) if (creator_id is not None and user_id is not None) else False + is_subscribed = self.strategies.is_subscribed(user_id, tbl_key) + + # Must be owner OR subscribed to run backtest + if not is_owner and not is_subscribed: + return {"error": "Subscribe to this strategy first"} + + # For subscribed strategies, use creator's indicators + indicator_owner_id = creator_id if is_subscribed and not is_owner else None + if not backtest_name: # If backtest_name is not provided, generate a unique name backtest_name = f"{tbl_key}_backtest" @@ -777,7 +818,9 @@ class Backtester: msg_data['trading_source'] = source try: - data_feed, precomputed_indicators = self.prepare_backtest_data(msg_data, strategy_components) + data_feed, precomputed_indicators = self.prepare_backtest_data( + msg_data, strategy_components, indicator_owner_id=indicator_owner_id + ) except ValueError as ve: logger.error(f"Error preparing backtest data: {ve}") return {"error": str(ve)} @@ -791,7 +834,8 @@ class Backtester: generated_code=strategy_components.get("generated_code", ""), data_cache=self.data_cache, indicators=None, # Custom handling in BacktestStrategyInstance - trades=None # Custom handling in BacktestStrategyInstance + trades=None, # Custom handling in BacktestStrategyInstance + indicator_owner_id=indicator_owner_id, # For subscribed strategies, use creator's indicators ) # Cache the backtest diff --git a/src/live_strategy_instance.py b/src/live_strategy_instance.py index 1aa6459..f3b1100 100644 --- a/src/live_strategy_instance.py +++ b/src/live_strategy_instance.py @@ -47,6 +47,7 @@ class LiveStrategyInstance(StrategyInstance): circuit_breaker_pct: float = -0.10, rate_limit: float = 2.0, edm_client: Any = None, + indicator_owner_id: int = None, ): """ Initialize the LiveStrategyInstance. @@ -67,6 +68,7 @@ class LiveStrategyInstance(StrategyInstance): :param max_position_pct: Maximum position size as percentage of balance (0.5 = 50%). :param circuit_breaker_pct: Drawdown percentage to trigger circuit breaker (-0.10 = -10%). :param rate_limit: API calls per second limit. + :param indicator_owner_id: For subscribed strategies, the creator's user ID for indicator lookup. """ # Safety checks if not testnet: @@ -102,7 +104,8 @@ class LiveStrategyInstance(StrategyInstance): # Initialize parent (will call _initialize_or_load_context) super().__init__( strategy_instance_id, strategy_id, strategy_name, user_id, - generated_code, data_cache, indicators, trades, edm_client + generated_code, data_cache, indicators, trades, edm_client, + indicator_owner_id=indicator_owner_id ) # Connect to exchange and sync state diff --git a/src/paper_strategy_instance.py b/src/paper_strategy_instance.py index db61998..48b7a28 100644 --- a/src/paper_strategy_instance.py +++ b/src/paper_strategy_instance.py @@ -38,6 +38,7 @@ class PaperStrategyInstance(StrategyInstance): slippage: float = 0.0005, price_provider: Any = None, edm_client: Any = None, + indicator_owner_id: int = None, ): """ Initialize the PaperStrategyInstance. @@ -54,6 +55,7 @@ class PaperStrategyInstance(StrategyInstance): :param commission: Commission rate for paper trades. :param slippage: Slippage rate for market orders. :param price_provider: Callable to get current prices. + :param indicator_owner_id: For subscribed strategies, the creator's user ID for indicator lookup. """ # Initialize the paper broker self.paper_broker = PaperBroker( @@ -69,7 +71,8 @@ class PaperStrategyInstance(StrategyInstance): super().__init__( strategy_instance_id, strategy_id, strategy_name, user_id, - generated_code, data_cache, indicators, trades, edm_client + generated_code, data_cache, indicators, trades, edm_client, + indicator_owner_id=indicator_owner_id ) # Initialize balance attributes from paper broker diff --git a/src/static/Strategies.js b/src/static/Strategies.js index 83ee5d4..16feaeb 100644 --- a/src/static/Strategies.js +++ b/src/static/Strategies.js @@ -1,3 +1,32 @@ +/** + * Escapes HTML special characters to prevent XSS attacks. + * @param {string} str - The string to escape. + * @returns {string} - The escaped string. + */ +function escapeHtml(str) { + if (str == null) return ''; + return String(str) + .replace(/&/g, '&') + .replace(//g, '>') + .replace(/"/g, '"') + .replace(/'/g, '''); +} + +/** + * Escapes a string for safe embedding inside a single-quoted JS string literal. + * @param {string} str - Raw string value. + * @returns {string} - JS-escaped string. + */ +function escapeJsString(str) { + if (str == null) return ''; + return String(str) + .replace(/\\/g, '\\\\') + .replace(/'/g, "\\'") + .replace(/\r/g, '\\r') + .replace(/\n/g, '\\n'); +} + class StratUIManager { constructor(workspaceManager) { this.workspaceManager = workspaceManager; @@ -341,28 +370,57 @@ class StratUIManager { strategyItem.className = 'strategy-item'; strategyItem.setAttribute('data-strategy-id', strat.tbl_key); + // Check if this is a subscribed strategy (not owned) + const isSubscribed = strat.is_subscribed && !strat.is_owner; + const isOwner = strat.is_owner !== false; // Default to owner if not specified + // Check if strategy is running const isRunning = UI.strats && UI.strats.isStrategyRunning(strat.tbl_key); const runningInfo = isRunning ? UI.strats.getRunningInfo(strat.tbl_key) : null; - // Delete button - const deleteButton = document.createElement('button'); - deleteButton.className = 'delete-button'; - deleteButton.innerHTML = '✘'; - deleteButton.addEventListener('click', (e) => { - e.stopPropagation(); - if (isRunning) { - alert('Cannot delete a running strategy. Stop it first.'); - return; - } - console.log(`Delete button clicked for strategy: ${strat.name}`); - if (this.onDeleteStrategy) { - this.onDeleteStrategy(strat.tbl_key); - } else { - console.error("Delete strategy callback is not set."); - } - }); - strategyItem.appendChild(deleteButton); + // Add subscribed class if applicable + if (isSubscribed) { + strategyItem.classList.add('subscribed'); + } + + // Delete/Unsubscribe button + if (isSubscribed) { + // Show unsubscribe button for subscribed strategies + const unsubscribeButton = document.createElement('button'); + unsubscribeButton.className = 'unsubscribe-button'; + unsubscribeButton.innerHTML = '−'; // Minus sign + unsubscribeButton.title = 'Unsubscribe from strategy'; + unsubscribeButton.addEventListener('click', (e) => { + e.stopPropagation(); + if (isRunning) { + alert('Cannot unsubscribe while strategy is running. Stop it first.'); + return; + } + if (UI.strats && UI.strats.unsubscribeFromStrategy) { + UI.strats.unsubscribeFromStrategy(strat.tbl_key); + } + }); + strategyItem.appendChild(unsubscribeButton); + } else { + // Delete button for owned strategies + const deleteButton = document.createElement('button'); + deleteButton.className = 'delete-button'; + deleteButton.innerHTML = '✘'; + deleteButton.addEventListener('click', (e) => { + e.stopPropagation(); + if (isRunning) { + alert('Cannot delete a running strategy. Stop it first.'); + return; + } + console.log(`Delete button clicked for strategy: ${strat.name}`); + if (this.onDeleteStrategy) { + this.onDeleteStrategy(strat.tbl_key); + } else { + console.error("Delete strategy callback is not set."); + } + }); + strategyItem.appendChild(deleteButton); + } // Run/Stop button const runButton = document.createElement('button'); @@ -383,11 +441,20 @@ class StratUIManager { // Strategy icon const strategyIcon = document.createElement('div'); strategyIcon.className = isRunning ? 'strategy-icon running' : 'strategy-icon'; + if (isSubscribed) { + strategyIcon.classList.add('subscribed'); + } strategyIcon.addEventListener('click', () => { console.log(`Strategy icon clicked for strategy: ${strat.name}`); - this.displayForm('edit', strat).catch(error => { - console.error('Error displaying form:', error); - }); + if (isSubscribed) { + // Show info modal for subscribed strategies (can't edit) + this.showSubscribedStrategyInfo(strat); + } else { + // Normal edit behavior for owned strategies + this.displayForm('edit', strat).catch(error => { + console.error('Error displaying form:', error); + }); + } }); // Strategy name @@ -395,18 +462,31 @@ class StratUIManager { strategyName.className = 'strategy-name'; strategyName.textContent = strat.name || 'Unnamed Strategy'; strategyIcon.appendChild(strategyName); + + // Creator badge for subscribed strategies + if (isSubscribed && strat.creator_name) { + const creatorBadge = document.createElement('div'); + creatorBadge.className = 'creator-badge'; + creatorBadge.textContent = `by @${strat.creator_name}`; + strategyIcon.appendChild(creatorBadge); + } + strategyItem.appendChild(strategyIcon); // Strategy hover details with run controls const strategyHover = document.createElement('div'); strategyHover.className = 'strategy-hover'; + const strategyKey = String(strat.tbl_key || ''); + const strategyKeyHtml = escapeHtml(strategyKey); + const strategyKeyJs = escapeHtml(escapeJsString(strategyKey)); - // Build hover content - let hoverHtml = `${strat.name || 'Unnamed Strategy'}`; + // Build hover content (escape user-controlled values) + let hoverHtml = `${escapeHtml(strat.name || 'Unnamed Strategy')}`; // Show running status if applicable if (isRunning) { let modeDisplay = runningInfo.mode; + const safeModeDisplay = escapeHtml(modeDisplay); let modeBadge = ''; // Add testnet/production badge for live mode @@ -420,7 +500,7 @@ class StratUIManager { let statusHtml = `