import ast import json import re import backtrader as bt import datetime as dt from DataCache_v3 import DataCache from Strategies import Strategies import threading import numpy as np class Backtester: def __init__(self, data_cache: DataCache, strategies: Strategies): """ Initialize the Backtesting class with a cache for back-tests """ self.data_cache = data_cache self.strategies = strategies # Create a cache for storing back-tests self.data_cache.create_cache('tests', cache_type='row', size_limit=100, default_expiration=dt.timedelta(days=1), eviction_policy='evict') def get_default_chart_view(self, user_name): """Fetch default chart view if no specific source is provided.""" return self.data_cache.get_datacache_item( item_name='chart_view', cache_name='users', filter_vals=('user_name', user_name)) def cache_backtest(self, user_name, backtest_name, backtest_data): """ Cache the backtest data for a user """ columns = ('user_name', 'strategy_name', 'start_time', 'capital', 'commission', 'results') values = ( backtest_data.get('user_name'), backtest_data.get('strategy'), backtest_data.get('start_date'), backtest_data.get('capital', 10000), # Default capital if not provided backtest_data.get('commission', 0.001), # Default commission None # No results yet; will be filled in after backtest completion ) cache_key = f"backtest:{user_name}:{backtest_name}" self.data_cache.insert_row_into_cache('tests', columns, values, key=cache_key) def map_user_strategy(self, user_strategy): """Maps user strategy details into a Backtrader-compatible strategy class.""" class MappedStrategy(bt.Strategy): params = ( ('initial_cash', user_strategy['params'].get('initial_cash', 10000)), ('commission', user_strategy['params'].get('commission', 0.001)), ) def __init__(self): # Extract unique sources (exchange, symbol, timeframe) from blocks self.sources = self.extract_sources(user_strategy) # Map of source to data feed (used later in next()) self.source_data_feed_map = {} def extract_sources(self, user_strategy): """Extracts unique sources from the strategy.""" sources = [] for block in user_strategy.get('blocks', []): if block.get('type') in ['last_candle_value', 'trade_action']: source = self.extract_source_from_block(block) if source and source not in sources: sources.append(source) elif block.get('type') == 'target_market': target_source = self.extract_target_market(block) if target_source and target_source not in sources: sources.append(target_source) return sources def extract_source_from_block(self, block): """Extract source (exchange, symbol, timeframe) from a strategy block.""" source = {} if block.get('type') == 'last_candle_value': source = block.get('SOURCE', None) # If SOURCE is missing, use the trade target or default if not source: source = self.get_default_chart_view(self.user_name) # Fallback to default return source def extract_target_market(self, block): """Extracts target market data (timeframe, exchange, symbol) from the trade_action block.""" target_market = block.get('target_market', {}) return { 'timeframe': target_market.get('TF', '5m'), 'exchange': target_market.get('EXC', 'Binance'), 'symbol': target_market.get('SYM', 'BTCUSD') } def next(self): """Execute trading logic using the compiled strategy.""" try: exec(self.compiled_logic, {'self': self, 'data_feeds': self.source_data_feed_map}) except Exception as e: print(f"Error executing trading logic: {e}") return MappedStrategy def prepare_data_feed(self, start_date: str, sources: list, user_name: str): """ Prepare multiple data feeds based on the start date and list of sources. """ try: # Convert the start date to a datetime object start_dt = dt.datetime.strptime(start_date, '%Y-%m-%dT%H:%M') # Dictionary to map each source to its corresponding data feed data_feeds = {} for source in sources: # Ensure exchange details contain required keys (fallback if missing) asset = source.get('asset', 'BTCUSD') timeframe = source.get('timeframe', '5m') exchange = source.get('exchange', 'Binance') # Fetch OHLC data from DataCache based on the source ex_details = [asset, timeframe, exchange, user_name] data = self.data_cache.get_records_since(start_dt, ex_details) # Return the data as a Pandas DataFrame compatible with Backtrader data_feeds[tuple(ex_details)] = data return data_feeds except Exception as e: print(f"Error preparing data feed: {e}") return None def run_backtest(self, strategy, data_feed_map, msg_data, user_name, callback, socket_conn): """ Runs a backtest using Backtrader on a separate thread and calls the callback with the results when finished. Also sends progress updates to the client via WebSocket. """ def execute_backtest(): cerebro = bt.Cerebro() # Add the mapped strategy to the backtest cerebro.addstrategy(strategy) # Add all the data feeds to Cerebro total_bars = 0 # Total number of data points (bars) across all feeds for source, data_feed in data_feed_map.items(): bt_feed = bt.feeds.PandasData(dataname=data_feed) cerebro.adddata(bt_feed) strategy.source_data_feed_map[source] = bt_feed total_bars = max(total_bars, len(data_feed)) # Get the total bars from the largest feed # Capture initial capital initial_capital = cerebro.broker.getvalue() # Progress tracking variables current_bar = 0 last_progress = 0 # Custom next function to track progress (if you have a large dataset) def track_progress(): nonlocal current_bar, last_progress current_bar += 1 progress = (current_bar / total_bars) * 100 # Send progress update every 10% increment if progress >= last_progress + 10: last_progress += 10 socket_conn.send(json.dumps({"progress": int(last_progress)})) # Attach the custom next method to the strategy strategy.next = track_progress # Run the backtest print("Running backtest...") start_time = dt.datetime.now() cerebro.run() end_time = dt.datetime.now() # Extract performance metrics final_value = cerebro.broker.getvalue() run_duration = (end_time - start_time).total_seconds() # Send 100% completion socket_conn.send(json.dumps({"progress": 100})) # Prepare the results to pass into the callback callback({ "initial_capital": initial_capital, "final_portfolio_value": final_value, "run_duration": run_duration }) # Map the user strategy and prepare the data feeds sources = strategy.extract_sources() data_feed_map = self.prepare_data_feed(msg_data['start_date'], sources, user_name) # Run the backtest in a separate thread thread = threading.Thread(target=execute_backtest) thread.start() def handle_backtest_message(self, user_id, msg_data, socket_conn): user_name = msg_data.get('user_name') backtest_name = f"{msg_data['strategy']}_backtest" # Cache the backtest data self.cache_backtest(user_name, backtest_name, msg_data) # Fetch the strategy using user_id and strategy_name strategy_name = msg_data.get('strategy') user_strategy = self.strategies.get_strategy_by_name(user_id=user_id, name=strategy_name) if not user_strategy: return {"error": f"Strategy {strategy_name} not found for user {user_name}"} # Extract sources from the strategy JSON sources = self.extract_sources_from_strategy_json(user_strategy.get('strategy_json')) if not sources: return {"error": "No valid sources found in the strategy."} # Prepare the data feed map based on extracted sources data_feed_map = self.prepare_data_feed(msg_data['start_date'], sources, user_name) if data_feed_map is None: return {"error": "Data feed could not be prepared. Please check the data source."} # Map the user strategy to a Backtrader strategy class mapped_strategy = self.map_user_strategy(user_strategy) # Define the callback function to handle backtest completion def backtest_callback(results): self.store_backtest_results(user_name, backtest_name, results) self.update_strategy_stats(user_id, strategy_name, results) # Run the backtest and pass the callback function, msg_data, and user_name self.run_backtest(mapped_strategy, data_feed_map, msg_data, user_name, backtest_callback, socket_conn) return {"reply": "backtest_started"} def extract_sources_from_strategy_json(self, strategy_json): sources = [] # Parse the JSON strategy to extract sources def traverse_blocks(blocks): for block in blocks: if block['type'] == 'source': source = { 'timeframe': block['fields'].get('TF'), 'exchange': block['fields'].get('EXC'), 'symbol': block['fields'].get('SYM') } sources.append(source) # Recursively traverse inputs and statements if 'inputs' in block: traverse_blocks(block['inputs'].values()) if 'statements' in block: traverse_blocks(block['statements'].values()) if 'next' in block: traverse_blocks([block['next']]) traverse_blocks(strategy_json) return sources def update_strategy_stats(self, user_id, strategy_name, results): """ Update the strategy stats with the backtest results """ strategy = self.strategies.get_strategy_by_name(user_id=user_id, name=strategy_name) if strategy: initial_capital = results['initial_capital'] final_value = results['final_portfolio_value'] returns = np.array(results['returns']) equity_curve = np.array(results['equity_curve']) trades = results['trades'] total_return = (final_value - initial_capital) / initial_capital * 100 risk_free_rate = 0.0 mean_return = np.mean(returns) std_return = np.std(returns) sharpe_ratio = (mean_return - risk_free_rate) / std_return if std_return != 0 else 0 running_max = np.maximum.accumulate(equity_curve) drawdowns = (equity_curve - running_max) / running_max max_drawdown = np.min(drawdowns) * 100 num_trades = len(trades) wins = sum(1 for trade in trades if trade['profit'] > 0) losses = num_trades - wins win_loss_ratio = wins / losses if losses != 0 else wins stats = { 'total_return': total_return, 'sharpe_ratio': sharpe_ratio, 'max_drawdown': max_drawdown, 'number_of_trades': num_trades, 'win_loss_ratio': win_loss_ratio, } strategy.update_stats(stats) else: print(f"Strategy {strategy_name} not found for user {user_id}.") def store_backtest_results(self, user_name, backtest_name, results): """ Store the backtest results in the cache """ cache_key = f"backtest:{user_name}:{backtest_name}" filter_vals = [('tbl_key', cache_key)] backtest_data = self.data_cache.get_rows_from_cache('tests', filter_vals) if not backtest_data.empty: backtest_data['results'] = results self.data_cache.insert_row_into_cache('tests', backtest_data.keys(), backtest_data.values(), key=cache_key) else: print(f"Backtest {backtest_name} not found in cache.")