Fix backtest performance: cache compilation and skip unnecessary DB writes

Performance optimizations for backtesting:
- Cache compiled strategy code instead of recompiling every iteration
- Skip save_context() for backtests (no DB persistence needed per tick)
- Check cache existence before recreating broker state caches
- Add margin position processing to backtest strategy

These fixes eliminate O(n) slowdown that caused backtests to progressively
slow to a crawl. Backtests now maintain consistent speed throughout.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
rob 2026-03-15 16:49:18 -03:00
parent b2d40d3d2f
commit 533a00dce7
4 changed files with 652 additions and 26 deletions

View File

@ -7,11 +7,226 @@ from trade import Trades
import datetime as dt
import json
import traceback
from typing import Any
from typing import Any, Optional
# Configure logging
logger = logging.getLogger(__name__)
class UndefinedValue:
"""
Sentinel object returned when accessing undefined strategy variables.
- All comparisons return False (and log a warning)
- Math operations return 0 (and log a warning)
- String conversion returns empty string
This prevents strategy crashes while alerting the user to the issue.
"""
def __init__(self, var_name: str, warn_callback=None):
self._var_name = var_name
self._warn_callback = warn_callback
self._warned = False # Only warn once per variable per tick
def set_warn_callback(self, warn_callback):
"""Update the warning callback used by this sentinel."""
self._warn_callback = warn_callback
def reset_warning(self):
"""Allow this undefined value to warn again on the next strategy tick."""
self._warned = False
def _warn(self, operation: str):
"""Log warning about undefined variable usage."""
if not self._warned:
msg = f"Variable '{self._var_name}' used in {operation} before being set"
logger.warning(msg)
if self._warn_callback:
self._warn_callback(msg)
self._warned = True
# Comparison operators - all return False
def __lt__(self, other):
self._warn("comparison (<)")
return False
def __le__(self, other):
self._warn("comparison (<=)")
return False
def __gt__(self, other):
self._warn("comparison (>)")
return False
def __ge__(self, other):
self._warn("comparison (>=)")
return False
def __eq__(self, other):
if other is None:
return True # UndefinedValue == None is True for backwards compat
self._warn("comparison (==)")
return False
def __ne__(self, other):
if other is None:
return False # UndefinedValue != None is False for backwards compat
self._warn("comparison (!=)")
return True
# Math operators - return 0
def __add__(self, other):
self._warn("math operation (+)")
return 0
def __radd__(self, other):
self._warn("math operation (+)")
return other # other + undefined = other
def __sub__(self, other):
self._warn("math operation (-)")
return 0
def __rsub__(self, other):
self._warn("math operation (-)")
return other # other - undefined = other
def __mul__(self, other):
self._warn("math operation (*)")
return 0
def __rmul__(self, other):
self._warn("math operation (*)")
return 0
def __truediv__(self, other):
self._warn("math operation (/)")
return 0
def __rtruediv__(self, other):
self._warn("math operation (/)")
return 0 # Can't divide by undefined
def __floordiv__(self, other):
self._warn("math operation (//)")
return 0
def __rfloordiv__(self, other):
self._warn("math operation (//)")
return 0
def __mod__(self, other):
self._warn("math operation (%)")
return 0
def __neg__(self):
self._warn("math operation (negation)")
return 0
def __pos__(self):
self._warn("math operation (positive)")
return 0
def __abs__(self):
self._warn("math operation (abs)")
return 0
# Type conversions
def __str__(self):
return ""
def __repr__(self):
return f"UndefinedValue('{self._var_name}')"
def __bool__(self):
# Undefined is falsy
return False
def __float__(self):
self._warn("conversion to float")
return 0.0
def __int__(self):
self._warn("conversion to int")
return 0
def __format__(self, format_spec):
"""Handle f-string formatting like f'{value:.2f}'."""
if not format_spec:
# Plain f"{value}" - use __str__ behavior (empty string)
return ""
self._warn("format string")
# Return formatted 0 using the same format spec
return format(0, format_spec)
class StrategyVariables(dict):
"""
Custom dict for strategy variables that returns UndefinedValue for missing keys.
This allows strategies to reference variables before they're defined without crashing,
while still warning the user about the issue.
"""
_UNSET = object()
def __init__(self, warn_callback=None, *args, **kwargs):
self._undefined_cache: dict[str, UndefinedValue] = {}
super().__init__(*args, **kwargs)
self._warn_callback = warn_callback
def get(self, key, default=_UNSET):
"""
Return dict values normally, but surface UndefinedValue when no real fallback is supplied.
Missing keys use UndefinedValue when:
- no default is provided, or
- the caller explicitly passes ``None`` (for backwards compatibility with
previously generated strategy code).
"""
if key in self:
return super().__getitem__(key)
if default is self._UNSET or default is None:
return self.get_undefined(key)
return default
def get_undefined(self, key):
"""Return a cached UndefinedValue sentinel for a missing variable name."""
if key not in self._undefined_cache:
self._undefined_cache[key] = UndefinedValue(key, self._warn_callback)
return self._undefined_cache[key]
def set_warn_callback(self, callback):
"""Set the warning callback after initialization."""
self._warn_callback = callback
for undefined_value in self._undefined_cache.values():
undefined_value.set_warn_callback(callback)
def reset_undefined_warnings(self):
"""Reset cached UndefinedValue warning state at the start of a strategy tick."""
for undefined_value in self._undefined_cache.values():
undefined_value.reset_warning()
def __setitem__(self, key, value):
self._undefined_cache.pop(key, None)
super().__setitem__(key, value)
def __delitem__(self, key):
self._undefined_cache.pop(key, None)
super().__delitem__(key)
def pop(self, key, default=_UNSET):
self._undefined_cache.pop(key, None)
if default is self._UNSET:
return super().pop(key)
return super().pop(key, default)
def clear(self):
self._undefined_cache.clear()
super().clear()
class StrategyInstance:
def __init__(self, strategy_instance_id: str, strategy_id: str, strategy_name: str,
user_id: int, generated_code: str, data_cache: Any, indicators: Any | None, trades: Any | None,
@ -49,7 +264,7 @@ class StrategyInstance:
# Initialize context variables
self.flags: dict[str, Any] = {}
self.variables: dict[str, Any] = {}
self.variables = StrategyVariables(warn_callback=self._variable_warning)
self.starting_balance: float = 0.0
self.current_balance: float = 0.0
self.available_balance: float = 0.0
@ -103,6 +318,16 @@ class StrategyInstance:
'current_balance': self.current_balance,
'available_balance': self.available_balance,
'available_strategy_balance': self.available_strategy_balance,
# Margin trading methods (Paper/Live mode only)
'set_default_margin_leverage': self.set_default_margin_leverage,
'get_total_margin_used': self.get_total_margin_used,
'open_margin_position': self.open_margin_position,
'close_margin_position': self.close_margin_position,
'has_margin_position': self.has_margin_position,
'get_liquidation_buffer_pct': self.get_liquidation_buffer_pct,
'get_unrealized_pnl': self.get_unrealized_pnl,
# Asset conversion
'convert_asset': self.convert_asset,
}
# Automatically load or initialize the context
@ -136,7 +361,7 @@ class StrategyInstance:
Initializes a new context for the strategy instance and saves it to the cache.
"""
self.flags = {}
self.variables = {}
self.variables = StrategyVariables(warn_callback=self._variable_warning)
self.profit_loss = 0.0
self.active = True
self.paused = False
@ -167,7 +392,8 @@ class StrategyInstance:
try:
context = context_data.iloc[0].to_dict()
self.flags = json.loads(context.get('flags', '{}'))
self.variables = json.loads(context.get('variables', '{}'))
loaded_vars = json.loads(context.get('variables', '{}'))
self.variables = StrategyVariables(warn_callback=self._variable_warning, **loaded_vars)
self.profit_loss = context.get('profit_loss', 0.0)
self.active = bool(context.get('active', 1))
self.paused = bool(context.get('paused', 0))
@ -370,14 +596,18 @@ class StrategyInstance:
:return: Result of the execution.
"""
try:
if isinstance(self.variables, StrategyVariables):
self.variables.reset_undefined_warnings()
# Log the generated code once for debugging
if not hasattr(self, '_code_logged'):
logger.info(f"Strategy {self.strategy_id} generated code:\n{self.generated_code}")
self._code_logged = True
# Compile the generated code with a meaningful filename
compiled_code = compile(self.generated_code, '<strategy_code>', 'exec')
exec(compiled_code, self.exec_context)
# Compile the generated code ONCE and cache it (compile is expensive)
if not hasattr(self, '_compiled_code'):
self._compiled_code = compile(self.generated_code, '<strategy_code>', 'exec')
exec(self._compiled_code, self.exec_context)
# Call the 'next()' method if defined
if 'next' in self.exec_context and callable(self.exec_context['next']):
@ -396,6 +626,10 @@ class StrategyInstance:
# Retrieve and update profit/loss
self.profit_loss = self.exec_context.get('profit_loss', self.profit_loss)
# Skip save_context() for backtests - it's expensive and not needed
# Backtests don't need to persist state to DB on every tick
if not hasattr(self, 'backtrader_strategy') or self.backtrader_strategy is None:
self.save_context()
return {"success": True, "profit_loss": self.profit_loss}
@ -729,6 +963,14 @@ class StrategyInstance:
exc_info=True)
traceback.print_exc()
def _variable_warning(self, message: str):
"""
Callback for undefined variable warnings.
Override in subclasses to customize warning behavior (e.g., collect alerts).
"""
# Base implementation just logs - subclasses can add to alerts
logger.warning(f"[Strategy {self.strategy_id}] {message}")
def notify_user(self, message: str):
"""
Sends a notification to the user.
@ -975,3 +1217,346 @@ class StrategyInstance:
logger.debug(f"Accumulated fee: {result['fee_charged']} sats for trade worth ${trade_value_usd:.2f}")
return result
# ==============================
# Margin Trading Methods (Stubs)
# ==============================
# These are base implementations that should be overridden by subclasses
# that support margin trading (PaperStrategyInstance, LiveStrategyInstance).
def set_default_margin_leverage(self, leverage: float) -> None:
"""
Set the default leverage for margin trades.
This is idempotent - calling it multiple times has no additional effect.
Stores the leverage in variables for persistence across ticks.
Base implementation stores leverage but does nothing with it.
Override in subclasses that support margin trading.
:param leverage: Leverage multiplier (1-10x).
"""
new_leverage = max(1, min(int(leverage), 10))
current_leverage = self.variables['_margin_leverage_default'] if '_margin_leverage_default' in self.variables else None
if current_leverage != new_leverage:
self.variables['_margin_leverage_default'] = new_leverage
logger.debug(f"Strategy '{self.strategy_id}' margin leverage set to {new_leverage}x")
def get_margin_leverage(self) -> float:
"""
Get the current default margin leverage.
:return: The default leverage (1-10x), defaults to 3.
"""
return self.variables.get('_margin_leverage_default', 3)
def get_total_margin_used(self) -> float:
"""
Get the total margin (collateral) currently locked in margin positions.
Base implementation returns 0. Override in subclasses.
:return: Total margin used in quote currency.
"""
logger.warning("get_total_margin_used called on base StrategyInstance - margin trading not available")
return 0.0
def _ensure_margin_runtime_state(self) -> tuple[dict[str, Any], dict[str, Any]]:
"""Ensure persistent margin runtime state containers exist in strategy variables."""
if not hasattr(self, 'variables') or not isinstance(self.variables, dict):
self.variables = StrategyVariables(warn_callback=self._variable_warning)
pending_entries = self.variables.get('_pending_margin_entries')
if not isinstance(pending_entries, dict):
pending_entries = {}
self.variables['_pending_margin_entries'] = pending_entries
trailing_stops = self.variables.get('_margin_trailing_stops')
if not isinstance(trailing_stops, dict):
trailing_stops = {}
self.variables['_margin_trailing_stops'] = trailing_stops
return pending_entries, trailing_stops
def _get_pending_margin_entries(self) -> dict[str, Any]:
"""Return persisted pending margin-entry state."""
pending_entries, _ = self._ensure_margin_runtime_state()
return pending_entries
def _get_margin_trailing_stops(self) -> dict[str, Any]:
"""Return persisted trailing-stop state for margin positions."""
_, trailing_stops = self._ensure_margin_runtime_state()
return trailing_stops
def _normalize_margin_tif(self, tif: str | None) -> str:
"""Normalize time-in-force values to GTC/IOC/FOK."""
tif_value = (tif or 'GTC')
if isinstance(tif_value, dict):
tif_value = tif_value.get('time_in_force', 'GTC')
tif_value = str(tif_value).upper()
if tif_value not in {'GTC', 'IOC', 'FOK'}:
logger.warning(f"Invalid margin time-in-force '{tif_value}', defaulting to GTC")
return 'GTC'
return tif_value
def _extract_margin_value(self, option: Any, default_key: str = 'value') -> Optional[float]:
"""Extract a numeric value from shared Blockly option payloads."""
if option is None:
return None
if isinstance(option, (int, float)):
return float(option)
if isinstance(option, dict):
value = option.get(default_key)
if value is None and len(option) == 1:
value = next(iter(option.values()))
if isinstance(value, (int, float)):
return float(value)
return None
def _extract_margin_limit_price(self, limit: Any) -> Optional[float]:
"""Extract a limit-entry price from the shared limit option."""
return self._extract_margin_value(limit, default_key='limit')
def _extract_margin_trailing_distance(self, option: Any, key: str) -> Optional[float]:
"""Extract a trailing distance from the shared trailing option."""
return self._extract_margin_value(option, default_key=key)
def _resolve_margin_target(self, target_market: Optional[dict[str, Any]]) -> tuple[str, Optional[str], Optional[str]]:
"""Resolve symbol/exchange/timeframe for a margin action."""
symbol = self.exec_context.get('current_symbol', 'BTC/USDT')
exchange = None
timeframe = None
if isinstance(target_market, dict) and target_market:
symbol = (
target_market.get('symbol')
or target_market.get('market')
or symbol
)
exchange = target_market.get('exchange')
timeframe = target_market.get('time_frame') or target_market.get('timeframe')
return symbol, exchange, timeframe
def _normalize_margin_order_name(self, name_order: Any) -> Optional[str]:
"""Extract an optional user-facing order name from shared option payloads."""
if isinstance(name_order, dict):
order_name = name_order.get('order_name')
elif isinstance(name_order, str):
order_name = name_order
else:
order_name = None
if order_name is None:
return None
normalized = str(order_name).strip()
return normalized or None
def _register_margin_trailing_stop(
self,
symbol: str,
side: str,
trail_distance: float,
current_price: float,
order_name: Optional[str] = None
) -> None:
"""Persist a trailing-stop configuration for an open margin position."""
trailing_stops = self._get_margin_trailing_stops()
trailing_stops[symbol] = {
'side': side,
'trail_distance': float(trail_distance),
'best_price': float(current_price),
'order_name': order_name,
}
def _clear_margin_trailing_stop(self, symbol: str) -> None:
"""Remove trailing-stop state for a symbol if present."""
trailing_stops = self._get_margin_trailing_stops()
trailing_stops.pop(symbol, None)
def open_margin_position(
self,
side: str,
collateral: float,
leverage: float = None,
stop_loss: float = None,
take_profit: float = None,
tif: str = 'GTC',
trailing_stop: dict = None,
limit: dict = None,
trailing_limit: dict = None,
target_market: dict = None,
name_order: Any = None
) -> dict:
"""
Open a margin position.
Base implementation pauses strategy and raises error.
Override in subclasses that support margin trading.
:param side: 'long' or 'short'.
:param collateral: Amount of collateral in quote currency.
:param leverage: Leverage multiplier (uses default if not specified).
:param stop_loss: Optional stop loss price.
:param take_profit: Optional take profit price.
:param tif: Time in force for limit-style entry flows.
:param trailing_stop: Optional trailing-stop config.
:param limit: Optional limit-entry config.
:param trailing_limit: Optional trailing-limit entry config.
:param target_market: Optional symbol/exchange/timeframe override.
:param name_order: Optional user-facing name for the order.
:return: Position details dict.
:raises NotImplementedError: In base class.
"""
self.notify_user("MARGIN TRADING NOT AVAILABLE: This mode does not support margin trading")
self.set_paused(True)
raise NotImplementedError("Margin trading requires Paper or Live mode")
def close_margin_position(self, symbol: str = None, percentage: float = 100.0) -> dict:
"""
Close a margin position.
Base implementation pauses strategy and raises error.
Override in subclasses that support margin trading.
:param symbol: Symbol to close (uses current symbol if not specified).
:param percentage: Percentage of position to close (default 100%).
:return: Close result dict.
:raises NotImplementedError: In base class.
"""
self.notify_user("MARGIN TRADING NOT AVAILABLE: This mode does not support margin trading")
self.set_paused(True)
raise NotImplementedError("Margin trading requires Paper or Live mode")
def has_margin_position(self, symbol: str = None) -> bool:
"""
Check if there is an open margin position for the symbol.
Base implementation returns False. Override in subclasses.
:param symbol: Symbol to check (uses current symbol if not specified).
:return: True if position exists, False otherwise.
"""
return False
def get_liquidation_buffer_pct(self, symbol: str = None) -> float:
"""
Get the liquidation buffer percentage for a margin position.
Returns how far current price is from liquidation:
- 100% = price at entry (maximum safety)
- 50% = price halfway to liquidation
- 0% = at liquidation price
Base implementation returns 100 (no position). Override in subclasses.
:param symbol: Symbol to check (uses current symbol if not specified).
:return: Liquidation buffer percentage (0-100).
"""
return 100.0 # No position = no liquidation risk
def get_unrealized_pnl(self, symbol: str = None) -> float:
"""
Get the unrealized P/L for a margin position.
Base implementation returns 0. Override in subclasses.
:param symbol: Symbol to check (uses current symbol if not specified).
:return: Unrealized P/L in quote currency.
"""
return 0.0
# ==============================
# Asset Conversion
# ==============================
def convert_asset(self, amount: float, from_asset: str, to_asset: str,
exchange: str = 'binance', timeframe: str = '1h') -> float:
"""
Convert an amount from one asset to another using current market prices.
Examples:
convert_asset(10, 'USD', 'BTC') -> Returns how much BTC equals $10
convert_asset(0.001, 'BTC', 'USD') -> Returns USD value of 0.001 BTC
convert_asset(100, 'USDT', 'ETH') -> Returns how much ETH equals 100 USDT
:param amount: The amount to convert.
:param from_asset: The source asset (e.g., 'USD', 'BTC', 'ETH').
:param to_asset: The target asset (e.g., 'BTC', 'USD', 'ETH').
:param exchange: Exchange to use for price data (default: 'binance').
:param timeframe: Timeframe for price data (default: '1h').
:return: The converted amount in the target asset.
"""
try:
# Normalize asset names
from_asset = from_asset.upper().strip()
to_asset = to_asset.upper().strip()
# If same asset, no conversion needed
if from_asset == to_asset:
return amount
# Define stablecoin equivalents (treat as USD)
stablecoins = {'USD', 'USDT', 'USDC', 'BUSD', 'DAI'}
# Check if both are stablecoins (1:1 conversion)
if from_asset in stablecoins and to_asset in stablecoins:
return amount
# Normalize stablecoins to USDT for pair lookup (most common on exchanges)
from_normalized = 'USDT' if from_asset in stablecoins else from_asset
to_normalized = 'USDT' if to_asset in stablecoins else to_asset
# Case 1: Converting stablecoin to crypto (e.g., USD -> BTC)
# Look up CRYPTO/STABLECOIN pair (e.g., BTC/USDT) and divide
if from_asset in stablecoins and to_asset not in stablecoins:
price = self._get_conversion_price(to_normalized, 'USDT', exchange, timeframe)
if price is not None and price > 0:
return amount / price
return 0.0
# Case 2: Converting crypto to stablecoin (e.g., BTC -> USD)
# Look up CRYPTO/STABLECOIN pair (e.g., BTC/USDT) and multiply
if from_asset not in stablecoins and to_asset in stablecoins:
price = self._get_conversion_price(from_normalized, 'USDT', exchange, timeframe)
if price is not None:
return amount * price
return 0.0
# Case 3: Converting between cryptos (e.g., ETH -> BTC)
# Go through USD: ETH -> USD -> BTC
from_usd_price = self._get_conversion_price(from_normalized, 'USDT', exchange, timeframe)
to_usd_price = self._get_conversion_price(to_normalized, 'USDT', exchange, timeframe)
if from_usd_price is not None and to_usd_price is not None and to_usd_price > 0:
# Convert from_asset to USD, then USD to to_asset
usd_amount = amount * from_usd_price
return usd_amount / to_usd_price
logger.warning(f"Could not find price for {from_asset}/{to_asset} conversion")
return 0.0
except Exception as e:
logger.error(f"Error converting {amount} {from_asset} to {to_asset}: {e}", exc_info=True)
return 0.0
def _get_conversion_price(self, base: str, quote: str, exchange: str, timeframe: str) -> Optional[float]:
"""
Get the price for a trading pair.
:param base: Base asset (e.g., 'BTC').
:param quote: Quote asset (e.g., 'USDT').
:param exchange: Exchange name.
:param timeframe: Timeframe for price data.
:return: Price or None if not available.
"""
try:
symbol = f"{base}/{quote}"
price = self.get_current_price(timeframe=timeframe, exchange=exchange, symbol=symbol)
if price is not None and price > 0:
return price
return None
except Exception as e:
logger.debug(f"Could not get price for {base}/{quote}: {e}")
return None

View File

@ -1061,6 +1061,8 @@ class LiveBroker(BaseBroker):
[]
)
# Only create the cache if it doesn't already exist
if 'live_broker_states' not in self._data_cache.caches:
self._data_cache.create_cache(
name='live_broker_states',
cache_type='table',

View File

@ -625,6 +625,8 @@ class PaperBroker(BaseBroker):
[]
)
# Only create the cache if it doesn't already exist
if 'paper_broker_states' not in self._data_cache.caches:
self._data_cache.create_cache(
name='paper_broker_states',
cache_type='table',

View File

@ -53,6 +53,7 @@ class MappedStrategy(bt.Strategy):
# Initialize lists to store orders and trades
self.orders = []
self.trade_list = []
self.margin_trade_list = [] # Margin trade history
# Initialize other needed variables
self.starting_balance = self.broker.getvalue()
@ -160,6 +161,21 @@ class MappedStrategy(bt.Strategy):
# Execute the strategy logic
self.execute_strategy()
# Process margin positions
if hasattr(self.strategy_instance, 'process_margin_tick'):
margin_events = self.strategy_instance.process_margin_tick()
for event in margin_events:
if event.get('type') == 'liquidation':
self.strategy_instance.notify_user(
f"LIQUIDATION: {event['symbol']} at ${event.get('price', 0):.2f}, "
f"P&L: ${event.get('pnl', 0):.2f}"
)
elif event.get('type') == 'sltp_triggered':
trigger = event.get('trigger', 'unknown').upper()
self.strategy_instance.notify_user(
f"{trigger}: {event['symbol']} at ${event.get('price', 0):.2f}"
)
# Advance indicator pointers for the next candle
for name in self.indicator_names:
if name in self.indicator_pointers:
@ -209,10 +225,31 @@ class MappedStrategy(bt.Strategy):
self.last_progress = progress
def stop(self):
# Close all open positions
# Close all open spot positions
if self.position:
self.close()
self.log(f"Closing remaining position at the end of backtest.")
self.log("Closing remaining spot position at the end of backtest.")
# Close any open margin positions (mark-to-market)
if hasattr(self.strategy_instance, 'paper_margin_broker'):
broker = self.strategy_instance.paper_margin_broker
if broker:
positions = list(broker.get_all_positions())
for pos in positions:
self.log(f"Closing margin position {pos.symbol} at end of backtest.")
try:
broker.close_position(pos.symbol)
except Exception as e:
self.log(f"Error closing margin position {pos.symbol}: {e}")
# Copy margin trade history for results
self.margin_trade_list = broker.get_position_history()
# Transfer final balance back to spot
margin_balance = broker.get_balance()
if margin_balance > 0:
self.broker.setcash(self.broker.getcash() + margin_balance)
broker._balance = 0
def t_order(
self,