509 lines
22 KiB
Python
509 lines
22 KiB
Python
# backtest_strategy_instance.py
|
|
|
|
import logging
|
|
from typing import Any, Optional, TYPE_CHECKING
|
|
|
|
import pandas as pd
|
|
import datetime as dt
|
|
import backtrader as bt
|
|
from StrategyInstance import StrategyInstance
|
|
|
|
if TYPE_CHECKING:
|
|
from brokers import BacktestBroker
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BacktestStrategyInstance(StrategyInstance):
|
|
"""
|
|
Extends StrategyInstance with custom methods for backtesting.
|
|
"""
|
|
|
|
def __init__(self, strategy_instance_id: str, strategy_id: str, strategy_name: str,
|
|
user_id: int, generated_code: str, data_cache: Any, indicators: Any | None,
|
|
trades: Any | None, backtrader_strategy: Optional[bt.Strategy] = None,
|
|
edm_client: Any = None, indicator_owner_id: int = None):
|
|
# Set 'self.broker' and 'self.backtrader_strategy' to None before calling super().__init__()
|
|
self.broker = None
|
|
self.backtrader_strategy = None
|
|
|
|
super().__init__(strategy_instance_id, strategy_id, strategy_name, user_id,
|
|
generated_code, data_cache, indicators, trades, edm_client,
|
|
indicator_owner_id=indicator_owner_id)
|
|
|
|
# Set the backtrader_strategy instance after super().__init__()
|
|
self.backtrader_strategy = backtrader_strategy
|
|
self.broker = self.backtrader_strategy.broker if self.backtrader_strategy else None
|
|
|
|
# Initialize balances; they will be set after backtrader_strategy is available
|
|
self.starting_balance = 0.0
|
|
self.current_balance = 0.0
|
|
self.available_balance = 0.0
|
|
self.available_strategy_balance = 0.0
|
|
|
|
# Update exec_context with balance attributes
|
|
self.exec_context['starting_balance'] = self.starting_balance
|
|
self.exec_context['current_balance'] = self.current_balance
|
|
self.exec_context['available_balance'] = self.available_balance
|
|
self.exec_context['available_strategy_balance'] = self.available_strategy_balance
|
|
|
|
# Initialize last_valid_values for indicators
|
|
self.last_valid_values={}
|
|
|
|
# Initialize collected alerts for backtest results
|
|
self.collected_alerts = []
|
|
|
|
def set_backtrader_strategy(self, backtrader_strategy: bt.Strategy):
|
|
"""
|
|
Sets the backtrader_strategy and initializes broker-dependent attributes.
|
|
"""
|
|
self.backtrader_strategy = backtrader_strategy
|
|
self.broker = self.backtrader_strategy.broker
|
|
|
|
# Now initialize balances from Backtrader's broker
|
|
self.starting_balance = self.fetch_user_balance()
|
|
self.current_balance = self.starting_balance
|
|
self.available_balance = self.calculate_available_balance()
|
|
self.available_strategy_balance = self.starting_balance
|
|
|
|
# Update exec_context with updated balance attributes
|
|
self.exec_context['starting_balance'] = self.starting_balance
|
|
self.exec_context['current_balance'] = self.current_balance
|
|
self.exec_context['available_balance'] = self.available_balance
|
|
self.exec_context['available_strategy_balance'] = self.available_strategy_balance
|
|
|
|
# 1. Override trade_order
|
|
def trade_order(
|
|
self,
|
|
trade_type: str,
|
|
size: float,
|
|
order_type: str,
|
|
source: dict = None,
|
|
tif: str = 'GTC',
|
|
stop_loss: dict = None,
|
|
trailing_stop: dict = None,
|
|
take_profit: dict = None,
|
|
limit: dict = None,
|
|
trailing_limit: dict = None,
|
|
target_market: dict = None,
|
|
name_order: dict = None
|
|
):
|
|
"""
|
|
Custom trade_order method for backtesting.
|
|
Prepares order parameters and passes them to MappedStrategy for execution.
|
|
"""
|
|
if self.backtrader_strategy is None:
|
|
logger.error("Backtrader strategy is not set in StrategyInstance.")
|
|
return
|
|
|
|
# Validate and extract symbol
|
|
symbol = source.get('symbol') or source.get('market') if source else 'Unknown'
|
|
if not symbol:
|
|
logger.error("Symbol not provided in source. Order not executed.")
|
|
return
|
|
|
|
# Get current price from Backtrader's data feed
|
|
price = self.get_current_price()
|
|
|
|
# Get stop_loss and take_profit prices
|
|
stop_loss_price = stop_loss.get('value') if stop_loss else None
|
|
take_profit_price = take_profit.get('value') if take_profit else None
|
|
|
|
# Determine execution type based on order_type
|
|
order_type_upper = order_type.upper()
|
|
if order_type_upper == 'MARKET':
|
|
exectype = bt.Order.Market
|
|
order_price = None # Do not set price for market orders
|
|
elif order_type_upper == 'LIMIT':
|
|
exectype = bt.Order.Limit
|
|
order_price = price # Use current price as the limit price
|
|
else:
|
|
logger.error(f"Invalid order_type '{order_type}'. Order not executed.")
|
|
return
|
|
|
|
# Prepare order parameters
|
|
order_params = {
|
|
'trade_type': trade_type,
|
|
'size': size,
|
|
'exectype': exectype,
|
|
'price': order_price,
|
|
'symbol': symbol,
|
|
'stop_loss_price': stop_loss_price,
|
|
'take_profit_price': take_profit_price,
|
|
'tif': tif,
|
|
'order_type': order_type_upper
|
|
}
|
|
|
|
# Call t_order in backtrader_strategy to place the order
|
|
self.backtrader_strategy.t_order(**order_params)
|
|
|
|
# Logging and context updates
|
|
action = trade_type.upper()
|
|
message = f"{action} order placed for {size} {symbol} at {order_type_upper} price."
|
|
self.notify_user(message)
|
|
logger.info(message)
|
|
|
|
# 2. Override process_indicator
|
|
def process_indicator(self, indicator_name: str, output_field: str):
|
|
"""
|
|
Retrieves precomputed indicator values for backtesting.
|
|
If the current value is NaN, returns the last non-NaN value if available.
|
|
If no last valid value exists, searches forward for the next valid value.
|
|
If no valid value is found, returns a default value (e.g., 1).
|
|
"""
|
|
logger.info(f"[BACKTEST] process_indicator called: indicator='{indicator_name}', output='{output_field}'")
|
|
if self.backtrader_strategy is None:
|
|
logger.error("Backtrader strategy is not set in StrategyInstance.")
|
|
return None
|
|
|
|
# Try direct lookup first
|
|
df = self.backtrader_strategy.precomputed_indicators.get(indicator_name)
|
|
|
|
# If not found, try alternative name formats
|
|
if df is None:
|
|
available = list(self.backtrader_strategy.precomputed_indicators.keys())
|
|
|
|
# Try underscore->space conversion (Blockly sanitizes names with underscores)
|
|
alt_name = indicator_name.replace('_', ' ')
|
|
if alt_name in available:
|
|
df = self.backtrader_strategy.precomputed_indicators.get(alt_name)
|
|
logger.debug(f"[BACKTEST] Found indicator '{alt_name}' (converted from '{indicator_name}')")
|
|
|
|
# Try case-insensitive match
|
|
if df is None:
|
|
for avail_name in available:
|
|
if avail_name.lower() == indicator_name.lower():
|
|
df = self.backtrader_strategy.precomputed_indicators.get(avail_name)
|
|
logger.debug(f"[BACKTEST] Found indicator '{avail_name}' (case-insensitive match)")
|
|
break
|
|
|
|
if df is None:
|
|
logger.error(f"[BACKTEST DEBUG] Indicator '{indicator_name}' not found in precomputed_indicators!")
|
|
logger.error(f"[BACKTEST DEBUG] Available indicators: {list(self.backtrader_strategy.precomputed_indicators.keys())}")
|
|
return None
|
|
|
|
idx = self.backtrader_strategy.indicator_pointers.get(indicator_name, 0)
|
|
if idx >= len(df):
|
|
logger.warning(f"No more data for indicator '{indicator_name}' at index {idx}.")
|
|
return None
|
|
|
|
# Retrieve the value at the current index
|
|
value = df.iloc[idx].get(output_field)
|
|
indicator_time = df.iloc[idx].get('time', 'N/A')
|
|
|
|
# Get current candle time for comparison
|
|
candle_time = None
|
|
if self.backtrader_strategy and self.backtrader_strategy.data:
|
|
try:
|
|
candle_time = self.backtrader_strategy.data.datetime.datetime(0)
|
|
except:
|
|
pass
|
|
|
|
# Log indicator values for debugging (first 10 and every 50th)
|
|
if idx < 10 or idx % 50 == 0:
|
|
logger.info(f"[BACKTEST] process_indicator('{indicator_name}', '{output_field}') at idx={idx}: value={value}, indicator_time={indicator_time}, candle_time={candle_time}")
|
|
|
|
if pd.isna(value):
|
|
# Check if we have a cached last valid value
|
|
last_valid_value = self.last_valid_values.get(indicator_name, {}).get(output_field)
|
|
if last_valid_value is not None:
|
|
logger.debug(f"Using cached last valid value for indicator '{indicator_name}': {last_valid_value}")
|
|
return last_valid_value
|
|
else:
|
|
logger.debug(
|
|
f"No cached last valid value for indicator '{indicator_name}'. Searching ahead for next valid value.")
|
|
# Search forward for the next valid value
|
|
valid_idx = idx + 1
|
|
while valid_idx < len(df):
|
|
next_value = df.iloc[valid_idx].get(output_field)
|
|
if not pd.isna(next_value):
|
|
logger.debug(f"Found valid value at index {valid_idx}: {next_value}")
|
|
# Update the cache with this value
|
|
if indicator_name not in self.last_valid_values:
|
|
self.last_valid_values[indicator_name] = {}
|
|
self.last_valid_values[indicator_name][output_field] = next_value
|
|
return next_value
|
|
valid_idx += 1
|
|
# If no valid value is found, return a default value (e.g., 1)
|
|
logger.warning(
|
|
f"No valid value found for indicator '{indicator_name}' after index {idx}. Returning default value 1.")
|
|
return 1 # Default value to prevent errors
|
|
else:
|
|
# Update the cache with the new valid value
|
|
if indicator_name not in self.last_valid_values:
|
|
self.last_valid_values[indicator_name] = {}
|
|
self.last_valid_values[indicator_name][output_field] = value
|
|
# Log the returned value for debugging
|
|
idx = self.backtrader_strategy.indicator_pointers.get(indicator_name, 0)
|
|
if idx < 10 or idx % 50 == 0:
|
|
logger.info(f"[BACKTEST] process_indicator returning: {indicator_name}.{output_field} = {value}")
|
|
return value
|
|
|
|
# 2b. Override process_signal for backtesting
|
|
def process_signal(self, signal_name: str, output_field: str = 'triggered'):
|
|
"""
|
|
Evaluates a signal condition during backtesting.
|
|
|
|
For backtesting, signals can be precomputed or we return a placeholder.
|
|
Full signal backtesting support requires precomputed signal states.
|
|
"""
|
|
logger.debug(f"[BACKTEST] process_signal called: signal='{signal_name}', output='{output_field}'")
|
|
|
|
if self.backtrader_strategy is None:
|
|
logger.error("Backtrader strategy is not set in StrategyInstance.")
|
|
return False if output_field == 'triggered' else None
|
|
|
|
# Get precomputed signals if available
|
|
precomputed_signals = getattr(self.backtrader_strategy, 'precomputed_signals', {})
|
|
signal_data = precomputed_signals.get(signal_name)
|
|
|
|
if signal_data is not None and len(signal_data) > 0:
|
|
# Use precomputed signal data
|
|
signal_pointer = self.backtrader_strategy.signal_pointers.get(signal_name, 0)
|
|
|
|
if signal_pointer >= len(signal_data):
|
|
logger.debug(f"[BACKTEST] Signal '{signal_name}' pointer out of bounds: {signal_pointer}")
|
|
return False if output_field == 'triggered' else None
|
|
|
|
signal_row = signal_data.iloc[signal_pointer]
|
|
|
|
if output_field == 'triggered':
|
|
triggered = bool(signal_row.get('triggered', False))
|
|
return triggered
|
|
else:
|
|
value = signal_row.get('value', None)
|
|
return value
|
|
else:
|
|
# Signal not precomputed - log warning once and return default
|
|
if not hasattr(self, '_signal_warnings'):
|
|
self._signal_warnings = set()
|
|
if signal_name not in self._signal_warnings:
|
|
logger.warning(f"[BACKTEST] Signal '{signal_name}' not precomputed. "
|
|
"Signal blocks in backtesting require precomputed signal data.")
|
|
self._signal_warnings.add(signal_name)
|
|
return False if output_field == 'triggered' else None
|
|
|
|
# 3. Override get_current_price
|
|
def get_current_price(self, timeframe: str = '1h', exchange: str = 'binance',
|
|
symbol: str = 'BTC/USD') -> float:
|
|
"""
|
|
Retrieves the current market price from Backtrader's data feed.
|
|
"""
|
|
if self.backtrader_strategy:
|
|
price = self.backtrader_strategy.data.close[0]
|
|
logger.debug(f"Current price from Backtrader's data feed: {price}")
|
|
return price
|
|
else:
|
|
logger.error("Backtrader strategy is not set.")
|
|
return 0.0
|
|
|
|
# 4. Override get_last_candle
|
|
def get_last_candle(self, candle_part: str, timeframe: str, exchange: str, symbol: str):
|
|
"""
|
|
Retrieves the specified part of the last candle from Backtrader's data feed.
|
|
"""
|
|
if self.backtrader_strategy is None:
|
|
logger.error("Backtrader strategy is not set in StrategyInstance.")
|
|
return None
|
|
|
|
candle_map = {
|
|
'open': self.backtrader_strategy.data.open[0],
|
|
'high': self.backtrader_strategy.data.high[0],
|
|
'low': self.backtrader_strategy.data.low[0],
|
|
'close': self.backtrader_strategy.data.close[0],
|
|
'volume': self.backtrader_strategy.data.volume[0],
|
|
}
|
|
value = candle_map.get(candle_part.lower())
|
|
if value is None:
|
|
logger.error(f"Invalid candle_part '{candle_part}'. Must be one of {list(candle_map.keys())}.")
|
|
else:
|
|
logger.debug(
|
|
f"Retrieved '{candle_part}' from last candle for {symbol} on {exchange} ({timeframe}): {value}"
|
|
)
|
|
return value
|
|
|
|
# 5. Override get_filled_orders
|
|
def get_filled_orders(self) -> int:
|
|
"""
|
|
Retrieves the number of filled orders from Backtrader's broker.
|
|
"""
|
|
if self.backtrader_strategy is None:
|
|
logger.error("Backtrader strategy is not set in StrategyInstance.")
|
|
return 0
|
|
try:
|
|
filled_orders = len([o for o in self.backtrader_strategy.broker.orders if o.status == bt.Order.Completed])
|
|
logger.debug(f"Number of filled orders: {filled_orders}")
|
|
return filled_orders
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving filled orders: {e}", exc_info=True)
|
|
return 0
|
|
|
|
# 6. Override get_available_balance
|
|
def get_available_balance(self) -> float:
|
|
"""
|
|
Retrieves the available cash balance from Backtrader's broker.
|
|
"""
|
|
self.available_balance = self.broker.getcash()
|
|
self.exec_context['available_balance'] = self.available_balance
|
|
logger.debug(f"Available balance retrieved from Backtrader's broker: {self.available_balance}")
|
|
return self.available_balance
|
|
|
|
# 7. Override get_current_balance
|
|
def get_current_balance(self) -> float:
|
|
"""
|
|
Retrieves the current total value from Backtrader's broker.
|
|
"""
|
|
self.current_balance = self.broker.getvalue()
|
|
self.exec_context['current_balance'] = self.current_balance
|
|
logger.debug(f"Current balance retrieved from Backtrader's broker: {self.current_balance}")
|
|
return self.current_balance
|
|
|
|
# 8. Override get_filled_orders_details (Optional but Recommended)
|
|
def get_filled_orders_details(self) -> list:
|
|
"""
|
|
Retrieves detailed information about filled orders.
|
|
"""
|
|
if self.backtrader_strategy is None:
|
|
logger.error("Backtrader strategy is not set in StrategyInstance.")
|
|
return []
|
|
try:
|
|
filled_orders = []
|
|
for order in self.backtrader_strategy.broker.filled:
|
|
order_info = {
|
|
'ref': order.ref,
|
|
'size': order.size,
|
|
'price': order.executed.price,
|
|
'value': order.executed.value,
|
|
'commission': order.executed.comm,
|
|
'status': order.status,
|
|
'created_at': dt.datetime.fromtimestamp(order.created.dt.timestamp()) if hasattr(order,
|
|
'created') else None
|
|
}
|
|
filled_orders.append(order_info)
|
|
logger.debug(f"Filled orders details: {filled_orders}")
|
|
return filled_orders
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving filled orders details: {e}", exc_info=True)
|
|
return []
|
|
|
|
# 9. Override notify_user
|
|
def notify_user(self, message: str):
|
|
"""
|
|
Collects notifications with timestamps for display in backtest results.
|
|
:param message: Notification message.
|
|
"""
|
|
timestamp = self.get_current_candle_datetime()
|
|
alert = {
|
|
# Append 'Z' to indicate UTC timezone (Backtrader uses UTC internally)
|
|
'timestamp': (timestamp.isoformat() + 'Z') if timestamp else None,
|
|
'message': message
|
|
}
|
|
self.collected_alerts.append(alert)
|
|
logger.debug(f"Backtest notification: {message} (at {timestamp})")
|
|
|
|
def get_current_candle_datetime(self) -> dt.datetime:
|
|
"""
|
|
Gets the datetime of the current candle from backtrader's data feed.
|
|
"""
|
|
if self.backtrader_strategy is None:
|
|
return dt.datetime.now()
|
|
try:
|
|
# Use the data feed's datetime method to get proper datetime
|
|
return self.backtrader_strategy.data.datetime.datetime(0)
|
|
except Exception as e:
|
|
logger.warning(f"Could not get candle datetime: {e}")
|
|
return dt.datetime.now()
|
|
|
|
def get_collected_alerts(self) -> list:
|
|
"""
|
|
Returns the list of collected alerts for inclusion in backtest results.
|
|
"""
|
|
return self.collected_alerts
|
|
|
|
def save_context(self):
|
|
"""
|
|
Saves the current strategy execution context to the cache and database.
|
|
Adjusted for backtesting to include balance attributes.
|
|
"""
|
|
try:
|
|
# Update balances from broker before saving
|
|
self.current_balance = self.get_current_balance()
|
|
self.available_balance = self.get_available_balance()
|
|
|
|
self.exec_context['current_balance'] = self.current_balance
|
|
self.exec_context['available_balance'] = self.available_balance
|
|
self.exec_context['available_strategy_balance'] = self.available_strategy_balance
|
|
|
|
super().save_context()
|
|
except Exception as e:
|
|
logger.error(f"Error saving context for backtest: {e}", exc_info=True)
|
|
|
|
def fetch_user_balance(self) -> float:
|
|
"""
|
|
Fetches the starting balance from Backtrader's broker.
|
|
"""
|
|
if hasattr(self, 'broker') and self.broker:
|
|
balance = self.broker.getvalue()
|
|
logger.debug(f"Fetched starting balance from Backtrader's broker: {balance}")
|
|
return balance
|
|
else:
|
|
logger.error("Broker is not set. Cannot fetch starting balance.")
|
|
return 0.0
|
|
|
|
|
|
def calculate_available_balance(self) -> float:
|
|
"""
|
|
Calculates the available cash balance from Backtrader's broker.
|
|
"""
|
|
if self.broker:
|
|
available_balance = self.broker.getcash()
|
|
logger.debug(f"Calculated available cash balance from Backtrader's broker: {available_balance}")
|
|
return available_balance
|
|
else:
|
|
logger.error("Broker is not set. Cannot calculate available balance.")
|
|
return 0.0
|
|
|
|
|
|
def set_available_strategy_balance(self, balance: float):
|
|
"""
|
|
Sets the available strategy balance in backtesting.
|
|
"""
|
|
# In backtesting, we might simulate allocation by adjusting internal variables
|
|
if balance > self.get_available_balance():
|
|
raise ValueError("Cannot allocate more than the available balance in backtest.")
|
|
self.available_strategy_balance = balance
|
|
self.exec_context['available_strategy_balance'] = self.available_strategy_balance
|
|
self.save_context()
|
|
logger.debug(f"Available strategy balance set to {balance} in backtest.")
|
|
|
|
def get_available_strategy_balance(self) -> float:
|
|
"""
|
|
Retrieves the available strategy balance in backtesting.
|
|
"""
|
|
logger.debug(f"Available strategy balance in backtest: {self.available_strategy_balance}")
|
|
return self.available_strategy_balance
|
|
|
|
def get_starting_balance(self) -> float:
|
|
"""
|
|
Returns the starting balance in backtesting.
|
|
"""
|
|
logger.debug(f"Starting balance in backtest: {self.starting_balance}")
|
|
return self.starting_balance
|
|
|
|
def get_active_trades(self) -> int:
|
|
"""
|
|
Retrieves the number of active trades (open positions) from Backtrader's broker.
|
|
"""
|
|
if self.backtrader_strategy is None:
|
|
logger.error("Backtrader strategy is not set in StrategyInstance.")
|
|
return 0
|
|
try:
|
|
# Get all positions
|
|
positions = self.broker.positions
|
|
active_trades_count = sum(1 for position in positions.values() if position.size != 0)
|
|
logger.debug(f"Number of active trades: {active_trades_count}")
|
|
return active_trades_count
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving active trades: {e}", exc_info=True)
|
|
return 0
|