313 lines
13 KiB
Python
313 lines
13 KiB
Python
import ast
|
|
import json
|
|
import re
|
|
|
|
import backtrader as bt
|
|
import datetime as dt
|
|
from DataCache_v3 import DataCache
|
|
from Strategies import Strategies
|
|
import threading
|
|
import numpy as np
|
|
|
|
|
|
class Backtester:
|
|
def __init__(self, data_cache: DataCache, strategies: Strategies):
|
|
""" Initialize the Backtesting class with a cache for back-tests """
|
|
self.data_cache = data_cache
|
|
self.strategies = strategies
|
|
# Create a cache for storing back-tests
|
|
self.data_cache.create_cache('tests', cache_type='row', size_limit=100,
|
|
default_expiration=dt.timedelta(days=1),
|
|
eviction_policy='evict')
|
|
|
|
def get_default_chart_view(self, user_name):
|
|
"""Fetch default chart view if no specific source is provided."""
|
|
return self.data_cache.get_datacache_item(
|
|
item_name='chart_view', cache_name='users', filter_vals=('user_name', user_name))
|
|
|
|
def cache_backtest(self, user_name, backtest_name, backtest_data):
|
|
""" Cache the backtest data for a user """
|
|
columns = ('user_name', 'strategy_name', 'start_time', 'capital', 'commission', 'results')
|
|
values = (
|
|
backtest_data.get('user_name'),
|
|
backtest_data.get('strategy'),
|
|
backtest_data.get('start_date'),
|
|
backtest_data.get('capital', 10000), # Default capital if not provided
|
|
backtest_data.get('commission', 0.001), # Default commission
|
|
None # No results yet; will be filled in after backtest completion
|
|
)
|
|
cache_key = f"backtest:{user_name}:{backtest_name}"
|
|
self.data_cache.insert_row_into_cache('tests', columns, values, key=cache_key)
|
|
|
|
def map_user_strategy(self, user_strategy):
|
|
"""Maps user strategy details into a Backtrader-compatible strategy class."""
|
|
|
|
class MappedStrategy(bt.Strategy):
|
|
params = (
|
|
('initial_cash', user_strategy['params'].get('initial_cash', 10000)),
|
|
('commission', user_strategy['params'].get('commission', 0.001)),
|
|
)
|
|
|
|
def __init__(self):
|
|
# Extract unique sources (exchange, symbol, timeframe) from blocks
|
|
self.sources = self.extract_sources(user_strategy)
|
|
|
|
# Map of source to data feed (used later in next())
|
|
self.source_data_feed_map = {}
|
|
|
|
def extract_sources(self, user_strategy):
|
|
"""Extracts unique sources from the strategy."""
|
|
sources = []
|
|
for block in user_strategy.get('blocks', []):
|
|
if block.get('type') in ['last_candle_value', 'trade_action']:
|
|
source = self.extract_source_from_block(block)
|
|
if source and source not in sources:
|
|
sources.append(source)
|
|
elif block.get('type') == 'target_market':
|
|
target_source = self.extract_target_market(block)
|
|
if target_source and target_source not in sources:
|
|
sources.append(target_source)
|
|
return sources
|
|
|
|
def extract_source_from_block(self, block):
|
|
"""Extract source (exchange, symbol, timeframe) from a strategy block."""
|
|
source = {}
|
|
if block.get('type') == 'last_candle_value':
|
|
source = block.get('SOURCE', None)
|
|
# If SOURCE is missing, use the trade target or default
|
|
if not source:
|
|
source = self.get_default_chart_view(self.user_name) # Fallback to default
|
|
return source
|
|
|
|
def extract_target_market(self, block):
|
|
"""Extracts target market data (timeframe, exchange, symbol) from the trade_action block."""
|
|
target_market = block.get('target_market', {})
|
|
return {
|
|
'timeframe': target_market.get('TF', '5m'),
|
|
'exchange': target_market.get('EXC', 'Binance'),
|
|
'symbol': target_market.get('SYM', 'BTCUSD')
|
|
}
|
|
|
|
def next(self):
|
|
"""Execute trading logic using the compiled strategy."""
|
|
try:
|
|
exec(self.compiled_logic, {'self': self, 'data_feeds': self.source_data_feed_map})
|
|
except Exception as e:
|
|
print(f"Error executing trading logic: {e}")
|
|
|
|
return MappedStrategy
|
|
|
|
def prepare_data_feed(self, start_date: str, sources: list, user_name: str):
|
|
"""
|
|
Prepare multiple data feeds based on the start date and list of sources.
|
|
"""
|
|
try:
|
|
# Convert the start date to a datetime object
|
|
start_dt = dt.datetime.strptime(start_date, '%Y-%m-%dT%H:%M')
|
|
|
|
# Dictionary to map each source to its corresponding data feed
|
|
data_feeds = {}
|
|
|
|
for source in sources:
|
|
# Ensure exchange details contain required keys (fallback if missing)
|
|
asset = source.get('asset', 'BTCUSD')
|
|
timeframe = source.get('timeframe', '5m')
|
|
exchange = source.get('exchange', 'Binance')
|
|
|
|
# Fetch OHLC data from DataCache based on the source
|
|
ex_details = [asset, timeframe, exchange, user_name]
|
|
data = self.data_cache.get_records_since(start_dt, ex_details)
|
|
|
|
# Return the data as a Pandas DataFrame compatible with Backtrader
|
|
data_feeds[tuple(ex_details)] = data
|
|
|
|
return data_feeds
|
|
|
|
except Exception as e:
|
|
print(f"Error preparing data feed: {e}")
|
|
return None
|
|
|
|
def run_backtest(self, strategy, data_feed_map, msg_data, user_name, callback, socket_conn):
|
|
"""
|
|
Runs a backtest using Backtrader on a separate thread and calls the callback with the results when finished.
|
|
Also sends progress updates to the client via WebSocket.
|
|
"""
|
|
|
|
def execute_backtest():
|
|
cerebro = bt.Cerebro()
|
|
|
|
# Add the mapped strategy to the backtest
|
|
cerebro.addstrategy(strategy)
|
|
|
|
# Add all the data feeds to Cerebro
|
|
total_bars = 0 # Total number of data points (bars) across all feeds
|
|
for source, data_feed in data_feed_map.items():
|
|
bt_feed = bt.feeds.PandasData(dataname=data_feed)
|
|
cerebro.adddata(bt_feed)
|
|
strategy.source_data_feed_map[source] = bt_feed
|
|
total_bars = max(total_bars, len(data_feed)) # Get the total bars from the largest feed
|
|
|
|
# Capture initial capital
|
|
initial_capital = cerebro.broker.getvalue()
|
|
|
|
# Progress tracking variables
|
|
current_bar = 0
|
|
last_progress = 0
|
|
|
|
# Custom next function to track progress (if you have a large dataset)
|
|
def track_progress():
|
|
nonlocal current_bar, last_progress
|
|
current_bar += 1
|
|
progress = (current_bar / total_bars) * 100
|
|
|
|
# Send progress update every 10% increment
|
|
if progress >= last_progress + 10:
|
|
last_progress += 10
|
|
socket_conn.send(json.dumps({"progress": int(last_progress)}))
|
|
|
|
# Attach the custom next method to the strategy
|
|
strategy.next = track_progress
|
|
|
|
# Run the backtest
|
|
print("Running backtest...")
|
|
start_time = dt.datetime.now()
|
|
cerebro.run()
|
|
end_time = dt.datetime.now()
|
|
|
|
# Extract performance metrics
|
|
final_value = cerebro.broker.getvalue()
|
|
run_duration = (end_time - start_time).total_seconds()
|
|
|
|
# Send 100% completion
|
|
socket_conn.send(json.dumps({"progress": 100}))
|
|
|
|
# Prepare the results to pass into the callback
|
|
callback({
|
|
"initial_capital": initial_capital,
|
|
"final_portfolio_value": final_value,
|
|
"run_duration": run_duration
|
|
})
|
|
|
|
# Map the user strategy and prepare the data feeds
|
|
sources = strategy.extract_sources()
|
|
data_feed_map = self.prepare_data_feed(msg_data['start_date'], sources, user_name)
|
|
|
|
# Run the backtest in a separate thread
|
|
thread = threading.Thread(target=execute_backtest)
|
|
thread.start()
|
|
|
|
def handle_backtest_message(self, user_id, msg_data, socket_conn):
|
|
user_name = msg_data.get('user_name')
|
|
backtest_name = f"{msg_data['strategy']}_backtest"
|
|
|
|
# Cache the backtest data
|
|
self.cache_backtest(user_name, backtest_name, msg_data)
|
|
|
|
# Fetch the strategy using user_id and strategy_name
|
|
strategy_name = msg_data.get('strategy')
|
|
user_strategy = self.strategies.get_strategy_by_name(user_id=user_id, name=strategy_name)
|
|
|
|
if not user_strategy:
|
|
return {"error": f"Strategy {strategy_name} not found for user {user_name}"}
|
|
|
|
# Extract sources from the strategy JSON
|
|
sources = self.extract_sources_from_strategy_json(user_strategy.get('strategy_json'))
|
|
|
|
if not sources:
|
|
return {"error": "No valid sources found in the strategy."}
|
|
|
|
# Prepare the data feed map based on extracted sources
|
|
data_feed_map = self.prepare_data_feed(msg_data['start_date'], sources, user_name)
|
|
|
|
if data_feed_map is None:
|
|
return {"error": "Data feed could not be prepared. Please check the data source."}
|
|
|
|
# Map the user strategy to a Backtrader strategy class
|
|
mapped_strategy = self.map_user_strategy(user_strategy)
|
|
|
|
# Define the callback function to handle backtest completion
|
|
def backtest_callback(results):
|
|
self.store_backtest_results(user_name, backtest_name, results)
|
|
self.update_strategy_stats(user_id, strategy_name, results)
|
|
|
|
# Run the backtest and pass the callback function, msg_data, and user_name
|
|
self.run_backtest(mapped_strategy, data_feed_map, msg_data, user_name, backtest_callback, socket_conn)
|
|
|
|
return {"reply": "backtest_started"}
|
|
|
|
def extract_sources_from_strategy_json(self, strategy_json):
|
|
sources = []
|
|
|
|
# Parse the JSON strategy to extract sources
|
|
def traverse_blocks(blocks):
|
|
for block in blocks:
|
|
if block['type'] == 'source':
|
|
source = {
|
|
'timeframe': block['fields'].get('TF'),
|
|
'exchange': block['fields'].get('EXC'),
|
|
'symbol': block['fields'].get('SYM')
|
|
}
|
|
sources.append(source)
|
|
# Recursively traverse inputs and statements
|
|
if 'inputs' in block:
|
|
traverse_blocks(block['inputs'].values())
|
|
if 'statements' in block:
|
|
traverse_blocks(block['statements'].values())
|
|
if 'next' in block:
|
|
traverse_blocks([block['next']])
|
|
|
|
traverse_blocks(strategy_json)
|
|
return sources
|
|
|
|
def update_strategy_stats(self, user_id, strategy_name, results):
|
|
""" Update the strategy stats with the backtest results """
|
|
strategy = self.strategies.get_strategy_by_name(user_id=user_id, name=strategy_name)
|
|
|
|
if strategy:
|
|
initial_capital = results['initial_capital']
|
|
final_value = results['final_portfolio_value']
|
|
returns = np.array(results['returns'])
|
|
equity_curve = np.array(results['equity_curve'])
|
|
trades = results['trades']
|
|
|
|
total_return = (final_value - initial_capital) / initial_capital * 100
|
|
|
|
risk_free_rate = 0.0
|
|
mean_return = np.mean(returns)
|
|
std_return = np.std(returns)
|
|
sharpe_ratio = (mean_return - risk_free_rate) / std_return if std_return != 0 else 0
|
|
|
|
running_max = np.maximum.accumulate(equity_curve)
|
|
drawdowns = (equity_curve - running_max) / running_max
|
|
max_drawdown = np.min(drawdowns) * 100
|
|
|
|
num_trades = len(trades)
|
|
wins = sum(1 for trade in trades if trade['profit'] > 0)
|
|
losses = num_trades - wins
|
|
win_loss_ratio = wins / losses if losses != 0 else wins
|
|
|
|
stats = {
|
|
'total_return': total_return,
|
|
'sharpe_ratio': sharpe_ratio,
|
|
'max_drawdown': max_drawdown,
|
|
'number_of_trades': num_trades,
|
|
'win_loss_ratio': win_loss_ratio,
|
|
}
|
|
|
|
strategy.update_stats(stats)
|
|
else:
|
|
print(f"Strategy {strategy_name} not found for user {user_id}.")
|
|
|
|
def store_backtest_results(self, user_name, backtest_name, results):
|
|
""" Store the backtest results in the cache """
|
|
cache_key = f"backtest:{user_name}:{backtest_name}"
|
|
|
|
filter_vals = [('tbl_key', cache_key)]
|
|
backtest_data = self.data_cache.get_rows_from_cache('tests', filter_vals)
|
|
|
|
if not backtest_data.empty:
|
|
backtest_data['results'] = results
|
|
self.data_cache.insert_row_into_cache('tests', backtest_data.keys(), backtest_data.values(), key=cache_key)
|
|
else:
|
|
print(f"Backtest {backtest_name} not found in cache.")
|