brighter-trading/src/backtest_strategy_instance.py

509 lines
22 KiB
Python

# backtest_strategy_instance.py
import logging
from typing import Any, Optional, TYPE_CHECKING
import pandas as pd
import datetime as dt
import backtrader as bt
from StrategyInstance import StrategyInstance
if TYPE_CHECKING:
from brokers import BacktestBroker
logger = logging.getLogger(__name__)
class BacktestStrategyInstance(StrategyInstance):
"""
Extends StrategyInstance with custom methods for backtesting.
"""
def __init__(self, strategy_instance_id: str, strategy_id: str, strategy_name: str,
user_id: int, generated_code: str, data_cache: Any, indicators: Any | None,
trades: Any | None, backtrader_strategy: Optional[bt.Strategy] = None,
edm_client: Any = None, indicator_owner_id: int = None):
# Set 'self.broker' and 'self.backtrader_strategy' to None before calling super().__init__()
self.broker = None
self.backtrader_strategy = None
super().__init__(strategy_instance_id, strategy_id, strategy_name, user_id,
generated_code, data_cache, indicators, trades, edm_client,
indicator_owner_id=indicator_owner_id)
# Set the backtrader_strategy instance after super().__init__()
self.backtrader_strategy = backtrader_strategy
self.broker = self.backtrader_strategy.broker if self.backtrader_strategy else None
# Initialize balances; they will be set after backtrader_strategy is available
self.starting_balance = 0.0
self.current_balance = 0.0
self.available_balance = 0.0
self.available_strategy_balance = 0.0
# Update exec_context with balance attributes
self.exec_context['starting_balance'] = self.starting_balance
self.exec_context['current_balance'] = self.current_balance
self.exec_context['available_balance'] = self.available_balance
self.exec_context['available_strategy_balance'] = self.available_strategy_balance
# Initialize last_valid_values for indicators
self.last_valid_values={}
# Initialize collected alerts for backtest results
self.collected_alerts = []
def set_backtrader_strategy(self, backtrader_strategy: bt.Strategy):
"""
Sets the backtrader_strategy and initializes broker-dependent attributes.
"""
self.backtrader_strategy = backtrader_strategy
self.broker = self.backtrader_strategy.broker
# Now initialize balances from Backtrader's broker
self.starting_balance = self.fetch_user_balance()
self.current_balance = self.starting_balance
self.available_balance = self.calculate_available_balance()
self.available_strategy_balance = self.starting_balance
# Update exec_context with updated balance attributes
self.exec_context['starting_balance'] = self.starting_balance
self.exec_context['current_balance'] = self.current_balance
self.exec_context['available_balance'] = self.available_balance
self.exec_context['available_strategy_balance'] = self.available_strategy_balance
# 1. Override trade_order
def trade_order(
self,
trade_type: str,
size: float,
order_type: str,
source: dict = None,
tif: str = 'GTC',
stop_loss: dict = None,
trailing_stop: dict = None,
take_profit: dict = None,
limit: dict = None,
trailing_limit: dict = None,
target_market: dict = None,
name_order: dict = None
):
"""
Custom trade_order method for backtesting.
Prepares order parameters and passes them to MappedStrategy for execution.
"""
if self.backtrader_strategy is None:
logger.error("Backtrader strategy is not set in StrategyInstance.")
return
# Validate and extract symbol
symbol = source.get('symbol') or source.get('market') if source else 'Unknown'
if not symbol:
logger.error("Symbol not provided in source. Order not executed.")
return
# Get current price from Backtrader's data feed
price = self.get_current_price()
# Get stop_loss and take_profit prices
stop_loss_price = stop_loss.get('value') if stop_loss else None
take_profit_price = take_profit.get('value') if take_profit else None
# Determine execution type based on order_type
order_type_upper = order_type.upper()
if order_type_upper == 'MARKET':
exectype = bt.Order.Market
order_price = None # Do not set price for market orders
elif order_type_upper == 'LIMIT':
exectype = bt.Order.Limit
order_price = price # Use current price as the limit price
else:
logger.error(f"Invalid order_type '{order_type}'. Order not executed.")
return
# Prepare order parameters
order_params = {
'trade_type': trade_type,
'size': size,
'exectype': exectype,
'price': order_price,
'symbol': symbol,
'stop_loss_price': stop_loss_price,
'take_profit_price': take_profit_price,
'tif': tif,
'order_type': order_type_upper
}
# Call t_order in backtrader_strategy to place the order
self.backtrader_strategy.t_order(**order_params)
# Logging and context updates
action = trade_type.upper()
message = f"{action} order placed for {size} {symbol} at {order_type_upper} price."
self.notify_user(message)
logger.info(message)
# 2. Override process_indicator
def process_indicator(self, indicator_name: str, output_field: str):
"""
Retrieves precomputed indicator values for backtesting.
If the current value is NaN, returns the last non-NaN value if available.
If no last valid value exists, searches forward for the next valid value.
If no valid value is found, returns a default value (e.g., 1).
"""
logger.info(f"[BACKTEST] process_indicator called: indicator='{indicator_name}', output='{output_field}'")
if self.backtrader_strategy is None:
logger.error("Backtrader strategy is not set in StrategyInstance.")
return None
# Try direct lookup first
df = self.backtrader_strategy.precomputed_indicators.get(indicator_name)
# If not found, try alternative name formats
if df is None:
available = list(self.backtrader_strategy.precomputed_indicators.keys())
# Try underscore->space conversion (Blockly sanitizes names with underscores)
alt_name = indicator_name.replace('_', ' ')
if alt_name in available:
df = self.backtrader_strategy.precomputed_indicators.get(alt_name)
logger.debug(f"[BACKTEST] Found indicator '{alt_name}' (converted from '{indicator_name}')")
# Try case-insensitive match
if df is None:
for avail_name in available:
if avail_name.lower() == indicator_name.lower():
df = self.backtrader_strategy.precomputed_indicators.get(avail_name)
logger.debug(f"[BACKTEST] Found indicator '{avail_name}' (case-insensitive match)")
break
if df is None:
logger.error(f"[BACKTEST DEBUG] Indicator '{indicator_name}' not found in precomputed_indicators!")
logger.error(f"[BACKTEST DEBUG] Available indicators: {list(self.backtrader_strategy.precomputed_indicators.keys())}")
return None
idx = self.backtrader_strategy.indicator_pointers.get(indicator_name, 0)
if idx >= len(df):
logger.warning(f"No more data for indicator '{indicator_name}' at index {idx}.")
return None
# Retrieve the value at the current index
value = df.iloc[idx].get(output_field)
indicator_time = df.iloc[idx].get('time', 'N/A')
# Get current candle time for comparison
candle_time = None
if self.backtrader_strategy and self.backtrader_strategy.data:
try:
candle_time = self.backtrader_strategy.data.datetime.datetime(0)
except:
pass
# Log indicator values for debugging (first 10 and every 50th)
if idx < 10 or idx % 50 == 0:
logger.info(f"[BACKTEST] process_indicator('{indicator_name}', '{output_field}') at idx={idx}: value={value}, indicator_time={indicator_time}, candle_time={candle_time}")
if pd.isna(value):
# Check if we have a cached last valid value
last_valid_value = self.last_valid_values.get(indicator_name, {}).get(output_field)
if last_valid_value is not None:
logger.debug(f"Using cached last valid value for indicator '{indicator_name}': {last_valid_value}")
return last_valid_value
else:
logger.debug(
f"No cached last valid value for indicator '{indicator_name}'. Searching ahead for next valid value.")
# Search forward for the next valid value
valid_idx = idx + 1
while valid_idx < len(df):
next_value = df.iloc[valid_idx].get(output_field)
if not pd.isna(next_value):
logger.debug(f"Found valid value at index {valid_idx}: {next_value}")
# Update the cache with this value
if indicator_name not in self.last_valid_values:
self.last_valid_values[indicator_name] = {}
self.last_valid_values[indicator_name][output_field] = next_value
return next_value
valid_idx += 1
# If no valid value is found, return a default value (e.g., 1)
logger.warning(
f"No valid value found for indicator '{indicator_name}' after index {idx}. Returning default value 1.")
return 1 # Default value to prevent errors
else:
# Update the cache with the new valid value
if indicator_name not in self.last_valid_values:
self.last_valid_values[indicator_name] = {}
self.last_valid_values[indicator_name][output_field] = value
# Log the returned value for debugging
idx = self.backtrader_strategy.indicator_pointers.get(indicator_name, 0)
if idx < 10 or idx % 50 == 0:
logger.info(f"[BACKTEST] process_indicator returning: {indicator_name}.{output_field} = {value}")
return value
# 2b. Override process_signal for backtesting
def process_signal(self, signal_name: str, output_field: str = 'triggered'):
"""
Evaluates a signal condition during backtesting.
For backtesting, signals can be precomputed or we return a placeholder.
Full signal backtesting support requires precomputed signal states.
"""
logger.debug(f"[BACKTEST] process_signal called: signal='{signal_name}', output='{output_field}'")
if self.backtrader_strategy is None:
logger.error("Backtrader strategy is not set in StrategyInstance.")
return False if output_field == 'triggered' else None
# Get precomputed signals if available
precomputed_signals = getattr(self.backtrader_strategy, 'precomputed_signals', {})
signal_data = precomputed_signals.get(signal_name)
if signal_data is not None and len(signal_data) > 0:
# Use precomputed signal data
signal_pointer = self.backtrader_strategy.signal_pointers.get(signal_name, 0)
if signal_pointer >= len(signal_data):
logger.debug(f"[BACKTEST] Signal '{signal_name}' pointer out of bounds: {signal_pointer}")
return False if output_field == 'triggered' else None
signal_row = signal_data.iloc[signal_pointer]
if output_field == 'triggered':
triggered = bool(signal_row.get('triggered', False))
return triggered
else:
value = signal_row.get('value', None)
return value
else:
# Signal not precomputed - log warning once and return default
if not hasattr(self, '_signal_warnings'):
self._signal_warnings = set()
if signal_name not in self._signal_warnings:
logger.warning(f"[BACKTEST] Signal '{signal_name}' not precomputed. "
"Signal blocks in backtesting require precomputed signal data.")
self._signal_warnings.add(signal_name)
return False if output_field == 'triggered' else None
# 3. Override get_current_price
def get_current_price(self, timeframe: str = '1h', exchange: str = 'binance',
symbol: str = 'BTC/USD') -> float:
"""
Retrieves the current market price from Backtrader's data feed.
"""
if self.backtrader_strategy:
price = self.backtrader_strategy.data.close[0]
logger.debug(f"Current price from Backtrader's data feed: {price}")
return price
else:
logger.error("Backtrader strategy is not set.")
return 0.0
# 4. Override get_last_candle
def get_last_candle(self, candle_part: str, timeframe: str, exchange: str, symbol: str):
"""
Retrieves the specified part of the last candle from Backtrader's data feed.
"""
if self.backtrader_strategy is None:
logger.error("Backtrader strategy is not set in StrategyInstance.")
return None
candle_map = {
'open': self.backtrader_strategy.data.open[0],
'high': self.backtrader_strategy.data.high[0],
'low': self.backtrader_strategy.data.low[0],
'close': self.backtrader_strategy.data.close[0],
'volume': self.backtrader_strategy.data.volume[0],
}
value = candle_map.get(candle_part.lower())
if value is None:
logger.error(f"Invalid candle_part '{candle_part}'. Must be one of {list(candle_map.keys())}.")
else:
logger.debug(
f"Retrieved '{candle_part}' from last candle for {symbol} on {exchange} ({timeframe}): {value}"
)
return value
# 5. Override get_filled_orders
def get_filled_orders(self) -> int:
"""
Retrieves the number of filled orders from Backtrader's broker.
"""
if self.backtrader_strategy is None:
logger.error("Backtrader strategy is not set in StrategyInstance.")
return 0
try:
filled_orders = len([o for o in self.backtrader_strategy.broker.orders if o.status == bt.Order.Completed])
logger.debug(f"Number of filled orders: {filled_orders}")
return filled_orders
except Exception as e:
logger.error(f"Error retrieving filled orders: {e}", exc_info=True)
return 0
# 6. Override get_available_balance
def get_available_balance(self) -> float:
"""
Retrieves the available cash balance from Backtrader's broker.
"""
self.available_balance = self.broker.getcash()
self.exec_context['available_balance'] = self.available_balance
logger.debug(f"Available balance retrieved from Backtrader's broker: {self.available_balance}")
return self.available_balance
# 7. Override get_current_balance
def get_current_balance(self) -> float:
"""
Retrieves the current total value from Backtrader's broker.
"""
self.current_balance = self.broker.getvalue()
self.exec_context['current_balance'] = self.current_balance
logger.debug(f"Current balance retrieved from Backtrader's broker: {self.current_balance}")
return self.current_balance
# 8. Override get_filled_orders_details (Optional but Recommended)
def get_filled_orders_details(self) -> list:
"""
Retrieves detailed information about filled orders.
"""
if self.backtrader_strategy is None:
logger.error("Backtrader strategy is not set in StrategyInstance.")
return []
try:
filled_orders = []
for order in self.backtrader_strategy.broker.filled:
order_info = {
'ref': order.ref,
'size': order.size,
'price': order.executed.price,
'value': order.executed.value,
'commission': order.executed.comm,
'status': order.status,
'created_at': dt.datetime.fromtimestamp(order.created.dt.timestamp()) if hasattr(order,
'created') else None
}
filled_orders.append(order_info)
logger.debug(f"Filled orders details: {filled_orders}")
return filled_orders
except Exception as e:
logger.error(f"Error retrieving filled orders details: {e}", exc_info=True)
return []
# 9. Override notify_user
def notify_user(self, message: str):
"""
Collects notifications with timestamps for display in backtest results.
:param message: Notification message.
"""
timestamp = self.get_current_candle_datetime()
alert = {
# Append 'Z' to indicate UTC timezone (Backtrader uses UTC internally)
'timestamp': (timestamp.isoformat() + 'Z') if timestamp else None,
'message': message
}
self.collected_alerts.append(alert)
logger.debug(f"Backtest notification: {message} (at {timestamp})")
def get_current_candle_datetime(self) -> dt.datetime:
"""
Gets the datetime of the current candle from backtrader's data feed.
"""
if self.backtrader_strategy is None:
return dt.datetime.now()
try:
# Use the data feed's datetime method to get proper datetime
return self.backtrader_strategy.data.datetime.datetime(0)
except Exception as e:
logger.warning(f"Could not get candle datetime: {e}")
return dt.datetime.now()
def get_collected_alerts(self) -> list:
"""
Returns the list of collected alerts for inclusion in backtest results.
"""
return self.collected_alerts
def save_context(self):
"""
Saves the current strategy execution context to the cache and database.
Adjusted for backtesting to include balance attributes.
"""
try:
# Update balances from broker before saving
self.current_balance = self.get_current_balance()
self.available_balance = self.get_available_balance()
self.exec_context['current_balance'] = self.current_balance
self.exec_context['available_balance'] = self.available_balance
self.exec_context['available_strategy_balance'] = self.available_strategy_balance
super().save_context()
except Exception as e:
logger.error(f"Error saving context for backtest: {e}", exc_info=True)
def fetch_user_balance(self) -> float:
"""
Fetches the starting balance from Backtrader's broker.
"""
if hasattr(self, 'broker') and self.broker:
balance = self.broker.getvalue()
logger.debug(f"Fetched starting balance from Backtrader's broker: {balance}")
return balance
else:
logger.error("Broker is not set. Cannot fetch starting balance.")
return 0.0
def calculate_available_balance(self) -> float:
"""
Calculates the available cash balance from Backtrader's broker.
"""
if self.broker:
available_balance = self.broker.getcash()
logger.debug(f"Calculated available cash balance from Backtrader's broker: {available_balance}")
return available_balance
else:
logger.error("Broker is not set. Cannot calculate available balance.")
return 0.0
def set_available_strategy_balance(self, balance: float):
"""
Sets the available strategy balance in backtesting.
"""
# In backtesting, we might simulate allocation by adjusting internal variables
if balance > self.get_available_balance():
raise ValueError("Cannot allocate more than the available balance in backtest.")
self.available_strategy_balance = balance
self.exec_context['available_strategy_balance'] = self.available_strategy_balance
self.save_context()
logger.debug(f"Available strategy balance set to {balance} in backtest.")
def get_available_strategy_balance(self) -> float:
"""
Retrieves the available strategy balance in backtesting.
"""
logger.debug(f"Available strategy balance in backtest: {self.available_strategy_balance}")
return self.available_strategy_balance
def get_starting_balance(self) -> float:
"""
Returns the starting balance in backtesting.
"""
logger.debug(f"Starting balance in backtest: {self.starting_balance}")
return self.starting_balance
def get_active_trades(self) -> int:
"""
Retrieves the number of active trades (open positions) from Backtrader's broker.
"""
if self.backtrader_strategy is None:
logger.error("Backtrader strategy is not set in StrategyInstance.")
return 0
try:
# Get all positions
positions = self.broker.positions
active_trades_count = sum(1 for position in positions.values() if position.size != 0)
logger.debug(f"Number of active trades: {active_trades_count}")
return active_trades_count
except Exception as e:
logger.error(f"Error retrieving active trades: {e}", exc_info=True)
return 0