brighter-trading/src/backtesting.py

313 lines
13 KiB
Python

import ast
import json
import re
import backtrader as bt
import datetime as dt
from DataCache_v3 import DataCache
from Strategies import Strategies
import threading
import numpy as np
class Backtester:
def __init__(self, data_cache: DataCache, strategies: Strategies):
""" Initialize the Backtesting class with a cache for back-tests """
self.data_cache = data_cache
self.strategies = strategies
# Create a cache for storing back-tests
self.data_cache.create_cache('tests', cache_type='row', size_limit=100,
default_expiration=dt.timedelta(days=1),
eviction_policy='evict')
def get_default_chart_view(self, user_name):
"""Fetch default chart view if no specific source is provided."""
return self.data_cache.get_datacache_item(
item_name='chart_view', cache_name='users', filter_vals=('user_name', user_name))
def cache_backtest(self, user_name, backtest_name, backtest_data):
""" Cache the backtest data for a user """
columns = ('user_name', 'strategy_name', 'start_time', 'capital', 'commission', 'results')
values = (
backtest_data.get('user_name'),
backtest_data.get('strategy'),
backtest_data.get('start_date'),
backtest_data.get('capital', 10000), # Default capital if not provided
backtest_data.get('commission', 0.001), # Default commission
None # No results yet; will be filled in after backtest completion
)
cache_key = f"backtest:{user_name}:{backtest_name}"
self.data_cache.insert_row_into_cache('tests', columns, values, key=cache_key)
def map_user_strategy(self, user_strategy):
"""Maps user strategy details into a Backtrader-compatible strategy class."""
class MappedStrategy(bt.Strategy):
params = (
('initial_cash', user_strategy['params'].get('initial_cash', 10000)),
('commission', user_strategy['params'].get('commission', 0.001)),
)
def __init__(self):
# Extract unique sources (exchange, symbol, timeframe) from blocks
self.sources = self.extract_sources(user_strategy)
# Map of source to data feed (used later in next())
self.source_data_feed_map = {}
def extract_sources(self, user_strategy):
"""Extracts unique sources from the strategy."""
sources = []
for block in user_strategy.get('blocks', []):
if block.get('type') in ['last_candle_value', 'trade_action']:
source = self.extract_source_from_block(block)
if source and source not in sources:
sources.append(source)
elif block.get('type') == 'target_market':
target_source = self.extract_target_market(block)
if target_source and target_source not in sources:
sources.append(target_source)
return sources
def extract_source_from_block(self, block):
"""Extract source (exchange, symbol, timeframe) from a strategy block."""
source = {}
if block.get('type') == 'last_candle_value':
source = block.get('SOURCE', None)
# If SOURCE is missing, use the trade target or default
if not source:
source = self.get_default_chart_view(self.user_name) # Fallback to default
return source
def extract_target_market(self, block):
"""Extracts target market data (timeframe, exchange, symbol) from the trade_action block."""
target_market = block.get('target_market', {})
return {
'timeframe': target_market.get('TF', '5m'),
'exchange': target_market.get('EXC', 'Binance'),
'symbol': target_market.get('SYM', 'BTCUSD')
}
def next(self):
"""Execute trading logic using the compiled strategy."""
try:
exec(self.compiled_logic, {'self': self, 'data_feeds': self.source_data_feed_map})
except Exception as e:
print(f"Error executing trading logic: {e}")
return MappedStrategy
def prepare_data_feed(self, start_date: str, sources: list, user_name: str):
"""
Prepare multiple data feeds based on the start date and list of sources.
"""
try:
# Convert the start date to a datetime object
start_dt = dt.datetime.strptime(start_date, '%Y-%m-%dT%H:%M')
# Dictionary to map each source to its corresponding data feed
data_feeds = {}
for source in sources:
# Ensure exchange details contain required keys (fallback if missing)
asset = source.get('asset', 'BTCUSD')
timeframe = source.get('timeframe', '5m')
exchange = source.get('exchange', 'Binance')
# Fetch OHLC data from DataCache based on the source
ex_details = [asset, timeframe, exchange, user_name]
data = self.data_cache.get_records_since(start_dt, ex_details)
# Return the data as a Pandas DataFrame compatible with Backtrader
data_feeds[tuple(ex_details)] = data
return data_feeds
except Exception as e:
print(f"Error preparing data feed: {e}")
return None
def run_backtest(self, strategy, data_feed_map, msg_data, user_name, callback, socket_conn):
"""
Runs a backtest using Backtrader on a separate thread and calls the callback with the results when finished.
Also sends progress updates to the client via WebSocket.
"""
def execute_backtest():
cerebro = bt.Cerebro()
# Add the mapped strategy to the backtest
cerebro.addstrategy(strategy)
# Add all the data feeds to Cerebro
total_bars = 0 # Total number of data points (bars) across all feeds
for source, data_feed in data_feed_map.items():
bt_feed = bt.feeds.PandasData(dataname=data_feed)
cerebro.adddata(bt_feed)
strategy.source_data_feed_map[source] = bt_feed
total_bars = max(total_bars, len(data_feed)) # Get the total bars from the largest feed
# Capture initial capital
initial_capital = cerebro.broker.getvalue()
# Progress tracking variables
current_bar = 0
last_progress = 0
# Custom next function to track progress (if you have a large dataset)
def track_progress():
nonlocal current_bar, last_progress
current_bar += 1
progress = (current_bar / total_bars) * 100
# Send progress update every 10% increment
if progress >= last_progress + 10:
last_progress += 10
socket_conn.send(json.dumps({"progress": int(last_progress)}))
# Attach the custom next method to the strategy
strategy.next = track_progress
# Run the backtest
print("Running backtest...")
start_time = dt.datetime.now()
cerebro.run()
end_time = dt.datetime.now()
# Extract performance metrics
final_value = cerebro.broker.getvalue()
run_duration = (end_time - start_time).total_seconds()
# Send 100% completion
socket_conn.send(json.dumps({"progress": 100}))
# Prepare the results to pass into the callback
callback({
"initial_capital": initial_capital,
"final_portfolio_value": final_value,
"run_duration": run_duration
})
# Map the user strategy and prepare the data feeds
sources = strategy.extract_sources()
data_feed_map = self.prepare_data_feed(msg_data['start_date'], sources, user_name)
# Run the backtest in a separate thread
thread = threading.Thread(target=execute_backtest)
thread.start()
def handle_backtest_message(self, user_id, msg_data, socket_conn):
user_name = msg_data.get('user_name')
backtest_name = f"{msg_data['strategy']}_backtest"
# Cache the backtest data
self.cache_backtest(user_name, backtest_name, msg_data)
# Fetch the strategy using user_id and strategy_name
strategy_name = msg_data.get('strategy')
user_strategy = self.strategies.get_strategy_by_name(user_id=user_id, name=strategy_name)
if not user_strategy:
return {"error": f"Strategy {strategy_name} not found for user {user_name}"}
# Extract sources from the strategy JSON
sources = self.extract_sources_from_strategy_json(user_strategy.get('strategy_json'))
if not sources:
return {"error": "No valid sources found in the strategy."}
# Prepare the data feed map based on extracted sources
data_feed_map = self.prepare_data_feed(msg_data['start_date'], sources, user_name)
if data_feed_map is None:
return {"error": "Data feed could not be prepared. Please check the data source."}
# Map the user strategy to a Backtrader strategy class
mapped_strategy = self.map_user_strategy(user_strategy)
# Define the callback function to handle backtest completion
def backtest_callback(results):
self.store_backtest_results(user_name, backtest_name, results)
self.update_strategy_stats(user_id, strategy_name, results)
# Run the backtest and pass the callback function, msg_data, and user_name
self.run_backtest(mapped_strategy, data_feed_map, msg_data, user_name, backtest_callback, socket_conn)
return {"reply": "backtest_started"}
def extract_sources_from_strategy_json(self, strategy_json):
sources = []
# Parse the JSON strategy to extract sources
def traverse_blocks(blocks):
for block in blocks:
if block['type'] == 'source':
source = {
'timeframe': block['fields'].get('TF'),
'exchange': block['fields'].get('EXC'),
'symbol': block['fields'].get('SYM')
}
sources.append(source)
# Recursively traverse inputs and statements
if 'inputs' in block:
traverse_blocks(block['inputs'].values())
if 'statements' in block:
traverse_blocks(block['statements'].values())
if 'next' in block:
traverse_blocks([block['next']])
traverse_blocks(strategy_json)
return sources
def update_strategy_stats(self, user_id, strategy_name, results):
""" Update the strategy stats with the backtest results """
strategy = self.strategies.get_strategy_by_name(user_id=user_id, name=strategy_name)
if strategy:
initial_capital = results['initial_capital']
final_value = results['final_portfolio_value']
returns = np.array(results['returns'])
equity_curve = np.array(results['equity_curve'])
trades = results['trades']
total_return = (final_value - initial_capital) / initial_capital * 100
risk_free_rate = 0.0
mean_return = np.mean(returns)
std_return = np.std(returns)
sharpe_ratio = (mean_return - risk_free_rate) / std_return if std_return != 0 else 0
running_max = np.maximum.accumulate(equity_curve)
drawdowns = (equity_curve - running_max) / running_max
max_drawdown = np.min(drawdowns) * 100
num_trades = len(trades)
wins = sum(1 for trade in trades if trade['profit'] > 0)
losses = num_trades - wins
win_loss_ratio = wins / losses if losses != 0 else wins
stats = {
'total_return': total_return,
'sharpe_ratio': sharpe_ratio,
'max_drawdown': max_drawdown,
'number_of_trades': num_trades,
'win_loss_ratio': win_loss_ratio,
}
strategy.update_stats(stats)
else:
print(f"Strategy {strategy_name} not found for user {user_id}.")
def store_backtest_results(self, user_name, backtest_name, results):
""" Store the backtest results in the cache """
cache_key = f"backtest:{user_name}:{backtest_name}"
filter_vals = [('tbl_key', cache_key)]
backtest_data = self.data_cache.get_rows_from_cache('tests', filter_vals)
if not backtest_data.empty:
backtest_data['results'] = results
self.data_cache.insert_row_into_cache('tests', backtest_data.keys(), backtest_data.values(), key=cache_key)
else:
print(f"Backtest {backtest_name} not found in cache.")