Add exchange validation, fix indicators bug, and improve balance display

- Add exchange_validation.py module for validating exchange requirements
  before running strategies (backtest, paper, live modes)
- Fix AttributeError in Signals.py: 'Indicators' object has no attribute
  'indicators' - created IndicatorWrapper class for proper data access
- Fix testnet balance issue: explicitly pass testnet=False to all
  connect_exchange calls to prevent pickle corruption from old testnet
  Exchange objects
- Add balance exchange selector: display one exchange at a time with
  dropdown to switch between connected exchanges (defaults to chart view)
- Add unique tbl_key generation for exchange_data to prevent duplicate
  entries (format: user:exchange)
- Fix DataCache balance serialization for list types
- Update frontend error handling for exchange validation errors

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
rob 2026-03-08 19:41:41 -03:00
parent ee16023b6b
commit cd9a69f1d4
14 changed files with 1099 additions and 59 deletions

View File

@ -752,6 +752,44 @@ class BrighterTrades:
# This ensures subscribers run with the creator's indicator definitions # This ensures subscribers run with the creator's indicator definitions
indicator_owner_id = creator_id if is_subscribed and not is_owner else None indicator_owner_id = creator_id if is_subscribed and not is_owner else None
# Early exchange requirements validation
from exchange_validation import extract_required_exchanges, validate_exchange_requirements
strategy_full = self.strategies.get_strategy_by_tbl_key(strategy_id)
required_exchanges = extract_required_exchanges(strategy_full)
if required_exchanges:
# Get user's configured exchanges
try:
user_name = self.users.get_username(user_id=user_id)
user_configured = self.users.get_exchanges(user_name, category='configured_exchanges') or []
except Exception:
user_configured = []
# Get EDM available exchanges (List[str])
edm_available = []
if self.edm_client:
try:
edm_available = self.edm_client.get_exchanges_sync()
except Exception as e:
logger.warning(f"Could not fetch EDM exchanges: {e}")
# For backtest mode, fail if EDM unreachable (can't proceed without data)
# Paper/live can continue since they use ccxt/exchange directly
validation_result = validate_exchange_requirements(
required_exchanges=required_exchanges,
user_configured_exchanges=user_configured,
edm_available_exchanges=edm_available,
mode=mode
)
if not validation_result.valid:
return {
"success": False,
"message": validation_result.message,
"error_code": validation_result.error_code.value if validation_result.error_code else None,
"missing_exchanges": list(validation_result.missing_exchanges)
}
# Check if already running # Check if already running
instance_key = (user_id, strategy_id, effective_mode) instance_key = (user_id, strategy_id, effective_mode)
if instance_key in self.strategies.active_instances: if instance_key in self.strategies.active_instances:
@ -1140,9 +1178,9 @@ class BrighterTrades:
try: try:
if self.data.get_serialized_datacache(cache_name='exchange_data', if self.data.get_serialized_datacache(cache_name='exchange_data',
filter_vals=([('user', user_name), ('name', exchange_name)])).empty: filter_vals=([('user', user_name), ('name', exchange_name)])).empty:
# Exchange is not connected, try to connect # Exchange is not connected, try to connect (always use production mode, not testnet)
success = self.exchanges.connect_exchange(exchange_name=exchange_name, user_name=user_name, success = self.exchanges.connect_exchange(exchange_name=exchange_name, user_name=user_name,
api_keys=api_keys) api_keys=api_keys, testnet=False)
if success: if success:
self.users.active_exchange(exchange=exchange_name, user_name=user_name, cmd='set') self.users.active_exchange(exchange=exchange_name, user_name=user_name, cmd='set')
# Check if api_keys has actual key/secret values (not just empty dict) # Check if api_keys has actual key/secret values (not just empty dict)
@ -1166,10 +1204,12 @@ class BrighterTrades:
) )
# Force reconnection to get fresh ccxt client and balances # Force reconnection to get fresh ccxt client and balances
# Always use production mode (testnet=False) unless explicitly requested
reconnect_ok = self.exchanges.connect_exchange( reconnect_ok = self.exchanges.connect_exchange(
exchange_name=exchange_name, exchange_name=exchange_name,
user_name=user_name, user_name=user_name,
api_keys=api_keys api_keys=api_keys,
testnet=False
) )
if reconnect_ok: if reconnect_ok:
# Update stored credentials if they changed # Update stored credentials if they changed
@ -1621,7 +1661,13 @@ class BrighterTrades:
if 'error' in resp: if 'error' in resp:
# If there's an error, send a backtest_error message # If there's an error, send a backtest_error message
return standard_reply("backtest_error", {"message": resp['error']}) # Preserve structured error fields (error_code, missing_exchanges) if present
error_data = {"message": resp['error']}
if 'error_code' in resp:
error_data['error_code'] = resp['error_code']
if 'missing_exchanges' in resp:
error_data['missing_exchanges'] = resp['missing_exchanges']
return standard_reply("backtest_error", error_data)
else: else:
# If successful, send a backtest_submitted message # If successful, send a backtest_submitted message
return standard_reply("backtest_submitted", resp) return standard_reply("backtest_submitted", resp)

View File

@ -1129,9 +1129,20 @@ class DatabaseInteractions(SnapshotDataCache):
if len(rows) > 1: if len(rows) > 1:
raise ValueError(f"Multiple rows found for {filter_vals}. Please provide more specific filter.") raise ValueError(f"Multiple rows found for {filter_vals}. Please provide more specific filter.")
# Update the DataFrame with the new values # Types that don't need serialization (same as serialized_datacache_insert)
for field_name, new_value in zip(field_names, new_values): excluded_objects = (str, int, float, bool, type(None), bytes)
rows[field_name] = new_value
# Serialize non-primitive values for database storage
serialized_values = []
for new_value in new_values:
if not isinstance(new_value, excluded_objects):
serialized_values.append(pickle.dumps(new_value))
else:
serialized_values.append(new_value)
# Update the DataFrame with the serialized values (for cache consistency)
for field_name, serialized_value in zip(field_names, serialized_values):
rows[field_name] = serialized_value
# Get the cache instance # Get the cache instance
cache = self.get_cache(cache_name) cache = self.get_cache(cache_name)
@ -1146,11 +1157,11 @@ class DatabaseInteractions(SnapshotDataCache):
else: else:
raise ValueError(f"Unsupported cache type for {cache_name}") raise ValueError(f"Unsupported cache type for {cache_name}")
# Update the values in the database # Update the values in the database with serialized values
set_clause = ", ".join([f"{field} = ?" for field in field_names]) set_clause = ", ".join([f"{field} = ?" for field in field_names])
where_clause = " AND ".join([f"{col} = ?" for col, _ in filter_vals]) where_clause = " AND ".join([f"{col} = ?" for col, _ in filter_vals])
sql_update = f"UPDATE {cache_name} SET {set_clause} WHERE {where_clause}" sql_update = f"UPDATE {cache_name} SET {set_clause} WHERE {where_clause}"
params = list(new_values) + [val for _, val in filter_vals] params = serialized_values + [val for _, val in filter_vals]
# Execute the SQL update to modify the database # Execute the SQL update to modify the database
self.db.execute_sql(sql_update, params) self.db.execute_sql(sql_update, params)

View File

@ -114,13 +114,19 @@ class ExchangeInterface:
:return: True if successful, False otherwise. :return: True if successful, False otherwise.
""" """
try: try:
# Get existing exchange to preserve EDM session ID for cleanup
existing = None existing = None
old_session_id = None
try: try:
# Preserve existing exchange until the replacement is created successfully.
existing = self.get_exchange(exchange_name, user_name) existing = self.get_exchange(exchange_name, user_name)
old_session_id = existing.edm_session_id if hasattr(existing, 'edm_session_id') else None
old_testnet = getattr(existing, 'testnet', 'unknown')
logger.info(f"Replacing existing {exchange_name} for {user_name} (old testnet={old_testnet}, new testnet={testnet})")
except Exception: except Exception:
pass # No existing entry to replace, that's fine pass # No existing entry, that's fine
# Create new exchange with explicit testnet setting
logger.info(f"Creating {exchange_name} for {user_name} with testnet={testnet}")
exchange = Exchange(name=exchange_name, api_keys=api_keys, exchange_id=exchange_name.lower(), exchange = Exchange(name=exchange_name, api_keys=api_keys, exchange_id=exchange_name.lower(),
testnet=testnet) testnet=testnet)
@ -141,20 +147,29 @@ class ExchangeInterface:
except Exception as e: except Exception as e:
logger.warning(f"Failed to create EDM session for {exchange_name}: {e}") logger.warning(f"Failed to create EDM session for {exchange_name}: {e}")
# Replace existing entry only after new exchange initialization. # ALWAYS try to remove existing entry before adding new one
if existing is not None: # This prevents duplicate entries even if get_exchange failed
old_session_id = existing.edm_session_id if hasattr(existing, 'edm_session_id') else None # Use tbl_key for precise targeting, with fallback to user+name filter
tbl_key = f"{user_name}:{exchange_name}"
try:
self.cache_manager.remove_row_from_datacache( self.cache_manager.remove_row_from_datacache(
cache_name='exchange_data', cache_name='exchange_data',
filter_vals=[('user', user_name), ('name', exchange_name)] filter_vals=[('user', user_name), ('name', exchange_name)],
key=tbl_key
) )
if old_session_id and self.edm_client: logger.info(f"Removed old exchange entry for {user_name}/{exchange_name}")
try: except Exception as e:
self.edm_client.delete_session_sync(old_session_id) logger.debug(f"No existing entry to remove for {user_name}/{exchange_name}: {e}")
except Exception as e:
logger.warning(f"Failed to delete old EDM session: {e}") # Clean up old EDM session if we had one
if old_session_id and self.edm_client:
try:
self.edm_client.delete_session_sync(old_session_id)
except Exception as e:
logger.warning(f"Failed to delete old EDM session: {e}")
self.add_exchange(user_name, exchange) self.add_exchange(user_name, exchange)
logger.info(f"Connected {exchange_name} for {user_name} (testnet={testnet}, balances={len(exchange.balances)} assets)")
return True return True
except Exception as e: except Exception as e:
logger.error(f"Failed to connect user '{user_name}' to exchange '{exchange_name}': {str(e)}") logger.error(f"Failed to connect user '{user_name}' to exchange '{exchange_name}': {str(e)}")
@ -168,6 +183,9 @@ class ExchangeInterface:
:param exchange: The Exchange object to add. :param exchange: The Exchange object to add.
""" """
try: try:
# Generate a unique tbl_key to prevent duplicates
tbl_key = f"{user_name}:{exchange.name}"
row_data = { row_data = {
'user': user_name, 'user': user_name,
'name': exchange.name, 'name': exchange.name,
@ -181,7 +199,8 @@ class ExchangeInterface:
row = pd.DataFrame([row_data]) row = pd.DataFrame([row_data])
self.cache_manager.serialized_datacache_insert(cache_name='exchange_data', data=row) # Pass key to let serialized_datacache_insert add the tbl_key column
self.cache_manager.serialized_datacache_insert(cache_name='exchange_data', data=row, key=tbl_key)
except Exception as e: except Exception as e:
logger.error(f"Couldn't create an instance of the exchange! {str(e)}") logger.error(f"Couldn't create an instance of the exchange! {str(e)}")
raise raise

View File

@ -3,7 +3,7 @@ import logging
import uuid import uuid
import datetime as dt import datetime as dt
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any, Dict
import pandas as pd import pandas as pd
from DataCache_v3 import DataCache from DataCache_v3 import DataCache
@ -12,6 +12,18 @@ from DataCache_v3 import DataCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class IndicatorWrapper:
"""
Wrapper to make indicator dict data accessible via .properties attribute.
This bridges the gap between get_indicator_list() return format
(flat dict with properties merged in) and the expected access pattern
(indicator.properties.get(prop_name)).
"""
def __init__(self, data: dict):
self.properties = data
@dataclass() @dataclass()
class Signal: class Signal:
"""Class for individual signal properties and state.""" """Class for individual signal properties and state."""
@ -507,8 +519,30 @@ class Signals:
:return: Dictionary of signals that changed state. :return: Dictionary of signals that changed state.
""" """
state_changes = {} state_changes = {}
# Cache indicator data per user to avoid repeated lookups
user_indicator_cache: Dict[int, dict] = {}
for signal in self.signals: for signal in self.signals:
change_in_state = self.process_signal(signal, indicators) # Get or fetch indicator data for this signal's creator
user_id = signal.creator
if user_id not in user_indicator_cache:
try:
# Fetch indicator list for this user
indicator_list = indicators.get_indicator_list(user_id=user_id)
# Wrap each indicator's data so it has a .properties attribute
user_indicator_cache[user_id] = {
name: IndicatorWrapper(data)
for name, data in indicator_list.items()
}
except Exception as e:
logger.debug(f"Could not fetch indicators for user {user_id}: {e}")
user_indicator_cache[user_id] = {}
# Get the wrapped indicators for this user
user_indicators = user_indicator_cache.get(user_id, {})
change_in_state = self.process_signal(signal, user_indicators)
if change_in_state: if change_in_state:
state_changes[signal.name] = signal.state state_changes[signal.name] = signal.state
# Persist state change to database # Persist state change to database
@ -521,12 +555,12 @@ class Signals:
) )
return state_changes return state_changes
def process_signal(self, signal: Signal, indicators, candles=None) -> bool: def process_signal(self, signal: Signal, indicator_data: dict, candles=None) -> bool:
""" """
Process a signal by comparing indicator values. Process a signal by comparing indicator values.
:param signal: The signal to process. :param signal: The signal to process.
:param indicators: The Indicators instance with calculated values. :param indicator_data: Dict mapping indicator names to IndicatorWrapper objects.
:param candles: Optional candles for recalculation. :param candles: Optional candles for recalculation.
:return: True if the signal state changed, False otherwise. :return: True if the signal state changed, False otherwise.
""" """
@ -534,8 +568,8 @@ class Signals:
# Get the source of the first signal # Get the source of the first signal
source_1 = signal.source1 source_1 = signal.source1
# Ask the indicator for the last result # Ask the indicator for the last result
if source_1 in indicators.indicators: if source_1 in indicator_data:
signal.value1 = indicators.indicators[source_1].properties.get(signal.prop1) signal.value1 = indicator_data[source_1].properties.get(signal.prop1)
else: else:
logger.debug(f'Could not calculate signal: source indicator "{source_1}" not found.') logger.debug(f'Could not calculate signal: source indicator "{source_1}" not found.')
return False return False
@ -550,8 +584,8 @@ class Signals:
signal.value2 = signal.prop2 signal.value2 = signal.prop2
else: else:
# Ask the indicator for the last result # Ask the indicator for the last result
if source_2 in indicators.indicators: if source_2 in indicator_data:
signal.value2 = indicators.indicators[source_2].properties.get(signal.prop2) signal.value2 = indicator_data[source_2].properties.get(signal.prop2)
else: else:
logger.debug(f'Could not calculate signal: source2 indicator "{source_2}" not found.') logger.debug(f'Could not calculate signal: source2 indicator "{source_2}" not found.')
return False return False

View File

@ -1116,6 +1116,20 @@ class Strategies:
return strategy_row return strategy_row
def get_required_exchanges(self, strategy_tbl_key: str) -> set[str]:
"""
Get the set of exchange names required by a strategy.
Extracts unique exchange names from the strategy's data sources
and default source settings.
:param strategy_tbl_key: The unique identifier of the strategy.
:return: Set of canonicalized exchange names required by the strategy.
"""
from exchange_validation import extract_required_exchanges
strategy = self.get_strategy_by_tbl_key(strategy_tbl_key)
return extract_required_exchanges(strategy)
def update_strategy_stats(self, strategy_id: str, profit_loss: float) -> None: def update_strategy_stats(self, strategy_id: str, profit_loss: float) -> None:
""" """
Updates the strategy's statistics based on the latest profit or loss. Updates the strategy's statistics based on the latest profit or loss.

View File

@ -781,6 +781,27 @@ class Backtester:
# For subscribed strategies, use creator's indicators # For subscribed strategies, use creator's indicators
indicator_owner_id = creator_id if is_subscribed and not is_owner else None indicator_owner_id = creator_id if is_subscribed and not is_owner else None
# Validate exchange requirements for backtest
from exchange_validation import extract_required_exchanges, validate_for_backtest
required_exchanges = extract_required_exchanges(strategy)
if required_exchanges and self.edm_client:
try:
edm_available = self.edm_client.get_exchanges_sync()
validation_result = validate_for_backtest(required_exchanges, edm_available)
if not validation_result.valid:
return {
"error": validation_result.message,
"error_code": validation_result.error_code.value if validation_result.error_code else None,
"missing_exchanges": list(validation_result.missing_exchanges)
}
except Exception as e:
logger.warning(f"Could not validate EDM exchanges: {e}")
return {
"error": "Cannot validate exchange availability - EDM unreachable",
"error_code": "edm_unreachable"
}
if not backtest_name: if not backtest_name:
# If backtest_name is not provided, generate a unique name # If backtest_name is not provided, generate a unique name
backtest_name = f"{tbl_key}_backtest" backtest_name = f"{tbl_key}_backtest"

314
src/exchange_validation.py Normal file
View File

@ -0,0 +1,314 @@
"""
Exchange requirements validation for strategies.
Centralized validator with structured error codes for validating that users
have access to required exchanges before running strategies in different
trading modes (backtest, paper, live).
"""
import logging
from typing import Set, List, Dict, Any
from enum import Enum
logger = logging.getLogger(__name__)
class ValidationErrorCode(Enum):
"""Structured error codes for exchange validation failures."""
MISSING_EDM_DATA = "missing_edm_data"
MISSING_CONFIG = "missing_config"
INVALID_EXCHANGE = "invalid_exchange"
INVALID_KEYS = "invalid_keys"
EDM_UNREACHABLE = "edm_unreachable"
class ExchangeValidationResult:
"""Structured validation result."""
def __init__(
self,
valid: bool,
error_code: ValidationErrorCode = None,
missing_exchanges: Set[str] = None,
message: str = None
):
self.valid = valid
self.error_code = error_code
self.missing_exchanges = missing_exchanges or set()
self.message = message
def to_dict(self) -> Dict[str, Any]:
"""Convert result to dictionary for JSON serialization."""
result = {"valid": self.valid}
if not self.valid:
result["error_code"] = self.error_code.value if self.error_code else None
result["missing_exchanges"] = list(self.missing_exchanges)
result["message"] = self.message
return result
# Exchange name canonicalization map
# Maps variations to canonical names. Names not in map pass through lowercased.
EXCHANGE_ALIASES = {
'binance': 'binance',
'binanceus': 'binanceus',
'binanceusdm': 'binanceusdm',
'binancecoinm': 'binancecoinm',
'kucoin': 'kucoin',
'kucoinfutures': 'kucoinfutures',
'kraken': 'kraken',
'krakenfutures': 'krakenfutures',
'coinbase': 'coinbase',
'coinbasepro': 'coinbasepro',
'bybit': 'bybit',
'okx': 'okx',
'okex': 'okx', # Alias: okex -> okx
'gateio': 'gateio',
'gate': 'gateio', # Alias: gate -> gateio
'htx': 'htx',
'huobi': 'htx', # Alias: huobi -> htx (rebranded)
}
def canonicalize_exchange(name: str) -> str:
"""
Normalize exchange name to canonical form.
Handles case normalization and common aliases.
:param name: Exchange name in any case
:return: Canonical lowercase exchange name
"""
if not name:
return ''
lower = name.lower().strip()
return EXCHANGE_ALIASES.get(lower, lower)
def extract_required_exchanges(strategy: dict) -> Set[str]:
"""
Extract unique exchange names required by a strategy.
Single canonical implementation - use everywhere to avoid drift.
Parses strategy_components.data_sources to identify all exchanges
the strategy needs for execution. Falls back to default_source if
data_sources is empty.
:param strategy: Strategy dictionary from database
:return: Set of canonicalized exchange names
"""
if not strategy:
return set()
components = strategy.get('strategy_components', {})
if isinstance(components, str):
# Handle case where components is JSON string
import json
try:
components = json.loads(components)
except (json.JSONDecodeError, TypeError):
components = {}
data_sources = components.get('data_sources', [])
exchanges = set()
# Extract from data_sources (list of tuples or dicts)
for source in data_sources:
if isinstance(source, (list, tuple)) and len(source) >= 1:
exchange_name = source[0]
if exchange_name:
exchanges.add(canonicalize_exchange(exchange_name))
elif isinstance(source, dict) and source.get('exchange'):
exchanges.add(canonicalize_exchange(source['exchange']))
# Also check default_source as fallback
default_source = strategy.get('default_source', {})
if isinstance(default_source, str):
import json
try:
default_source = json.loads(default_source)
except (json.JSONDecodeError, TypeError):
default_source = {}
if default_source and default_source.get('exchange'):
exchanges.add(canonicalize_exchange(default_source['exchange']))
return exchanges
def get_valid_ccxt_exchanges() -> Set[str]:
"""
Get set of valid ccxt exchange names.
Used to validate that an exchange is supported for paper trading.
:return: Set of lowercase exchange names supported by ccxt
"""
import ccxt
return {ex.lower() for ex in ccxt.exchanges}
def _extract_exchange_name(ex) -> str:
"""
Extract exchange name from EDM response item.
EDM may return either strings or dicts with a 'name' key.
:param ex: Exchange item (str or dict)
:return: Exchange name string
"""
if isinstance(ex, dict):
return ex.get('name', '')
return str(ex) if ex else ''
def validate_for_backtest(
required_exchanges: Set[str],
edm_available_exchanges: List[str]
) -> ExchangeValidationResult:
"""
Validate exchanges for backtest mode.
Backtest requires historical data from EDM. All required exchanges
must be available in EDM's exchange list.
:param required_exchanges: Set of canonicalized exchange names
:param edm_available_exchanges: List of exchanges available in EDM (strings or dicts)
:return: Validation result
"""
if not required_exchanges:
return ExchangeValidationResult(valid=True)
edm_available = {
canonicalize_exchange(_extract_exchange_name(ex))
for ex in (edm_available_exchanges or [])
}
missing = required_exchanges - edm_available
if missing:
return ExchangeValidationResult(
valid=False,
error_code=ValidationErrorCode.MISSING_EDM_DATA,
missing_exchanges=missing,
message=f"Historical data not available for: {', '.join(sorted(missing))}"
)
return ExchangeValidationResult(valid=True)
def validate_for_paper(
required_exchanges: Set[str],
edm_available_exchanges: List[str]
) -> ExchangeValidationResult:
"""
Validate exchanges for paper mode.
Paper mode uses ccxt public endpoints for price fetching, so exchanges
just need to be valid ccxt exchanges. Warns if not in EDM since some
data fetching may fail.
:param required_exchanges: Set of canonicalized exchange names
:param edm_available_exchanges: List of exchanges available in EDM
:return: Validation result
"""
if not required_exchanges:
return ExchangeValidationResult(valid=True)
# Paper mode uses ccxt public endpoints - just need valid ccxt exchange
ccxt_exchanges = get_valid_ccxt_exchanges()
invalid = required_exchanges - ccxt_exchanges
if invalid:
return ExchangeValidationResult(
valid=False,
error_code=ValidationErrorCode.INVALID_EXCHANGE,
missing_exchanges=invalid,
message=f"Unknown exchanges: {', '.join(sorted(invalid))}"
)
# Warn if not in EDM (price fetching may use ccxt fallback)
edm_available = {
canonicalize_exchange(_extract_exchange_name(ex))
for ex in (edm_available_exchanges or [])
}
not_in_edm = required_exchanges - edm_available
if not_in_edm:
logger.warning(
f"Exchanges not in EDM (will use ccxt fallback): {not_in_edm}"
)
return ExchangeValidationResult(valid=True)
def validate_for_live(
required_exchanges: Set[str],
user_configured_exchanges: List[str]
) -> ExchangeValidationResult:
"""
Validate exchanges for live mode.
Live mode requires user to have API keys configured for each exchange.
This is an early check - full validation (including key validity) happens
later in start_strategy when it attempts to connect.
:param required_exchanges: Set of canonicalized exchange names
:param user_configured_exchanges: Exchanges with API keys configured
:return: Validation result
"""
if not required_exchanges:
return ExchangeValidationResult(valid=True)
configured = {
canonicalize_exchange(ex)
for ex in (user_configured_exchanges or [])
}
missing = required_exchanges - configured
if missing:
return ExchangeValidationResult(
valid=False,
error_code=ValidationErrorCode.MISSING_CONFIG,
missing_exchanges=missing,
message=f"API keys required for: {', '.join(sorted(missing))}"
)
return ExchangeValidationResult(valid=True)
def validate_exchange_requirements(
required_exchanges: Set[str],
user_configured_exchanges: List[str],
edm_available_exchanges: List[str],
mode: str
) -> ExchangeValidationResult:
"""
Main validation entrypoint. Routes to mode-specific validator.
Use this from both start_strategy and backtest entry points.
:param required_exchanges: Set of exchange names required by strategy
:param user_configured_exchanges: Exchanges with API keys configured
:param edm_available_exchanges: Exchanges available in EDM
:param mode: Trading mode ('backtest', 'paper', 'live')
:return: Validation result with error details if invalid
"""
# Canonicalize all required exchanges
required = {canonicalize_exchange(ex) for ex in required_exchanges if ex}
if not required:
return ExchangeValidationResult(valid=True)
if mode == 'live':
return validate_for_live(required, user_configured_exchanges)
elif mode == 'paper':
return validate_for_paper(required, edm_available_exchanges)
elif mode == 'backtest':
return validate_for_backtest(required, edm_available_exchanges)
else:
# Unknown mode - reject explicitly
return ExchangeValidationResult(
valid=False,
error_code=ValidationErrorCode.INVALID_EXCHANGE,
missing_exchanges=set(),
message=f"Invalid trading mode: {mode}"
)

View File

@ -2000,8 +2000,38 @@ class Strategies {
* @param {Object} data - Error data from server. * @param {Object} data - Error data from server.
*/ */
handleStrategyRunError(data) { handleStrategyRunError(data) {
console.error("Strategy run error:", data.message); console.error("Strategy run error:", data.message, data);
alert(`Failed to start strategy: ${data.message}`);
const errorCode = data.error_code;
const missing = data.missing_exchanges;
// Handle exchange requirement errors with detailed messages
if (missing && missing.length > 0) {
const exchanges = missing.join(', ');
let message;
switch (errorCode) {
case 'missing_edm_data':
message = `Historical data not available for these exchanges:\n\n${exchanges}\n\n` +
`These exchanges may not be supported by the Exchange Data Manager.`;
break;
case 'missing_config':
message = `Please configure API keys for:\n\n${exchanges}\n\n` +
`Go to Exchange Settings to add your credentials.`;
break;
case 'invalid_exchange':
message = `Unknown or unsupported exchanges:\n\n${exchanges}`;
break;
case 'edm_unreachable':
message = `Cannot validate exchange availability - data service unreachable.`;
break;
default:
message = `This strategy requires: ${exchanges}`;
}
alert(message);
} else {
alert(`Failed to start strategy: ${data.message || data.error || 'Unknown error'}`);
}
} }
/** /**

View File

@ -101,7 +101,7 @@ class Backtesting {
} }
handleBacktestError(data) { handleBacktestError(data) {
console.error("Backtest error:", data.message); console.error("Backtest error:", data.message || data.error, data);
const test = this.tests.find(t => t.name === this.currentTest); const test = this.tests.find(t => t.name === this.currentTest);
if (test) { if (test) {
@ -110,7 +110,29 @@ class Backtesting {
this.updateHTML(); this.updateHTML();
} }
this.displayMessage(`Backtest error: ${data.message}`, 'red'); // Build error message with exchange requirement details if present
let errorMessage;
const errorCode = data.error_code;
const missing = data.missing_exchanges;
if (missing && missing.length > 0) {
const exchanges = missing.join(', ');
switch (errorCode) {
case 'missing_edm_data':
errorMessage = `Historical data not available for: ${exchanges}. ` +
`These exchanges may not be supported.`;
break;
case 'edm_unreachable':
errorMessage = `Cannot validate exchange availability - data service unreachable.`;
break;
default:
errorMessage = `This strategy requires exchanges: ${exchanges}`;
}
} else {
errorMessage = data.message || data.error || 'Unknown error';
}
this.displayMessage(`Backtest error: ${errorMessage}`, 'red');
// Hide progress bar and results // Hide progress bar and results
this.hideElement(this.progressContainer); this.hideElement(this.progressContainer);

View File

@ -1,8 +1,9 @@
class Exchanges { class Exchanges {
constructor() { constructor() {
this.exchanges = {}; this.exchanges = {};
this.balances = {}; this.balances = {}; // All balances by exchange name
this.connected_exchanges = []; this.connected_exchanges = [];
this.selectedBalanceExchange = null; // Currently selected exchange for balance display
this.isSubmitting = false; this.isSubmitting = false;
} }
@ -14,6 +15,15 @@ class Exchanges {
// Extract the text content from each span and store it in the connected_exchanges array // Extract the text content from each span and store it in the connected_exchanges array
this.connected_exchanges = Array.from(spans).map(span => span.textContent.trim()); this.connected_exchanges = Array.from(spans).map(span => span.textContent.trim());
// Get the currently selected exchange from the selector (defaults to chart view exchange)
const selector = document.getElementById('balance_exchange_selector');
if (selector && selector.value) {
this.selectedBalanceExchange = selector.value;
} else {
// No selector or no options - will show empty state
this.selectedBalanceExchange = null;
}
// Register handlers for exchange events // Register handlers for exchange events
if (window.UI && window.UI.data && window.UI.data.comms) { if (window.UI && window.UI.data && window.UI.data.comms) {
window.UI.data.comms.on('Exchange_connection_result', this.handleConnectionResult.bind(this)); window.UI.data.comms.on('Exchange_connection_result', this.handleConnectionResult.bind(this));
@ -21,6 +31,45 @@ class Exchanges {
} }
} }
onBalanceExchangeChange(exchangeName) {
this.selectedBalanceExchange = exchangeName;
// If we have cached balances for this exchange, display them
if (this.balances[exchangeName]) {
this.displaySingleExchangeBalances(exchangeName, this.balances[exchangeName]);
} else {
// No cached balances, show empty state
this.displaySingleExchangeBalances(exchangeName, []);
}
}
displaySingleExchangeBalances(exchangeName, balanceList) {
const tbl = document.getElementById('balances_tbl');
if (!tbl) return;
let html = '<table><tr><th>Asset</th><th>Balance</th><th>Profit & Loss</th></tr>';
let hasValidBalances = false;
if (Array.isArray(balanceList) && balanceList.length > 0) {
for (const balance of balanceList) {
// Skip N/A placeholder entries
if (balance.asset === 'N/A' && balance.balance === 0) continue;
hasValidBalances = true;
html += `<tr>
<td>${balance.asset || ''}</td>
<td>${this.formatBalance(balance.balance)}</td>
<td>${this.formatBalance(balance.pnl)}</td>
</tr>`;
}
}
if (!hasValidBalances) {
html += '<tr><td colspan="3" style="text-align: center; color: #888;">No balances available</td></tr>';
}
html += '</table>';
tbl.innerHTML = html;
}
status() { status() {
// Reset form state when opening // Reset form state when opening
this.resetFormState(); this.resetFormState();
@ -205,7 +254,10 @@ class Exchanges {
} }
if (data.success && data.balances) { if (data.success && data.balances) {
console.log('Updating balances table with:', data.balances); console.log('Updating balances cache with:', data.balances);
// Store all balances in cache
this.balances = data.balances;
// Display only the selected exchange's balances
this.updateBalancesTable(data.balances); this.updateBalancesTable(data.balances);
} else if (!data.success) { } else if (!data.success) {
console.error('Failed to refresh balances:', data.message); console.error('Failed to refresh balances:', data.message);
@ -213,27 +265,47 @@ class Exchanges {
} }
updateBalancesTable(balances) { updateBalancesTable(balances) {
const tbl = document.getElementById('balances_tbl'); // Update the selector with available exchanges
if (!tbl) return; this.updateBalanceExchangeSelector(Object.keys(balances));
// Build new table HTML // Only display the selected exchange's balances
let html = '<table><tr><th>Asset</th><th>Balance</th><th>Profit & Loss</th></tr>'; let selectedExchange = this.selectedBalanceExchange;
for (const [exchangeName, exchangeBalances] of Object.entries(balances)) { // If no exchange is selected or selected doesn't exist, pick the first one
html += `<tr><td class="name-row" colspan="4">${exchangeName}</td></tr>`; if (!selectedExchange || !balances[selectedExchange]) {
if (Array.isArray(exchangeBalances)) { const exchanges = Object.keys(balances);
for (const balance of exchangeBalances) { if (exchanges.length > 0) {
html += `<tr> selectedExchange = exchanges[0];
<td>${balance.asset || ''}</td> this.selectedBalanceExchange = selectedExchange;
<td>${this.formatBalance(balance.balance)}</td> const selector = document.getElementById('balance_exchange_selector');
<td>${this.formatBalance(balance.pnl)}</td> if (selector) {
</tr>`; selector.value = selectedExchange;
} }
} }
} }
html += '</table>'; const exchangeBalances = balances[selectedExchange] || [];
tbl.innerHTML = html; this.displaySingleExchangeBalances(selectedExchange, exchangeBalances);
}
updateBalanceExchangeSelector(exchanges) {
const selector = document.getElementById('balance_exchange_selector');
if (!selector) return;
// Get current selection
const currentSelection = selector.value;
// Rebuild options
selector.innerHTML = '';
for (const exchange of exchanges) {
const option = document.createElement('option');
option.value = exchange;
option.textContent = exchange;
if (exchange === currentSelection || exchange === this.selectedBalanceExchange) {
option.selected = true;
}
selector.appendChild(option);
}
} }
formatBalance(value) { formatBalance(value) {

View File

@ -23,6 +23,16 @@
<div> <div>
<h3 style="display: flex; align-items: center; gap: 10px;"> <h3 style="display: flex; align-items: center; gap: 10px;">
Balances Balances
{% set balance_exchanges = my_balances.keys()|list %}
{% set default_balance_exchange = selected_exchange if selected_exchange in balance_exchanges else (balance_exchanges[0] if balance_exchanges else '') %}
{% if balance_exchanges %}
<select id="balance_exchange_selector" onchange="UI.exchanges.onBalanceExchangeChange(this.value)"
style="font-size: 12px; padding: 2px 4px;">
{% for exchange in balance_exchanges %}
<option value="{{ exchange }}" {% if exchange == default_balance_exchange %}selected{% endif %}>{{ exchange }}</option>
{% endfor %}
</select>
{% endif %}
<button id="refresh_balances_btn" onclick="UI.exchanges.refreshBalances()" <button id="refresh_balances_btn" onclick="UI.exchanges.refreshBalances()"
style="font-size: 12px; padding: 2px 8px; cursor: pointer;" style="font-size: 12px; padding: 2px 8px; cursor: pointer;"
title="Refresh balances from exchange">&#x21bb;</button> title="Refresh balances from exchange">&#x21bb;</button>
@ -35,20 +45,21 @@
<th>Balance</th> <th>Balance</th>
<th>Profit & Loss</th> <th>Profit & Loss</th>
</tr> </tr>
{% for name, balances in my_balances.items() %} {% set selected_balances = my_balances.get(default_balance_exchange, []) %}
<tr> {% set valid_balances = selected_balances|selectattr('asset', 'ne', 'N/A')|list if selected_balances else [] %}
<td class="name-row" colspan="4">{{ name }}</td> {% if valid_balances %}
</tr> {% for balance in valid_balances %}
{% if balances %}
{% for balance in balances %}
<tr> <tr>
<td>{{ balance['asset'] }}</td> <td>{{ balance['asset'] }}</td>
<td>{{ balance['balance']|format_balance }}</td> <td>{{ balance['balance']|format_balance }}</td>
<td>{{ balance['pnl']|format_balance }}</td> <td>{{ balance['pnl']|format_balance }}</td>
</tr> </tr>
{% endfor %} {% endfor %}
{% endif %} {% else %}
{% endfor %} <tr>
<td colspan="3" style="text-align: center; color: #888;">No balances available</td>
</tr>
{% endif %}
</table> </table>
</div> </div>
</div> </div>

View File

@ -170,7 +170,8 @@ class TestBrighterTrades(unittest.TestCase):
self.mock_exchanges.connect_exchange.assert_called_with( self.mock_exchanges.connect_exchange.assert_called_with(
exchange_name='kucoin', exchange_name='kucoin',
user_name='testuser', user_name='testuser',
api_keys=new_keys api_keys=new_keys,
testnet=False
) )
self.mock_users.update_api_keys.assert_called_with( self.mock_users.update_api_keys.assert_called_with(
api_keys=new_keys, api_keys=new_keys,

View File

@ -0,0 +1,425 @@
"""
Unit tests for exchange_validation module.
Tests cover:
- Exchange extraction from strategy data
- Mode-specific validation (backtest, paper, live)
- Exchange name canonicalization
- Error code generation
- Edge cases (empty data, EDM unavailable, etc.)
"""
import pytest
import sys
import os
# Add src to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
from exchange_validation import (
canonicalize_exchange,
extract_required_exchanges,
validate_for_backtest,
validate_for_paper,
validate_for_live,
validate_exchange_requirements,
ValidationErrorCode,
ExchangeValidationResult,
)
class TestCanonicalizeExchange:
"""Tests for exchange name canonicalization."""
def test_lowercase_normalization(self):
"""Test that exchange names are normalized to lowercase."""
assert canonicalize_exchange('BINANCE') == 'binance'
assert canonicalize_exchange('Binance') == 'binance'
assert canonicalize_exchange('binance') == 'binance'
def test_alias_mapping(self):
"""Test that known aliases are mapped correctly."""
assert canonicalize_exchange('okex') == 'okx'
assert canonicalize_exchange('OKEX') == 'okx'
assert canonicalize_exchange('huobi') == 'htx'
assert canonicalize_exchange('gate') == 'gateio'
def test_passthrough_unknown(self):
"""Test that unknown exchange names pass through lowercased."""
assert canonicalize_exchange('someexchange') == 'someexchange'
assert canonicalize_exchange('NEWEXCHANGE') == 'newexchange'
def test_empty_string(self):
"""Test handling of empty string."""
assert canonicalize_exchange('') == ''
assert canonicalize_exchange(' ') == ''
def test_whitespace_stripping(self):
"""Test that whitespace is stripped."""
assert canonicalize_exchange(' binance ') == 'binance'
assert canonicalize_exchange('\tkucoin\n') == 'kucoin'
class TestExtractRequiredExchanges:
"""Tests for extracting required exchanges from strategy data."""
def test_empty_strategy(self):
"""Test handling of empty/None strategy."""
assert extract_required_exchanges(None) == set()
assert extract_required_exchanges({}) == set()
def test_extract_from_data_sources_tuple(self):
"""Test extraction from data_sources as tuples."""
strategy = {
'strategy_components': {
'data_sources': [
('binance', 'BTC/USDT', '1h'),
('kucoin', 'ETH/USDT', '5m'),
]
}
}
result = extract_required_exchanges(strategy)
assert result == {'binance', 'kucoin'}
def test_extract_from_data_sources_list(self):
"""Test extraction from data_sources as lists."""
strategy = {
'strategy_components': {
'data_sources': [
['binance', 'BTC/USDT', '1h'],
]
}
}
result = extract_required_exchanges(strategy)
assert result == {'binance'}
def test_extract_from_data_sources_dict(self):
"""Test extraction from data_sources as dicts."""
strategy = {
'strategy_components': {
'data_sources': [
{'exchange': 'binance', 'symbol': 'BTC/USDT', 'timeframe': '1h'},
]
}
}
result = extract_required_exchanges(strategy)
assert result == {'binance'}
def test_extract_from_default_source(self):
"""Test extraction from default_source."""
strategy = {
'default_source': {
'exchange': 'kucoin',
'symbol': 'BTC/USDT',
'timeframe': '15m'
}
}
result = extract_required_exchanges(strategy)
assert result == {'kucoin'}
def test_extract_combined_sources(self):
"""Test extraction from both data_sources and default_source."""
strategy = {
'strategy_components': {
'data_sources': [
('binance', 'BTC/USDT', '1h'),
]
},
'default_source': {
'exchange': 'kucoin',
'symbol': 'ETH/USDT',
}
}
result = extract_required_exchanges(strategy)
assert result == {'binance', 'kucoin'}
def test_extract_with_canonicalization(self):
"""Test that extracted exchanges are canonicalized."""
strategy = {
'strategy_components': {
'data_sources': [
('BINANCE', 'BTC/USDT', '1h'),
('okex', 'ETH/USDT', '5m'), # Should become 'okx'
]
}
}
result = extract_required_exchanges(strategy)
assert result == {'binance', 'okx'}
def test_extract_deduplication(self):
"""Test that duplicate exchanges are deduplicated."""
strategy = {
'strategy_components': {
'data_sources': [
('binance', 'BTC/USDT', '1h'),
('Binance', 'ETH/USDT', '5m'),
('BINANCE', 'LTC/USDT', '15m'),
]
}
}
result = extract_required_exchanges(strategy)
assert result == {'binance'}
def test_extract_json_string_components(self):
"""Test extraction when strategy_components is JSON string."""
import json
strategy = {
'strategy_components': json.dumps({
'data_sources': [
['binance', 'BTC/USDT', '1h'],
]
})
}
result = extract_required_exchanges(strategy)
assert result == {'binance'}
class TestValidateForBacktest:
"""Tests for backtest mode validation."""
def test_empty_requirements(self):
"""Test that empty requirements are valid."""
result = validate_for_backtest(set(), ['binance', 'kucoin'])
assert result.valid is True
def test_all_available(self):
"""Test validation when all exchanges are available."""
result = validate_for_backtest(
{'binance', 'kucoin'},
['binance', 'kucoin', 'kraken']
)
assert result.valid is True
def test_missing_exchanges(self):
"""Test validation when exchanges are missing."""
result = validate_for_backtest(
{'binance', 'kucoin', 'unknown_exchange'},
['binance', 'kucoin']
)
assert result.valid is False
assert result.error_code == ValidationErrorCode.MISSING_EDM_DATA
assert result.missing_exchanges == {'unknown_exchange'}
assert 'unknown_exchange' in result.message
def test_empty_edm_list(self):
"""Test validation when EDM list is empty."""
result = validate_for_backtest({'binance'}, [])
assert result.valid is False
assert result.error_code == ValidationErrorCode.MISSING_EDM_DATA
def test_none_edm_list(self):
"""Test validation when EDM list is None."""
result = validate_for_backtest({'binance'}, None)
assert result.valid is False
assert result.error_code == ValidationErrorCode.MISSING_EDM_DATA
def test_canonicalization_in_comparison(self):
"""Test that EDM exchanges are canonicalized during comparison."""
result = validate_for_backtest(
{'binance'},
['BINANCE'] # Should match after canonicalization
)
assert result.valid is True
def test_edm_returns_dicts(self):
"""Test that EDM response with dicts (containing 'name' key) is handled."""
result = validate_for_backtest(
{'binance', 'kucoin'},
[{'name': 'binance', 'timeframes': ['1m', '1h']},
{'name': 'kucoin', 'timeframes': ['5m', '15m']}]
)
assert result.valid is True
def test_edm_returns_dicts_missing(self):
"""Test that missing exchanges are detected with dict format."""
result = validate_for_backtest(
{'binance', 'kraken'},
[{'name': 'binance'}]
)
assert result.valid is False
assert result.missing_exchanges == {'kraken'}
class TestValidateForPaper:
"""Tests for paper mode validation."""
def test_empty_requirements(self):
"""Test that empty requirements are valid."""
result = validate_for_paper(set(), [])
assert result.valid is True
def test_valid_ccxt_exchange(self):
"""Test validation with valid ccxt exchange."""
result = validate_for_paper({'binance'}, [])
assert result.valid is True
def test_invalid_exchange(self):
"""Test validation with invalid exchange name."""
result = validate_for_paper({'totally_fake_exchange_xyz'}, [])
assert result.valid is False
assert result.error_code == ValidationErrorCode.INVALID_EXCHANGE
assert 'totally_fake_exchange_xyz' in result.missing_exchanges
def test_mixed_valid_invalid(self):
"""Test validation with mix of valid and invalid exchanges."""
result = validate_for_paper({'binance', 'fake_exchange'}, [])
assert result.valid is False
assert result.missing_exchanges == {'fake_exchange'}
class TestValidateForLive:
"""Tests for live mode validation."""
def test_empty_requirements(self):
"""Test that empty requirements are valid."""
result = validate_for_live(set(), [])
assert result.valid is True
def test_configured_exchange(self):
"""Test validation when exchange is configured."""
result = validate_for_live({'binance'}, ['binance', 'kucoin'])
assert result.valid is True
def test_missing_config(self):
"""Test validation when exchange is not configured."""
result = validate_for_live({'kraken'}, ['binance', 'kucoin'])
assert result.valid is False
assert result.error_code == ValidationErrorCode.MISSING_CONFIG
assert result.missing_exchanges == {'kraken'}
assert 'API keys' in result.message
def test_multiple_missing(self):
"""Test validation with multiple missing exchanges."""
result = validate_for_live(
{'binance', 'kraken', 'bybit'},
['binance']
)
assert result.valid is False
assert result.missing_exchanges == {'kraken', 'bybit'}
def test_canonicalization_in_comparison(self):
"""Test that configured exchanges are canonicalized."""
result = validate_for_live({'binance'}, ['BINANCE'])
assert result.valid is True
class TestValidateExchangeRequirements:
"""Tests for main validation entrypoint."""
def test_routes_to_backtest(self):
"""Test that backtest mode routes correctly."""
result = validate_exchange_requirements(
required_exchanges={'binance'},
user_configured_exchanges=[],
edm_available_exchanges=['binance'],
mode='backtest'
)
assert result.valid is True
def test_routes_to_paper(self):
"""Test that paper mode routes correctly."""
result = validate_exchange_requirements(
required_exchanges={'binance'},
user_configured_exchanges=[],
edm_available_exchanges=[],
mode='paper'
)
assert result.valid is True # binance is valid ccxt exchange
def test_routes_to_live(self):
"""Test that live mode routes correctly."""
result = validate_exchange_requirements(
required_exchanges={'binance'},
user_configured_exchanges=['binance'],
edm_available_exchanges=[],
mode='live'
)
assert result.valid is True
def test_empty_requirements_all_modes(self):
"""Test that empty requirements are valid for all modes."""
for mode in ['backtest', 'paper', 'live']:
result = validate_exchange_requirements(
required_exchanges=set(),
user_configured_exchanges=[],
edm_available_exchanges=[],
mode=mode
)
assert result.valid is True, f"Failed for mode: {mode}"
def test_unknown_mode_rejected(self):
"""Test that unknown mode is rejected."""
result = validate_exchange_requirements(
required_exchanges={'binance'},
user_configured_exchanges=[],
edm_available_exchanges=[],
mode='invalid_mode'
)
assert result.valid is False
assert 'Invalid trading mode' in result.message
class TestExchangeValidationResult:
"""Tests for ExchangeValidationResult class."""
def test_valid_result_to_dict(self):
"""Test serialization of valid result."""
result = ExchangeValidationResult(valid=True)
d = result.to_dict()
assert d == {"valid": True}
def test_invalid_result_to_dict(self):
"""Test serialization of invalid result."""
result = ExchangeValidationResult(
valid=False,
error_code=ValidationErrorCode.MISSING_CONFIG,
missing_exchanges={'binance', 'kucoin'},
message="API keys required"
)
d = result.to_dict()
assert d["valid"] is False
assert d["error_code"] == "missing_config"
assert set(d["missing_exchanges"]) == {'binance', 'kucoin'}
assert d["message"] == "API keys required"
class TestEdgeCases:
"""Tests for edge cases and error handling."""
def test_none_values_in_data_sources(self):
"""Test handling of None values in data sources."""
strategy = {
'strategy_components': {
'data_sources': [
(None, 'BTC/USDT', '1h'),
('binance', 'BTC/USDT', '1h'),
]
}
}
result = extract_required_exchanges(strategy)
assert result == {'binance'}
def test_empty_tuple_in_data_sources(self):
"""Test handling of empty tuples in data sources."""
strategy = {
'strategy_components': {
'data_sources': [
(),
('binance', 'BTC/USDT', '1h'),
]
}
}
result = extract_required_exchanges(strategy)
assert result == {'binance'}
def test_single_element_tuple(self):
"""Test handling of single element tuple (just exchange)."""
strategy = {
'strategy_components': {
'data_sources': [
('binance',),
]
}
}
result = extract_required_exchanges(strategy)
assert result == {'binance'}

View File

@ -47,6 +47,16 @@ class TestStartStrategyValidation:
bt.exchanges.get_exchange = MagicMock(return_value=mock_exchange) bt.exchanges.get_exchange = MagicMock(return_value=mock_exchange)
bt.exchanges.connect_exchange = MagicMock(return_value=True) bt.exchanges.connect_exchange = MagicMock(return_value=True)
# Mock EDM client for exchange validation
bt.edm_client = MagicMock()
bt.edm_client.get_exchanges_sync = MagicMock(return_value=['binance', 'kucoin'])
# Mock strategies.get_strategy_by_tbl_key for exchange validation
bt.strategies.get_strategy_by_tbl_key = MagicMock(return_value={
'strategy_components': {},
'default_source': {}
})
return bt return bt
def test_start_strategy_invalid_mode(self, mock_brighter_trades): def test_start_strategy_invalid_mode(self, mock_brighter_trades):
@ -517,6 +527,10 @@ class TestLiveModeWarning:
bt.exchanges.get_exchange = MagicMock(return_value=mock_exchange) bt.exchanges.get_exchange = MagicMock(return_value=mock_exchange)
bt.exchanges.connect_exchange = MagicMock(return_value=True) bt.exchanges.connect_exchange = MagicMock(return_value=True)
# Mock EDM client for exchange validation
bt.edm_client = MagicMock()
bt.edm_client.get_exchanges_sync = MagicMock(return_value=['binance', 'kucoin'])
# Set up valid strategy # Set up valid strategy
mock_strategy = pd.DataFrame([{ mock_strategy = pd.DataFrame([{
'tbl_key': 'test-strategy', 'tbl_key': 'test-strategy',
@ -531,6 +545,12 @@ class TestLiveModeWarning:
strategy_name='Test Strategy' strategy_name='Test Strategy'
) )
# Mock strategies.get_strategy_by_tbl_key for exchange validation
bt.strategies.get_strategy_by_tbl_key = MagicMock(return_value={
'strategy_components': {},
'default_source': {}
})
return bt return bt
def test_live_mode_returns_success(self, mock_brighter_trades): def test_live_mode_returns_success(self, mock_brighter_trades):