Fix backtest performance: cache compilation and skip unnecessary DB writes
Performance optimizations for backtesting: - Cache compiled strategy code instead of recompiling every iteration - Skip save_context() for backtests (no DB persistence needed per tick) - Check cache existence before recreating broker state caches - Add margin position processing to backtest strategy These fixes eliminate O(n) slowdown that caused backtests to progressively slow to a crawl. Backtests now maintain consistent speed throughout. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
b2d40d3d2f
commit
533a00dce7
|
|
@ -7,11 +7,226 @@ from trade import Trades
|
|||
import datetime as dt
|
||||
import json
|
||||
import traceback
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UndefinedValue:
|
||||
"""
|
||||
Sentinel object returned when accessing undefined strategy variables.
|
||||
|
||||
- All comparisons return False (and log a warning)
|
||||
- Math operations return 0 (and log a warning)
|
||||
- String conversion returns empty string
|
||||
|
||||
This prevents strategy crashes while alerting the user to the issue.
|
||||
"""
|
||||
|
||||
def __init__(self, var_name: str, warn_callback=None):
|
||||
self._var_name = var_name
|
||||
self._warn_callback = warn_callback
|
||||
self._warned = False # Only warn once per variable per tick
|
||||
|
||||
def set_warn_callback(self, warn_callback):
|
||||
"""Update the warning callback used by this sentinel."""
|
||||
self._warn_callback = warn_callback
|
||||
|
||||
def reset_warning(self):
|
||||
"""Allow this undefined value to warn again on the next strategy tick."""
|
||||
self._warned = False
|
||||
|
||||
def _warn(self, operation: str):
|
||||
"""Log warning about undefined variable usage."""
|
||||
if not self._warned:
|
||||
msg = f"Variable '{self._var_name}' used in {operation} before being set"
|
||||
logger.warning(msg)
|
||||
if self._warn_callback:
|
||||
self._warn_callback(msg)
|
||||
self._warned = True
|
||||
|
||||
# Comparison operators - all return False
|
||||
def __lt__(self, other):
|
||||
self._warn("comparison (<)")
|
||||
return False
|
||||
|
||||
def __le__(self, other):
|
||||
self._warn("comparison (<=)")
|
||||
return False
|
||||
|
||||
def __gt__(self, other):
|
||||
self._warn("comparison (>)")
|
||||
return False
|
||||
|
||||
def __ge__(self, other):
|
||||
self._warn("comparison (>=)")
|
||||
return False
|
||||
|
||||
def __eq__(self, other):
|
||||
if other is None:
|
||||
return True # UndefinedValue == None is True for backwards compat
|
||||
self._warn("comparison (==)")
|
||||
return False
|
||||
|
||||
def __ne__(self, other):
|
||||
if other is None:
|
||||
return False # UndefinedValue != None is False for backwards compat
|
||||
self._warn("comparison (!=)")
|
||||
return True
|
||||
|
||||
# Math operators - return 0
|
||||
def __add__(self, other):
|
||||
self._warn("math operation (+)")
|
||||
return 0
|
||||
|
||||
def __radd__(self, other):
|
||||
self._warn("math operation (+)")
|
||||
return other # other + undefined = other
|
||||
|
||||
def __sub__(self, other):
|
||||
self._warn("math operation (-)")
|
||||
return 0
|
||||
|
||||
def __rsub__(self, other):
|
||||
self._warn("math operation (-)")
|
||||
return other # other - undefined = other
|
||||
|
||||
def __mul__(self, other):
|
||||
self._warn("math operation (*)")
|
||||
return 0
|
||||
|
||||
def __rmul__(self, other):
|
||||
self._warn("math operation (*)")
|
||||
return 0
|
||||
|
||||
def __truediv__(self, other):
|
||||
self._warn("math operation (/)")
|
||||
return 0
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
self._warn("math operation (/)")
|
||||
return 0 # Can't divide by undefined
|
||||
|
||||
def __floordiv__(self, other):
|
||||
self._warn("math operation (//)")
|
||||
return 0
|
||||
|
||||
def __rfloordiv__(self, other):
|
||||
self._warn("math operation (//)")
|
||||
return 0
|
||||
|
||||
def __mod__(self, other):
|
||||
self._warn("math operation (%)")
|
||||
return 0
|
||||
|
||||
def __neg__(self):
|
||||
self._warn("math operation (negation)")
|
||||
return 0
|
||||
|
||||
def __pos__(self):
|
||||
self._warn("math operation (positive)")
|
||||
return 0
|
||||
|
||||
def __abs__(self):
|
||||
self._warn("math operation (abs)")
|
||||
return 0
|
||||
|
||||
# Type conversions
|
||||
def __str__(self):
|
||||
return ""
|
||||
|
||||
def __repr__(self):
|
||||
return f"UndefinedValue('{self._var_name}')"
|
||||
|
||||
def __bool__(self):
|
||||
# Undefined is falsy
|
||||
return False
|
||||
|
||||
def __float__(self):
|
||||
self._warn("conversion to float")
|
||||
return 0.0
|
||||
|
||||
def __int__(self):
|
||||
self._warn("conversion to int")
|
||||
return 0
|
||||
|
||||
def __format__(self, format_spec):
|
||||
"""Handle f-string formatting like f'{value:.2f}'."""
|
||||
if not format_spec:
|
||||
# Plain f"{value}" - use __str__ behavior (empty string)
|
||||
return ""
|
||||
self._warn("format string")
|
||||
# Return formatted 0 using the same format spec
|
||||
return format(0, format_spec)
|
||||
|
||||
|
||||
class StrategyVariables(dict):
|
||||
"""
|
||||
Custom dict for strategy variables that returns UndefinedValue for missing keys.
|
||||
|
||||
This allows strategies to reference variables before they're defined without crashing,
|
||||
while still warning the user about the issue.
|
||||
"""
|
||||
|
||||
_UNSET = object()
|
||||
|
||||
def __init__(self, warn_callback=None, *args, **kwargs):
|
||||
self._undefined_cache: dict[str, UndefinedValue] = {}
|
||||
super().__init__(*args, **kwargs)
|
||||
self._warn_callback = warn_callback
|
||||
|
||||
def get(self, key, default=_UNSET):
|
||||
"""
|
||||
Return dict values normally, but surface UndefinedValue when no real fallback is supplied.
|
||||
|
||||
Missing keys use UndefinedValue when:
|
||||
- no default is provided, or
|
||||
- the caller explicitly passes ``None`` (for backwards compatibility with
|
||||
previously generated strategy code).
|
||||
"""
|
||||
if key in self:
|
||||
return super().__getitem__(key)
|
||||
if default is self._UNSET or default is None:
|
||||
return self.get_undefined(key)
|
||||
return default
|
||||
|
||||
def get_undefined(self, key):
|
||||
"""Return a cached UndefinedValue sentinel for a missing variable name."""
|
||||
if key not in self._undefined_cache:
|
||||
self._undefined_cache[key] = UndefinedValue(key, self._warn_callback)
|
||||
return self._undefined_cache[key]
|
||||
|
||||
def set_warn_callback(self, callback):
|
||||
"""Set the warning callback after initialization."""
|
||||
self._warn_callback = callback
|
||||
for undefined_value in self._undefined_cache.values():
|
||||
undefined_value.set_warn_callback(callback)
|
||||
|
||||
def reset_undefined_warnings(self):
|
||||
"""Reset cached UndefinedValue warning state at the start of a strategy tick."""
|
||||
for undefined_value in self._undefined_cache.values():
|
||||
undefined_value.reset_warning()
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self._undefined_cache.pop(key, None)
|
||||
super().__setitem__(key, value)
|
||||
|
||||
def __delitem__(self, key):
|
||||
self._undefined_cache.pop(key, None)
|
||||
super().__delitem__(key)
|
||||
|
||||
def pop(self, key, default=_UNSET):
|
||||
self._undefined_cache.pop(key, None)
|
||||
if default is self._UNSET:
|
||||
return super().pop(key)
|
||||
return super().pop(key, default)
|
||||
|
||||
def clear(self):
|
||||
self._undefined_cache.clear()
|
||||
super().clear()
|
||||
|
||||
|
||||
class StrategyInstance:
|
||||
def __init__(self, strategy_instance_id: str, strategy_id: str, strategy_name: str,
|
||||
user_id: int, generated_code: str, data_cache: Any, indicators: Any | None, trades: Any | None,
|
||||
|
|
@ -49,7 +264,7 @@ class StrategyInstance:
|
|||
|
||||
# Initialize context variables
|
||||
self.flags: dict[str, Any] = {}
|
||||
self.variables: dict[str, Any] = {}
|
||||
self.variables = StrategyVariables(warn_callback=self._variable_warning)
|
||||
self.starting_balance: float = 0.0
|
||||
self.current_balance: float = 0.0
|
||||
self.available_balance: float = 0.0
|
||||
|
|
@ -103,6 +318,16 @@ class StrategyInstance:
|
|||
'current_balance': self.current_balance,
|
||||
'available_balance': self.available_balance,
|
||||
'available_strategy_balance': self.available_strategy_balance,
|
||||
# Margin trading methods (Paper/Live mode only)
|
||||
'set_default_margin_leverage': self.set_default_margin_leverage,
|
||||
'get_total_margin_used': self.get_total_margin_used,
|
||||
'open_margin_position': self.open_margin_position,
|
||||
'close_margin_position': self.close_margin_position,
|
||||
'has_margin_position': self.has_margin_position,
|
||||
'get_liquidation_buffer_pct': self.get_liquidation_buffer_pct,
|
||||
'get_unrealized_pnl': self.get_unrealized_pnl,
|
||||
# Asset conversion
|
||||
'convert_asset': self.convert_asset,
|
||||
}
|
||||
|
||||
# Automatically load or initialize the context
|
||||
|
|
@ -136,7 +361,7 @@ class StrategyInstance:
|
|||
Initializes a new context for the strategy instance and saves it to the cache.
|
||||
"""
|
||||
self.flags = {}
|
||||
self.variables = {}
|
||||
self.variables = StrategyVariables(warn_callback=self._variable_warning)
|
||||
self.profit_loss = 0.0
|
||||
self.active = True
|
||||
self.paused = False
|
||||
|
|
@ -167,7 +392,8 @@ class StrategyInstance:
|
|||
try:
|
||||
context = context_data.iloc[0].to_dict()
|
||||
self.flags = json.loads(context.get('flags', '{}'))
|
||||
self.variables = json.loads(context.get('variables', '{}'))
|
||||
loaded_vars = json.loads(context.get('variables', '{}'))
|
||||
self.variables = StrategyVariables(warn_callback=self._variable_warning, **loaded_vars)
|
||||
self.profit_loss = context.get('profit_loss', 0.0)
|
||||
self.active = bool(context.get('active', 1))
|
||||
self.paused = bool(context.get('paused', 0))
|
||||
|
|
@ -370,14 +596,18 @@ class StrategyInstance:
|
|||
:return: Result of the execution.
|
||||
"""
|
||||
try:
|
||||
if isinstance(self.variables, StrategyVariables):
|
||||
self.variables.reset_undefined_warnings()
|
||||
|
||||
# Log the generated code once for debugging
|
||||
if not hasattr(self, '_code_logged'):
|
||||
logger.info(f"Strategy {self.strategy_id} generated code:\n{self.generated_code}")
|
||||
self._code_logged = True
|
||||
|
||||
# Compile the generated code with a meaningful filename
|
||||
compiled_code = compile(self.generated_code, '<strategy_code>', 'exec')
|
||||
exec(compiled_code, self.exec_context)
|
||||
# Compile the generated code ONCE and cache it (compile is expensive)
|
||||
if not hasattr(self, '_compiled_code'):
|
||||
self._compiled_code = compile(self.generated_code, '<strategy_code>', 'exec')
|
||||
exec(self._compiled_code, self.exec_context)
|
||||
|
||||
# Call the 'next()' method if defined
|
||||
if 'next' in self.exec_context and callable(self.exec_context['next']):
|
||||
|
|
@ -396,6 +626,10 @@ class StrategyInstance:
|
|||
|
||||
# Retrieve and update profit/loss
|
||||
self.profit_loss = self.exec_context.get('profit_loss', self.profit_loss)
|
||||
|
||||
# Skip save_context() for backtests - it's expensive and not needed
|
||||
# Backtests don't need to persist state to DB on every tick
|
||||
if not hasattr(self, 'backtrader_strategy') or self.backtrader_strategy is None:
|
||||
self.save_context()
|
||||
|
||||
return {"success": True, "profit_loss": self.profit_loss}
|
||||
|
|
@ -729,6 +963,14 @@ class StrategyInstance:
|
|||
exc_info=True)
|
||||
traceback.print_exc()
|
||||
|
||||
def _variable_warning(self, message: str):
|
||||
"""
|
||||
Callback for undefined variable warnings.
|
||||
Override in subclasses to customize warning behavior (e.g., collect alerts).
|
||||
"""
|
||||
# Base implementation just logs - subclasses can add to alerts
|
||||
logger.warning(f"[Strategy {self.strategy_id}] {message}")
|
||||
|
||||
def notify_user(self, message: str):
|
||||
"""
|
||||
Sends a notification to the user.
|
||||
|
|
@ -975,3 +1217,346 @@ class StrategyInstance:
|
|||
logger.debug(f"Accumulated fee: {result['fee_charged']} sats for trade worth ${trade_value_usd:.2f}")
|
||||
|
||||
return result
|
||||
|
||||
# ==============================
|
||||
# Margin Trading Methods (Stubs)
|
||||
# ==============================
|
||||
# These are base implementations that should be overridden by subclasses
|
||||
# that support margin trading (PaperStrategyInstance, LiveStrategyInstance).
|
||||
|
||||
def set_default_margin_leverage(self, leverage: float) -> None:
|
||||
"""
|
||||
Set the default leverage for margin trades.
|
||||
|
||||
This is idempotent - calling it multiple times has no additional effect.
|
||||
Stores the leverage in variables for persistence across ticks.
|
||||
|
||||
Base implementation stores leverage but does nothing with it.
|
||||
Override in subclasses that support margin trading.
|
||||
|
||||
:param leverage: Leverage multiplier (1-10x).
|
||||
"""
|
||||
new_leverage = max(1, min(int(leverage), 10))
|
||||
current_leverage = self.variables['_margin_leverage_default'] if '_margin_leverage_default' in self.variables else None
|
||||
if current_leverage != new_leverage:
|
||||
self.variables['_margin_leverage_default'] = new_leverage
|
||||
logger.debug(f"Strategy '{self.strategy_id}' margin leverage set to {new_leverage}x")
|
||||
|
||||
def get_margin_leverage(self) -> float:
|
||||
"""
|
||||
Get the current default margin leverage.
|
||||
|
||||
:return: The default leverage (1-10x), defaults to 3.
|
||||
"""
|
||||
return self.variables.get('_margin_leverage_default', 3)
|
||||
|
||||
def get_total_margin_used(self) -> float:
|
||||
"""
|
||||
Get the total margin (collateral) currently locked in margin positions.
|
||||
|
||||
Base implementation returns 0. Override in subclasses.
|
||||
|
||||
:return: Total margin used in quote currency.
|
||||
"""
|
||||
logger.warning("get_total_margin_used called on base StrategyInstance - margin trading not available")
|
||||
return 0.0
|
||||
|
||||
def _ensure_margin_runtime_state(self) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Ensure persistent margin runtime state containers exist in strategy variables."""
|
||||
if not hasattr(self, 'variables') or not isinstance(self.variables, dict):
|
||||
self.variables = StrategyVariables(warn_callback=self._variable_warning)
|
||||
|
||||
pending_entries = self.variables.get('_pending_margin_entries')
|
||||
if not isinstance(pending_entries, dict):
|
||||
pending_entries = {}
|
||||
self.variables['_pending_margin_entries'] = pending_entries
|
||||
|
||||
trailing_stops = self.variables.get('_margin_trailing_stops')
|
||||
if not isinstance(trailing_stops, dict):
|
||||
trailing_stops = {}
|
||||
self.variables['_margin_trailing_stops'] = trailing_stops
|
||||
|
||||
return pending_entries, trailing_stops
|
||||
|
||||
def _get_pending_margin_entries(self) -> dict[str, Any]:
|
||||
"""Return persisted pending margin-entry state."""
|
||||
pending_entries, _ = self._ensure_margin_runtime_state()
|
||||
return pending_entries
|
||||
|
||||
def _get_margin_trailing_stops(self) -> dict[str, Any]:
|
||||
"""Return persisted trailing-stop state for margin positions."""
|
||||
_, trailing_stops = self._ensure_margin_runtime_state()
|
||||
return trailing_stops
|
||||
|
||||
def _normalize_margin_tif(self, tif: str | None) -> str:
|
||||
"""Normalize time-in-force values to GTC/IOC/FOK."""
|
||||
tif_value = (tif or 'GTC')
|
||||
if isinstance(tif_value, dict):
|
||||
tif_value = tif_value.get('time_in_force', 'GTC')
|
||||
tif_value = str(tif_value).upper()
|
||||
if tif_value not in {'GTC', 'IOC', 'FOK'}:
|
||||
logger.warning(f"Invalid margin time-in-force '{tif_value}', defaulting to GTC")
|
||||
return 'GTC'
|
||||
return tif_value
|
||||
|
||||
def _extract_margin_value(self, option: Any, default_key: str = 'value') -> Optional[float]:
|
||||
"""Extract a numeric value from shared Blockly option payloads."""
|
||||
if option is None:
|
||||
return None
|
||||
if isinstance(option, (int, float)):
|
||||
return float(option)
|
||||
if isinstance(option, dict):
|
||||
value = option.get(default_key)
|
||||
if value is None and len(option) == 1:
|
||||
value = next(iter(option.values()))
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
return None
|
||||
|
||||
def _extract_margin_limit_price(self, limit: Any) -> Optional[float]:
|
||||
"""Extract a limit-entry price from the shared limit option."""
|
||||
return self._extract_margin_value(limit, default_key='limit')
|
||||
|
||||
def _extract_margin_trailing_distance(self, option: Any, key: str) -> Optional[float]:
|
||||
"""Extract a trailing distance from the shared trailing option."""
|
||||
return self._extract_margin_value(option, default_key=key)
|
||||
|
||||
def _resolve_margin_target(self, target_market: Optional[dict[str, Any]]) -> tuple[str, Optional[str], Optional[str]]:
|
||||
"""Resolve symbol/exchange/timeframe for a margin action."""
|
||||
symbol = self.exec_context.get('current_symbol', 'BTC/USDT')
|
||||
exchange = None
|
||||
timeframe = None
|
||||
|
||||
if isinstance(target_market, dict) and target_market:
|
||||
symbol = (
|
||||
target_market.get('symbol')
|
||||
or target_market.get('market')
|
||||
or symbol
|
||||
)
|
||||
exchange = target_market.get('exchange')
|
||||
timeframe = target_market.get('time_frame') or target_market.get('timeframe')
|
||||
|
||||
return symbol, exchange, timeframe
|
||||
|
||||
def _normalize_margin_order_name(self, name_order: Any) -> Optional[str]:
|
||||
"""Extract an optional user-facing order name from shared option payloads."""
|
||||
if isinstance(name_order, dict):
|
||||
order_name = name_order.get('order_name')
|
||||
elif isinstance(name_order, str):
|
||||
order_name = name_order
|
||||
else:
|
||||
order_name = None
|
||||
|
||||
if order_name is None:
|
||||
return None
|
||||
|
||||
normalized = str(order_name).strip()
|
||||
return normalized or None
|
||||
|
||||
def _register_margin_trailing_stop(
|
||||
self,
|
||||
symbol: str,
|
||||
side: str,
|
||||
trail_distance: float,
|
||||
current_price: float,
|
||||
order_name: Optional[str] = None
|
||||
) -> None:
|
||||
"""Persist a trailing-stop configuration for an open margin position."""
|
||||
trailing_stops = self._get_margin_trailing_stops()
|
||||
trailing_stops[symbol] = {
|
||||
'side': side,
|
||||
'trail_distance': float(trail_distance),
|
||||
'best_price': float(current_price),
|
||||
'order_name': order_name,
|
||||
}
|
||||
|
||||
def _clear_margin_trailing_stop(self, symbol: str) -> None:
|
||||
"""Remove trailing-stop state for a symbol if present."""
|
||||
trailing_stops = self._get_margin_trailing_stops()
|
||||
trailing_stops.pop(symbol, None)
|
||||
|
||||
def open_margin_position(
|
||||
self,
|
||||
side: str,
|
||||
collateral: float,
|
||||
leverage: float = None,
|
||||
stop_loss: float = None,
|
||||
take_profit: float = None,
|
||||
tif: str = 'GTC',
|
||||
trailing_stop: dict = None,
|
||||
limit: dict = None,
|
||||
trailing_limit: dict = None,
|
||||
target_market: dict = None,
|
||||
name_order: Any = None
|
||||
) -> dict:
|
||||
"""
|
||||
Open a margin position.
|
||||
|
||||
Base implementation pauses strategy and raises error.
|
||||
Override in subclasses that support margin trading.
|
||||
|
||||
:param side: 'long' or 'short'.
|
||||
:param collateral: Amount of collateral in quote currency.
|
||||
:param leverage: Leverage multiplier (uses default if not specified).
|
||||
:param stop_loss: Optional stop loss price.
|
||||
:param take_profit: Optional take profit price.
|
||||
:param tif: Time in force for limit-style entry flows.
|
||||
:param trailing_stop: Optional trailing-stop config.
|
||||
:param limit: Optional limit-entry config.
|
||||
:param trailing_limit: Optional trailing-limit entry config.
|
||||
:param target_market: Optional symbol/exchange/timeframe override.
|
||||
:param name_order: Optional user-facing name for the order.
|
||||
:return: Position details dict.
|
||||
:raises NotImplementedError: In base class.
|
||||
"""
|
||||
self.notify_user("MARGIN TRADING NOT AVAILABLE: This mode does not support margin trading")
|
||||
self.set_paused(True)
|
||||
raise NotImplementedError("Margin trading requires Paper or Live mode")
|
||||
|
||||
def close_margin_position(self, symbol: str = None, percentage: float = 100.0) -> dict:
|
||||
"""
|
||||
Close a margin position.
|
||||
|
||||
Base implementation pauses strategy and raises error.
|
||||
Override in subclasses that support margin trading.
|
||||
|
||||
:param symbol: Symbol to close (uses current symbol if not specified).
|
||||
:param percentage: Percentage of position to close (default 100%).
|
||||
:return: Close result dict.
|
||||
:raises NotImplementedError: In base class.
|
||||
"""
|
||||
self.notify_user("MARGIN TRADING NOT AVAILABLE: This mode does not support margin trading")
|
||||
self.set_paused(True)
|
||||
raise NotImplementedError("Margin trading requires Paper or Live mode")
|
||||
|
||||
def has_margin_position(self, symbol: str = None) -> bool:
|
||||
"""
|
||||
Check if there is an open margin position for the symbol.
|
||||
|
||||
Base implementation returns False. Override in subclasses.
|
||||
|
||||
:param symbol: Symbol to check (uses current symbol if not specified).
|
||||
:return: True if position exists, False otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
def get_liquidation_buffer_pct(self, symbol: str = None) -> float:
|
||||
"""
|
||||
Get the liquidation buffer percentage for a margin position.
|
||||
|
||||
Returns how far current price is from liquidation:
|
||||
- 100% = price at entry (maximum safety)
|
||||
- 50% = price halfway to liquidation
|
||||
- 0% = at liquidation price
|
||||
|
||||
Base implementation returns 100 (no position). Override in subclasses.
|
||||
|
||||
:param symbol: Symbol to check (uses current symbol if not specified).
|
||||
:return: Liquidation buffer percentage (0-100).
|
||||
"""
|
||||
return 100.0 # No position = no liquidation risk
|
||||
|
||||
def get_unrealized_pnl(self, symbol: str = None) -> float:
|
||||
"""
|
||||
Get the unrealized P/L for a margin position.
|
||||
|
||||
Base implementation returns 0. Override in subclasses.
|
||||
|
||||
:param symbol: Symbol to check (uses current symbol if not specified).
|
||||
:return: Unrealized P/L in quote currency.
|
||||
"""
|
||||
return 0.0
|
||||
|
||||
# ==============================
|
||||
# Asset Conversion
|
||||
# ==============================
|
||||
|
||||
def convert_asset(self, amount: float, from_asset: str, to_asset: str,
|
||||
exchange: str = 'binance', timeframe: str = '1h') -> float:
|
||||
"""
|
||||
Convert an amount from one asset to another using current market prices.
|
||||
|
||||
Examples:
|
||||
convert_asset(10, 'USD', 'BTC') -> Returns how much BTC equals $10
|
||||
convert_asset(0.001, 'BTC', 'USD') -> Returns USD value of 0.001 BTC
|
||||
convert_asset(100, 'USDT', 'ETH') -> Returns how much ETH equals 100 USDT
|
||||
|
||||
:param amount: The amount to convert.
|
||||
:param from_asset: The source asset (e.g., 'USD', 'BTC', 'ETH').
|
||||
:param to_asset: The target asset (e.g., 'BTC', 'USD', 'ETH').
|
||||
:param exchange: Exchange to use for price data (default: 'binance').
|
||||
:param timeframe: Timeframe for price data (default: '1h').
|
||||
:return: The converted amount in the target asset.
|
||||
"""
|
||||
try:
|
||||
# Normalize asset names
|
||||
from_asset = from_asset.upper().strip()
|
||||
to_asset = to_asset.upper().strip()
|
||||
|
||||
# If same asset, no conversion needed
|
||||
if from_asset == to_asset:
|
||||
return amount
|
||||
|
||||
# Define stablecoin equivalents (treat as USD)
|
||||
stablecoins = {'USD', 'USDT', 'USDC', 'BUSD', 'DAI'}
|
||||
|
||||
# Check if both are stablecoins (1:1 conversion)
|
||||
if from_asset in stablecoins and to_asset in stablecoins:
|
||||
return amount
|
||||
|
||||
# Normalize stablecoins to USDT for pair lookup (most common on exchanges)
|
||||
from_normalized = 'USDT' if from_asset in stablecoins else from_asset
|
||||
to_normalized = 'USDT' if to_asset in stablecoins else to_asset
|
||||
|
||||
# Case 1: Converting stablecoin to crypto (e.g., USD -> BTC)
|
||||
# Look up CRYPTO/STABLECOIN pair (e.g., BTC/USDT) and divide
|
||||
if from_asset in stablecoins and to_asset not in stablecoins:
|
||||
price = self._get_conversion_price(to_normalized, 'USDT', exchange, timeframe)
|
||||
if price is not None and price > 0:
|
||||
return amount / price
|
||||
return 0.0
|
||||
|
||||
# Case 2: Converting crypto to stablecoin (e.g., BTC -> USD)
|
||||
# Look up CRYPTO/STABLECOIN pair (e.g., BTC/USDT) and multiply
|
||||
if from_asset not in stablecoins and to_asset in stablecoins:
|
||||
price = self._get_conversion_price(from_normalized, 'USDT', exchange, timeframe)
|
||||
if price is not None:
|
||||
return amount * price
|
||||
return 0.0
|
||||
|
||||
# Case 3: Converting between cryptos (e.g., ETH -> BTC)
|
||||
# Go through USD: ETH -> USD -> BTC
|
||||
from_usd_price = self._get_conversion_price(from_normalized, 'USDT', exchange, timeframe)
|
||||
to_usd_price = self._get_conversion_price(to_normalized, 'USDT', exchange, timeframe)
|
||||
|
||||
if from_usd_price is not None and to_usd_price is not None and to_usd_price > 0:
|
||||
# Convert from_asset to USD, then USD to to_asset
|
||||
usd_amount = amount * from_usd_price
|
||||
return usd_amount / to_usd_price
|
||||
|
||||
logger.warning(f"Could not find price for {from_asset}/{to_asset} conversion")
|
||||
return 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting {amount} {from_asset} to {to_asset}: {e}", exc_info=True)
|
||||
return 0.0
|
||||
|
||||
def _get_conversion_price(self, base: str, quote: str, exchange: str, timeframe: str) -> Optional[float]:
|
||||
"""
|
||||
Get the price for a trading pair.
|
||||
|
||||
:param base: Base asset (e.g., 'BTC').
|
||||
:param quote: Quote asset (e.g., 'USDT').
|
||||
:param exchange: Exchange name.
|
||||
:param timeframe: Timeframe for price data.
|
||||
:return: Price or None if not available.
|
||||
"""
|
||||
try:
|
||||
symbol = f"{base}/{quote}"
|
||||
price = self.get_current_price(timeframe=timeframe, exchange=exchange, symbol=symbol)
|
||||
if price is not None and price > 0:
|
||||
return price
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get price for {base}/{quote}: {e}")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1061,6 +1061,8 @@ class LiveBroker(BaseBroker):
|
|||
[]
|
||||
)
|
||||
|
||||
# Only create the cache if it doesn't already exist
|
||||
if 'live_broker_states' not in self._data_cache.caches:
|
||||
self._data_cache.create_cache(
|
||||
name='live_broker_states',
|
||||
cache_type='table',
|
||||
|
|
|
|||
|
|
@ -625,6 +625,8 @@ class PaperBroker(BaseBroker):
|
|||
[]
|
||||
)
|
||||
|
||||
# Only create the cache if it doesn't already exist
|
||||
if 'paper_broker_states' not in self._data_cache.caches:
|
||||
self._data_cache.create_cache(
|
||||
name='paper_broker_states',
|
||||
cache_type='table',
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ class MappedStrategy(bt.Strategy):
|
|||
# Initialize lists to store orders and trades
|
||||
self.orders = []
|
||||
self.trade_list = []
|
||||
self.margin_trade_list = [] # Margin trade history
|
||||
|
||||
# Initialize other needed variables
|
||||
self.starting_balance = self.broker.getvalue()
|
||||
|
|
@ -160,6 +161,21 @@ class MappedStrategy(bt.Strategy):
|
|||
# Execute the strategy logic
|
||||
self.execute_strategy()
|
||||
|
||||
# Process margin positions
|
||||
if hasattr(self.strategy_instance, 'process_margin_tick'):
|
||||
margin_events = self.strategy_instance.process_margin_tick()
|
||||
for event in margin_events:
|
||||
if event.get('type') == 'liquidation':
|
||||
self.strategy_instance.notify_user(
|
||||
f"LIQUIDATION: {event['symbol']} at ${event.get('price', 0):.2f}, "
|
||||
f"P&L: ${event.get('pnl', 0):.2f}"
|
||||
)
|
||||
elif event.get('type') == 'sltp_triggered':
|
||||
trigger = event.get('trigger', 'unknown').upper()
|
||||
self.strategy_instance.notify_user(
|
||||
f"{trigger}: {event['symbol']} at ${event.get('price', 0):.2f}"
|
||||
)
|
||||
|
||||
# Advance indicator pointers for the next candle
|
||||
for name in self.indicator_names:
|
||||
if name in self.indicator_pointers:
|
||||
|
|
@ -209,10 +225,31 @@ class MappedStrategy(bt.Strategy):
|
|||
self.last_progress = progress
|
||||
|
||||
def stop(self):
|
||||
# Close all open positions
|
||||
# Close all open spot positions
|
||||
if self.position:
|
||||
self.close()
|
||||
self.log(f"Closing remaining position at the end of backtest.")
|
||||
self.log("Closing remaining spot position at the end of backtest.")
|
||||
|
||||
# Close any open margin positions (mark-to-market)
|
||||
if hasattr(self.strategy_instance, 'paper_margin_broker'):
|
||||
broker = self.strategy_instance.paper_margin_broker
|
||||
if broker:
|
||||
positions = list(broker.get_all_positions())
|
||||
for pos in positions:
|
||||
self.log(f"Closing margin position {pos.symbol} at end of backtest.")
|
||||
try:
|
||||
broker.close_position(pos.symbol)
|
||||
except Exception as e:
|
||||
self.log(f"Error closing margin position {pos.symbol}: {e}")
|
||||
|
||||
# Copy margin trade history for results
|
||||
self.margin_trade_list = broker.get_position_history()
|
||||
|
||||
# Transfer final balance back to spot
|
||||
margin_balance = broker.get_balance()
|
||||
if margin_balance > 0:
|
||||
self.broker.setcash(self.broker.getcash() + margin_balance)
|
||||
broker._balance = 0
|
||||
|
||||
def t_order(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Reference in New Issue