The test are running but not without errors.

This commit is contained in:
Rob 2024-11-11 17:22:06 -04:00
parent 4fcc6f661d
commit 33298b7178
8 changed files with 1183 additions and 358 deletions

View File

@ -1,61 +1,3 @@
"""
{ {"generated_code": "def next():\n if flags.get('start', False) == False:\n if (get_current_price(timeframe='1h', exchange='binance', symbol='BTC/USD') > process_indicator('a bolengerband', 'middle')):\n flags['start'] = True\n set_available_strategy_balance((get_current_balance() / 20))\n if (get_last_candle(candle_part='close', timeframe='1h', exchange='binance', symbol='BTC/USD') > get_last_candle(candle_part='open', timeframe='1h', exchange='binance', symbol='BTC/USD')):\n variables['order_amount'] = (get_available_strategy_balance() / 20)\n trade_order(trade_type='buy', size=variables.get('order_amount', None), order_type='market', source={'exchange': 'binance', 'timeframe': '1h', 'market': 'BTC/USD'}, tif='GTC', stop_loss={'value': (get_current_price(timeframe='1h', exchange='binance', symbol='BTC/USD') - variables.get('order_amount', None))}, trailing_stop=None, take_profit={'value': (get_current_price(timeframe='1h', exchange='binance', symbol='BTC/USD') + variables.get('order_amount', None))}, limit=None, trailing_limit=None, target_market=None, name_order=None)\n if exit:\n exit_strategy()\n paused = True # Pause the strategy while exiting.", "indicators": [{"name": "a bolengerband", "output": "middle"}], "data_sources": [["binance", "BTC/USD", "1h"]], "flags_used": ["start"]}
"name": "fff",
"strategy_json": {
"type": "strategy",
"statements": [
{
"type": "set_available_strategy_balance",
"inputs": {
"BALANCE": {
"type": "math_operation",
"inputs": {
"operator": "add",
"left_operand": 1,
"right_operand": {
"type": "math_operation",
"inputs": {
"operator": "add",
"left_operand": {
"type": "power",
"inputs": {
"base": 2,
"exponent": 3
}
},
"right_operand": {
"type": "math_operation",
"inputs": {
"operator": "multiply",
"left_operand": {
"type": "min",
"inputs": {
"numbers": [
{
"type": "current_balance",
"inputs": {}
},
{
"type": "dynamic_value",
"values": [
5,
6
]
}
]
}
},
"right_operand": 4
}
}
}
}
}
}
}
}
]
},
"workspace": "<xml xmlns=\"https://developers.google.com/blockly/xml\"><block type=\"set_available_strategy_balance\" id=\"AC+~IeO;#NLcE`*p`-{8\" x=\"-250\" y=\"130\"><comment pinned=\"false\" h=\"80\" w=\"160\">Set the balance allocated to the strategy.</comment><value name=\"BALANCE\"><block type=\"math_operation\" id=\"|-W{hIFm.z5Bd#m-OZs}\"><field name=\"operator\">ADD</field><comment pinned=\"false\" h=\"80\" w=\"160\">Perform basic arithmetic operations between two values.</comment><value name=\"LEFT\"><block type=\"value_input\" id=\"3`4X?VL:]|2FbEiMRz]f\"><field name=\"VALUE\">1</field><comment pinned=\"false\" h=\"80\" w=\"160\">Enter a numerical value. Chain multiple for a list.</comment></block></value><value name=\"RIGHT\"><block type=\"math_operation\" id=\"%~{X-%0Q37cn0L2aul,8\"><field name=\"operator\">ADD</field><comment pinned=\"false\" h=\"80\" w=\"160\">Perform basic arithmetic operations between two values.</comment><value name=\"LEFT\"><block type=\"power\" id=\"3;*+0$b19;lFVk%5iR({\"><comment pinned=\"false\" h=\"80\" w=\"160\">Raise a number to the power of another number (x^y).</comment><value name=\"VALUES\"><block type=\"value_input\" id=\"z`L(+hvv*Fu#slvL`yDj\"><field name=\"VALUE\">2</field><comment pinned=\"false\" h=\"80\" w=\"160\">Enter a numerical value. Chain multiple for a list.</comment><value name=\"NEXT\"><block type=\"value_input\" id=\"r6%+3l|2TezFf/hj^5:o\"><field name=\"VALUE\">3</field><comment pinned=\"false\" h=\"80\" w=\"160\">Enter a numerical value. Chain multiple for a list.</comment></block></value></block></value></block></value><value name=\"RIGHT\"><block type=\"math_operation\" id=\"(O686k#xdA9_nW5]s4*;\"><field name=\"operator\">MULTIPLY</field><comment pinned=\"false\" h=\"80\" w=\"160\">Perform basic arithmetic operations between two values.</comment><value name=\"LEFT\"><block type=\"min\" id=\"{XPyvd;n_o~??i(P65Yi\"><comment pinned=\"false\" h=\"80\" w=\"160\">Determine the minimum value among given numbers.</comment><value name=\"VALUES\"><block type=\"current_balance\" id=\"BapPJm/W8_)QeI98~JF6\"><comment pinned=\"false\" h=\"80\" w=\"160\">Retrieve the current balance of the strategy.</comment><value name=\"VALUES\"><block type=\"value_input\" id=\"4:wZlOvQQJ}b)cO6MSh~\"><field name=\"VALUE\">5</field><comment pinned=\"false\" h=\"80\" w=\"160\">Enter a numerical value. Chain multiple for a list.</comment><value name=\"NEXT\"><block type=\"value_input\" id=\"L.Hk_E@2r6#_F)1UrrZG\"><field name=\"VALUE\">6</field><comment pinned=\"false\" h=\"80\" w=\"160\">Enter a numerical value. Chain multiple for a list.</comment></block></value></block></value></block></value></block></value><value name=\"RIGHT\"><block type=\"value_input\" id=\"xjq($pfgII*[?r|Mu:*F\"><field name=\"VALUE\">4</field><comment pinned=\"false\" h=\"80\" w=\"160\">Enter a numerical value. Chain multiple for a list.</comment></block></value></block></value></block></value></block></value></block></xml>"
}
"""

View File

@ -2,6 +2,7 @@
import logging import logging
import math import math
import re
from typing import Any, Dict, List, Set, Tuple, Union from typing import Any, Dict, List, Set, Tuple, Union
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -25,50 +26,40 @@ class PythonGenerator:
self.scheduled_action_count: int = 0 self.scheduled_action_count: int = 0
self.scheduled_actions: List[str] = [] self.scheduled_actions: List[str] = []
def generate(self, strategy_json: Dict[str, Any]) -> Dict[str, Any]: def generate(self, strategy_json: dict) -> dict:
""" code_lines = ["def next():"]
Generates the 'next()' method code and collects indicators, data sources, and flags used. indent_level = 1 # Starting indentation inside the function
:param strategy_json: The JSON definition of the strategy. # Extract statements from the strategy
:return: A dictionary containing 'generated_code', 'indicators', 'data_sources', and 'flags_used'. statements = strategy_json.get('statements', [])
""" if not isinstance(statements, list):
# Reset tracking attributes logger.error("'statements' should be a list in strategy_json.")
self.indicators_used.clear() # Handle as needed, possibly returning an error or skipping code generation
self.data_sources_used.clear() return {
self.flags_used.clear() "generated_code": '\n'.join(code_lines),
"indicators": list(self.indicators_used),
"data_sources": list(self.data_sources_used),
"flags_used": list(self.flags_used)
}
# Initialize code components # Process each statement
code_lines = [] code_lines.extend(self.generate_code_from_json(statements, indent_level))
indent_level = 1 # For 'next' method code indentation
# Start generating the 'next' method
code_lines.append("def next():")
indent_level += 1 # Increase indent level inside the 'next' method
# Recursively generate code from JSON nodes
code_lines.extend(self.generate_code_from_json(strategy_json, indent_level))
# Handle exit logic at the end of 'next()' # Handle exit logic at the end of 'next()'
indent = ' ' * indent_level indent = ' ' * indent_level
exit_indent = ' ' * (indent_level + 1) exit_indent = ' ' * (indent_level + 1)
code_lines.append(f"{indent}if self.exit:") code_lines.append(f"{indent}exit = flags.get('exit', False)")
code_lines.append(f"{exit_indent}self.exit_strategy()") code_lines.append(f"{indent}if exit:")
code_lines.append(f"{exit_indent}self.paused = True # Pause the strategy while exiting.") code_lines.append(f"{exit_indent}exit_strategy()")
code_lines.append(f"{exit_indent}set_paused(True) # Pause the strategy while exiting.")
# Join the code lines into a single string return {
next_method_code = '\n'.join(code_lines) "generated_code": '\n'.join(code_lines),
"indicators": list(self.indicators_used),
# Prepare the combined dictionary "data_sources": list(self.data_sources_used),
strategy_components = { "flags_used": list(self.flags_used)
'generated_code': next_method_code,
'indicators': self.indicators_used.copy(),
'data_sources': list(self.data_sources_used),
'flags_used': list(self.flags_used)
} }
logger.debug("Generated 'next()' method code.")
return strategy_components
# ============================== # ==============================
# Helper Methods # Helper Methods
# ============================== # ==============================
@ -156,6 +147,11 @@ class PythonGenerator:
elif isinstance(handler_code, str): elif isinstance(handler_code, str):
code_lines.append(handler_code) code_lines.append(handler_code)
# Process 'next' recursively if present
next_node = node.get('next')
if next_node:
code_lines.extend(self.generate_code_from_json(next_node, indent_level))
return code_lines return code_lines
def handle_default(self, node: Dict[str, Any], indent_level: int) -> str: def handle_default(self, node: Dict[str, Any], indent_level: int) -> str:
@ -199,17 +195,18 @@ class PythonGenerator:
def handle_indicator(self, node: Dict[str, Any], indent_level: int) -> str: def handle_indicator(self, node: Dict[str, Any], indent_level: int) -> str:
""" """
Handles the 'indicator' condition type. Handles the 'indicator_a_bolengerband' node type by generating a function call to retrieve indicator values.
:param node: The indicator node. :param node: The indicator_a_bolengerband node.
:param indent_level: Current indentation level. :param indent_level: Current indentation level.
:return: A string representing the condition. :return: A string representing the indicator value retrieval.
""" """
indicator_name = node.get('name') fields = node.get('fields', {})
output_field = node.get('output') indicator_name = fields.get('NAME')
output_field = fields.get('OUTPUT')
if not indicator_name or not output_field: if not indicator_name or not output_field:
logger.error("indicator node missing 'name' or 'output'.") logger.error("indicator node missing 'NAME' or 'OUTPUT'.")
return 'None' return 'None'
# Collect the indicator information # Collect the indicator information
@ -470,8 +467,18 @@ class PythonGenerator:
:param indent_level: Current indentation level. :param indent_level: Current indentation level.
:return: A string representing the condition. :return: A string representing the condition.
""" """
symbol = node.get('symbol', self.default_source.get('market', 'BTCUSD')) # Process source input
return f"get_current_price(symbol='{symbol}')" inputs = node.get('inputs', {})
source_node = inputs.get('source', {})
timeframe = source_node.get('timeframe', self.default_source.get('timeframe', '1m'))
exchange = source_node.get('exchange', self.default_source.get('exchange', 'Binance'))
symbol = source_node.get('symbol', self.default_source.get('market', 'BTC/USD'))
# Track data sources
self.data_sources_used.add((exchange, symbol, timeframe))
# Correctly format the function call with separate parameters
return f"get_current_price(timeframe='{timeframe}', exchange='{exchange}', symbol='{symbol}')"
def handle_bid_price(self, node: Dict[str, Any], indent_level: int) -> str: def handle_bid_price(self, node: Dict[str, Any], indent_level: int) -> str:
""" """
@ -481,8 +488,18 @@ class PythonGenerator:
:param indent_level: Current indentation level. :param indent_level: Current indentation level.
:return: A string representing the condition. :return: A string representing the condition.
""" """
symbol = node.get('symbol', self.default_source.get('market', 'BTCUSD')) # Process source input
return f"get_bid_price(symbol='{symbol}')" inputs = node.get('inputs', {})
source_node = inputs.get('source', {})
timeframe = source_node.get('timeframe', self.default_source.get('timeframe', '1m'))
exchange = source_node.get('exchange', self.default_source.get('exchange', 'Binance'))
symbol = source_node.get('symbol', self.default_source.get('market', 'BTC/USD'))
# Track data sources
self.data_sources_used.add((exchange, symbol, timeframe))
# Correctly format the function call with separate parameters
return f"get_bid_price(timeframe='{timeframe}', exchange='{exchange}', symbol='{symbol}')"
def handle_ask_price(self, node: Dict[str, Any], indent_level: int) -> str: def handle_ask_price(self, node: Dict[str, Any], indent_level: int) -> str:
""" """
@ -492,18 +509,29 @@ class PythonGenerator:
:param indent_level: Current indentation level. :param indent_level: Current indentation level.
:return: A string representing the condition. :return: A string representing the condition.
""" """
symbol = node.get('symbol', self.default_source.get('market', 'BTCUSD')) # Process source input
return f"get_ask_price(symbol='{symbol}')" inputs = node.get('inputs', {})
source_node = inputs.get('source', {})
timeframe = source_node.get('timeframe', self.default_source.get('timeframe', '1m'))
exchange = source_node.get('exchange', self.default_source.get('exchange', 'Binance'))
symbol = source_node.get('symbol', self.default_source.get('market', 'BTC/USD'))
# Track data sources
self.data_sources_used.add((exchange, symbol, timeframe))
# Correctly format the function call with separate parameters
return f"get_ask_price(timeframe='{timeframe}', exchange='{exchange}', symbol='{symbol}')"
def handle_last_candle_value(self, node: Dict[str, Any], indent_level: int) -> str: def handle_last_candle_value(self, node: Dict[str, Any], indent_level: int) -> str:
""" """
Handles the 'last_candle_value' condition type. Handles the 'last_candle_value' condition type by generating a function call to get candle data.
:param node: The last_candle_value node. :param node: The last_candle_value node.
:param indent_level: Current indentation level. :param indent_level: Current indentation level.
:return: A string representing the condition. :return: A string representing the candle data retrieval.
""" """
candle_part = node.get('candle_part', 'close').lower() inputs = node.get('inputs', {})
candle_part = inputs.get('candle_part', 'close').lower()
valid_candle_parts = ['open', 'high', 'low', 'close'] valid_candle_parts = ['open', 'high', 'low', 'close']
if candle_part not in valid_candle_parts: if candle_part not in valid_candle_parts:
logger.error(f"Invalid candle_part '{candle_part}' in 'last_candle_value'. Defaulting to 'close'.") logger.error(f"Invalid candle_part '{candle_part}' in 'last_candle_value'. Defaulting to 'close'.")
@ -513,7 +541,7 @@ class PythonGenerator:
source_node = node.get('source', {}) source_node = node.get('source', {})
timeframe = source_node.get('timeframe', self.default_source.get('timeframe', '1m')) timeframe = source_node.get('timeframe', self.default_source.get('timeframe', '1m'))
exchange = source_node.get('exchange', self.default_source.get('exchange', 'Binance')) exchange = source_node.get('exchange', self.default_source.get('exchange', 'Binance'))
symbol = source_node.get('symbol', self.default_source.get('market', 'BTCUSD')) symbol = source_node.get('symbol', self.default_source.get('market', 'BTC/USD'))
# Track data sources # Track data sources
self.data_sources_used.add((exchange, symbol, timeframe)) self.data_sources_used.add((exchange, symbol, timeframe))
@ -532,7 +560,7 @@ class PythonGenerator:
""" """
timeframe = node.get('time_frame', '5m') timeframe = node.get('time_frame', '5m')
exchange = node.get('exchange', 'Binance') exchange = node.get('exchange', 'Binance')
symbol = node.get('symbol', 'BTCUSD') symbol = node.get('symbol', 'BTC/USD')
# Track data sources # Track data sources
self.data_sources_used.add((exchange, symbol, timeframe)) self.data_sources_used.add((exchange, symbol, timeframe))
@ -583,7 +611,7 @@ class PythonGenerator:
left_expr = self.generate_condition_code(left_node, indent_level) left_expr = self.generate_condition_code(left_node, indent_level)
right_expr = self.generate_condition_code(right_node, indent_level) right_expr = self.generate_condition_code(right_node, indent_level)
condition = f"({left_expr} {python_operator} {right_expr})" condition = f"{left_expr} {python_operator} {right_expr}"
logger.debug(f"Generated comparison condition: {condition}") logger.debug(f"Generated comparison condition: {condition}")
return condition return condition
@ -605,7 +633,7 @@ class PythonGenerator:
left_expr = self.generate_condition_code(left_node, indent_level) left_expr = self.generate_condition_code(left_node, indent_level)
right_expr = self.generate_condition_code(right_node, indent_level) right_expr = self.generate_condition_code(right_node, indent_level)
condition = f"({left_expr} and {right_expr})" condition = f"{left_expr} and {right_expr}"
logger.debug(f"Generated logical AND condition: {condition}") logger.debug(f"Generated logical AND condition: {condition}")
return condition return condition
@ -627,7 +655,7 @@ class PythonGenerator:
left_expr = self.generate_condition_code(left_node, indent_level) left_expr = self.generate_condition_code(left_node, indent_level)
right_expr = self.generate_condition_code(right_node, indent_level) right_expr = self.generate_condition_code(right_node, indent_level)
condition = f"({left_expr} or {right_expr})" condition = f"{left_expr} or {right_expr}"
logger.debug(f"Generated logical OR condition: {condition}") logger.debug(f"Generated logical OR condition: {condition}")
return condition return condition
@ -696,7 +724,7 @@ class PythonGenerator:
# Collect data sources # Collect data sources
source = trade_options.get('source', self.default_source) source = trade_options.get('source', self.default_source)
exchange = source.get('exchange', 'binance') exchange = source.get('exchange', 'binance')
symbol = source.get('symbol', 'BTCUSDT') symbol = source.get('symbol', 'BTC/USD')
timeframe = source.get('timeframe', '5m') timeframe = source.get('timeframe', '5m')
self.data_sources_used.add((exchange, symbol, timeframe)) self.data_sources_used.add((exchange, symbol, timeframe))
@ -737,11 +765,31 @@ class PythonGenerator:
if not option: if not option:
return 'None' return 'None'
# Precompile the regex pattern for market symbols (e.g., 'BTC/USD')
market_symbol_pattern = re.compile(r'^[A-Z]{3}/[A-Z]{3}$')
def is_market_symbol(value: str) -> bool:
"""
Determines if a string is a market symbol following the pattern 'XXX/YYY'.
:param value: The string to check.
:return: True if it matches the market symbol pattern, False otherwise.
"""
return bool(market_symbol_pattern.match(value))
def format_value(value: Any) -> str: def format_value(value: Any) -> str:
if isinstance(value, str): if isinstance(value, str):
return f"'{value}'" if is_market_symbol(value):
return f"'{value}'" # Quote market symbols like 'BTC/USD'
# Check if the string represents an expression (contains operators or function calls)
elif any(op in value for op in ['(', ')', '+', '-', '*', '/', '.']):
return value # Assume it's an expression and return as-is
else:
return f"'{value}'"
elif isinstance(value, dict): elif isinstance(value, dict):
return self.format_trade_option(value) # Recursively format nested dictionaries
nested_items = [f"'{k}': {format_value(v)}" for k, v in value.items()]
return f"{{{', '.join(nested_items)}}}"
elif isinstance(value, (int, float)): elif isinstance(value, (int, float)):
return str(value) return str(value)
else: else:
@ -879,7 +927,7 @@ class PythonGenerator:
""" """
time_frame = inputs.get('time_frame', '1m') time_frame = inputs.get('time_frame', '1m')
exchange = inputs.get('exchange', 'Binance') exchange = inputs.get('exchange', 'Binance')
symbol = inputs.get('symbol', 'BTCUSDT') symbol = inputs.get('symbol', 'BTC/USD')
target_market = { target_market = {
'time_frame': time_frame, 'time_frame': time_frame,
@ -915,7 +963,7 @@ class PythonGenerator:
""" """
code_lines = [] code_lines = []
indent = ' ' * indent_level indent = ' ' * indent_level
code_lines.append(f"{indent}pause_strategy()") code_lines.append(f"{indent}set_paused(True)")
code_lines.append(f"{indent}notify_user('Strategy paused.')") code_lines.append(f"{indent}notify_user('Strategy paused.')")
return code_lines return code_lines
@ -946,8 +994,8 @@ class PythonGenerator:
code_lines = [] code_lines = []
indent = ' ' * indent_level indent = ' ' * indent_level
exit_option = node.get('condition', 'all') # 'all', 'in_profit', 'in_loss' exit_option = node.get('condition', 'all') # 'all', 'in_profit', 'in_loss'
code_lines.append(f"{indent}self.set_exit(True, '{exit_option}') # Initiate exit") code_lines.append(f"{indent}set_exit(True, '{exit_option}') # Initiate exit")
code_lines.append(f"{indent}self.set_paused(True) # Pause the strategy while exiting") code_lines.append(f"{indent}set_paused(True) # Pause the strategy while exiting")
return code_lines return code_lines
@ -1068,15 +1116,15 @@ class PythonGenerator:
inputs = node.get('inputs', {}) inputs = node.get('inputs', {})
condition = inputs.get('CONDITION', {}) condition = inputs.get('CONDITION', {})
statements = node.get('statements', {}).get('DO', []) do_statements = node.get('statements', {}).get('DO', [])
condition_code = self.generate_condition_code(condition, indent_level) condition_code = self.generate_condition_code(condition, indent_level)
code_lines.append(f"{indent}if {condition_code}:") code_lines.append(f"{indent}if {condition_code}:")
if not statements:
if not do_statements:
code_lines.append(f"{indent} pass # No actions defined") code_lines.append(f"{indent} pass # No actions defined")
else: else:
action_code = self.generate_code_from_json(statements, indent_level + 1) action_code = self.generate_code_from_json(do_statements, indent_level + 1)
if not action_code: if not action_code:
code_lines.append(f"{indent} pass # No valid actions defined") code_lines.append(f"{indent} pass # No valid actions defined")
else: else:
@ -1091,6 +1139,25 @@ class PythonGenerator:
# Values and Flags Handlers # Values and Flags Handlers
# ============================== # ==============================
def handle_dynamic_value(self, node: Dict[str, Any], indent_level: int) -> str:
"""
Handles the 'dynamic_value' node type.
:param node: The dynamic_value node.
:param indent_level: Current indentation level.
:return: A string representing the value.
"""
values = node.get('values', [])
if not values:
logger.error("dynamic_value node has no 'values'.")
return 'None'
# Assuming the first value is the primary value
first_value = values[0]
if isinstance(first_value, dict):
return self.generate_condition_code(first_value, indent_level)
else:
return str(first_value)
def handle_notify_user(self, node: Dict[str, Any], indent_level: int) -> List[str]: def handle_notify_user(self, node: Dict[str, Any], indent_level: int) -> List[str]:
""" """
Handles the 'notify_user' node type. Handles the 'notify_user' node type.
@ -1190,21 +1257,42 @@ class PythonGenerator:
flag_value = 'True' if str(flag_value_input).strip().lower() == 'true' else 'False' flag_value = 'True' if str(flag_value_input).strip().lower() == 'true' else 'False'
code_lines.append(f"{indent}flags['{flag_name}'] = {flag_value}") code_lines.append(f"{indent}flags['{flag_name}'] = {flag_value}")
self.flags_used.add(flag_name) self.flags_used.add(flag_name)
# # Process 'next' field if present
# next_node = node.get('next')
# if next_node:
# next_code = self.generate_code_from_json(next_node, indent_level)
# if isinstance(next_code, list):
# code_lines.extend(next_code)
# elif isinstance(next_code, str):
# code_lines.append(next_code)
return code_lines return code_lines
def handle_flag_is_set(self, node: Dict[str, Any], indent_level: int) -> str: def handle_flag_is_set(self, node: Dict[str, Any], indent_level: int) -> str:
""" """
Handles the 'flag_is_set' node type, checking if a flag is set. Handles the 'flag_is_set' condition type, checking if a flag is set to a specific value.
:param node: The flag_is_set node. :param node: The flag_is_set node.
:param indent_level: Current indentation level. :param indent_level: Current indentation level.
:return: A string representing the flag condition. :return: A string representing the flag condition.
""" """
flag_name = node.get('flag_name') flag_name = node.get('flag_name')
flag_value = node.get('flag_value', True) # Default to True if not specified
if not flag_name: if not flag_name:
logger.error("flag_is_set node missing 'flag_name'.") logger.error("flag_is_set node missing 'flag_name'.")
return 'False' return 'False'
return f"flags.get('{flag_name}', False)"
# Generate condition based on flag_value
if isinstance(flag_value, bool):
condition = f"flags.get('{flag_name}', False) == {flag_value}"
else:
logger.error(f"Unsupported flag_value type: {type(flag_value)}. Defaulting to 'False'.")
condition = 'False'
logger.debug(f"Generated flag_is_set condition: {condition}")
return condition
# Add other Values and Flags handlers here... # Add other Values and Flags handlers here...
@ -1279,10 +1367,6 @@ class PythonGenerator:
# Math Handlers # Math Handlers
# ============================== # ==============================
import math
import statistics
import random
def handle_math_operation(self, node: Dict[str, Any], indent_level: int) -> str: def handle_math_operation(self, node: Dict[str, Any], indent_level: int) -> str:
""" """
Handles the 'math_operation' node type. Handles the 'math_operation' node type.
@ -1291,9 +1375,11 @@ class PythonGenerator:
:param indent_level: Current indentation level. :param indent_level: Current indentation level.
:return: A string representing the math operation. :return: A string representing the math operation.
""" """
operator = node.get('operator', 'ADD') # Extract from 'inputs' instead of top-level
left_operand = node.get('left_operand') inputs = node.get('inputs', {})
right_operand = node.get('right_operand') operator = inputs.get('operator', 'ADD').upper()
left_operand = inputs.get('left_operand')
right_operand = inputs.get('right_operand')
operator_map = { operator_map = {
'ADD': '+', 'ADD': '+',
@ -1307,21 +1393,21 @@ class PythonGenerator:
left_expr = self.process_numeric_list(left_operand, indent_level) left_expr = self.process_numeric_list(left_operand, indent_level)
right_expr = self.process_numeric_list(right_operand, indent_level) right_expr = self.process_numeric_list(right_operand, indent_level)
expr = f"({left_expr} {python_operator} {right_expr})" expr = f"{left_expr} {python_operator} {right_expr}"
logger.debug(f"Generated math_operation expression: {expr}") logger.debug(f"Generated math_operation expression: {expr}")
return expr return expr
def handle_power(self, node: Dict[str, Any], indent_level: int) -> str: def handle_power(self, node: Dict[str, Any], indent_level: int) -> str:
base = self.process_numeric_list(node.get('base', 2), indent_level) base = self.process_numeric_list(node.get('base', 2), indent_level)
exponent = self.process_numeric_list(node.get('exponent', 3), indent_level) exponent = self.process_numeric_list(node.get('exponent', 3), indent_level)
expr = f"({base} ** {exponent})" expr = f"{base} ** {exponent}"
logger.debug(f"Generated power expression: {expr}") logger.debug(f"Generated power expression: {expr}")
return expr return expr
def handle_modulo(self, node: Dict[str, Any], indent_level: int) -> str: def handle_modulo(self, node: Dict[str, Any], indent_level: int) -> str:
dividend = self.process_numeric_list(node.get('dividend', 10), indent_level) dividend = self.process_numeric_list(node.get('dividend', 10), indent_level)
divisor = self.process_numeric_list(node.get('divisor', 3), indent_level) divisor = self.process_numeric_list(node.get('divisor', 3), indent_level)
expr = f"({dividend} % {divisor})" expr = f"{dividend} % {divisor}"
logger.debug(f"Generated modulo expression: {expr}") logger.debug(f"Generated modulo expression: {expr}")
return expr return expr

View File

@ -35,6 +35,24 @@ class Strategies:
default_expiration=dt.timedelta(hours=24), default_expiration=dt.timedelta(hours=24),
columns=["id", "creator", "name", "workspace", "code", "stats", "public", "fee", columns=["id", "creator", "name", "workspace", "code", "stats", "public", "fee",
"tbl_key", "strategy_components"]) "tbl_key", "strategy_components"])
# Create a cache for strategy contexts to store strategy states and settings
self.data_cache.create_cache(
name='strategy_contexts',
cache_type='table',
size_limit=1000,
eviction_policy='deny',
default_expiration=dt.timedelta(hours=24),
columns=[
"strategy_instance_id", # Unique identifier for the strategy instance
"flags", # JSON-encoded string to store flags
"profit_loss", # Float value for tracking profit/loss
"active", # Boolean or Integer (1/0) for active status
"paused", # Boolean or Integer (1/0) for paused status
"exit", # Boolean or Integer (1/0) for exit status
"exit_method", # String defining exit method
"start_time" # ISO-formatted datetime string for start time
]
)
# Initialize default settings # Initialize default settings
self.default_timeframe = '5m' self.default_timeframe = '5m'
@ -106,6 +124,12 @@ class Strategies:
return {"success": False, "message": "Invalid JSON format for 'code'."} return {"success": False, "message": "Invalid JSON format for 'code'."}
elif isinstance(code, dict): elif isinstance(code, dict):
strategy_json = code strategy_json = code
# Serialize 'code' to JSON string
try:
serialized_code = json.dumps(code)
strategy_data['code'] = serialized_code
except (TypeError, ValueError):
return {"success": False, "message": "Unable to serialize 'code' field."}
else: else:
return {"success": False, "message": "'code' must be a JSON string or dictionary."} return {"success": False, "message": "'code' must be a JSON string or dictionary."}
@ -361,8 +385,9 @@ class Strategies:
if not strategy_id or not strategy_name or not user_id: if not strategy_id or not strategy_name or not user_id:
return {"success": False, "message": "Strategy data is incomplete."} return {"success": False, "message": "Strategy data is incomplete."}
# Unique key for the strategy-user pair # Generate a deterministic strategy_instance_id
instance_key = (user_id, strategy_id) strategy_instance_id = f"{user_id}_{strategy_name}"
instance_key = (user_id, strategy_id) # Unique key for the strategy-user pair
# Retrieve or create StrategyInstance # Retrieve or create StrategyInstance
if instance_key not in self.active_instances: if instance_key not in self.active_instances:
@ -372,7 +397,7 @@ class Strategies:
# Instantiate StrategyInstance # Instantiate StrategyInstance
strategy_instance = StrategyInstance( strategy_instance = StrategyInstance(
strategy_instance_id=str(uuid.uuid4()), strategy_instance_id=strategy_instance_id,
strategy_id=strategy_id, strategy_id=strategy_id,
strategy_name=strategy_name, strategy_name=strategy_name,
user_id=user_id, user_id=user_id,
@ -382,18 +407,14 @@ class Strategies:
trades=self.trades trades=self.trades
) )
# Load existing context or initialize
strategy_instance.load_context()
# Store in active_instances # Store in active_instances
self.active_instances[instance_key] = strategy_instance self.active_instances[instance_key] = strategy_instance
logger.debug( logger.debug(f"Created new StrategyInstance '{strategy_instance_id}' for strategy '{strategy_id}'.")
f"Created new StrategyInstance '{strategy_instance.strategy_instance_id}' for strategy '{strategy_id}'.")
else: else:
strategy_instance = self.active_instances[instance_key] strategy_instance = self.active_instances[instance_key]
logger.debug( logger.debug(
f"Retrieved existing StrategyInstance '{strategy_instance.strategy_instance_id}' for strategy '{strategy_id}'.") f"Retrieved existing StrategyInstance '{strategy_instance_id}' for strategy '{strategy_id}'.")
# Execute the strategy # Execute the strategy
execution_result = strategy_instance.execute() execution_result = strategy_instance.execute()
@ -420,8 +441,8 @@ class Strategies:
overwrite='tbl_key' overwrite='tbl_key'
) )
else: else:
logger.info(f"Strategy '{strategy_id}' is exiting. Remaining" logger.info(
f" trades will be closed in subsequent executions.") f"Strategy '{strategy_id}' is exiting. Remaining trades will be closed in subsequent executions.")
return {"success": True, "strategy_profit_loss": profit_loss} return {"success": True, "strategy_profit_loss": profit_loss}
else: else:
@ -448,3 +469,45 @@ class Strategies:
except Exception as e: except Exception as e:
logger.error(f"Error updating strategies: {e}", exc_info=True) logger.error(f"Error updating strategies: {e}", exc_info=True)
traceback.print_exc() traceback.print_exc()
def update_stats(self, strategy_id: str, stats: dict) -> None:
"""
Updates the strategy's statistics with the provided stats.
:param strategy_id: Identifier of the strategy (tbl_key).
:param stats: Dictionary containing statistics to update.
"""
try:
# Fetch the current strategy data
strategy = self.data_cache.get_rows_from_datacache(
cache_name='strategies',
filter_vals=[('tbl_key', strategy_id)]
)
if strategy.empty:
logger.warning(f"Strategy ID {strategy_id} not found for stats update.")
return
strategy_row = strategy.iloc[0].to_dict()
current_stats = json.loads(strategy_row.get('stats', '{}'))
# Merge the new stats with existing stats
current_stats.update(stats)
# Serialize the updated stats
updated_stats_serialized = json.dumps(current_stats)
# Update the stats in the data cache
self.data_cache.modify_datacache_item(
cache_name='strategies',
filter_vals=[('tbl_key', strategy_id)],
field_names=('stats',),
new_values=(updated_stats_serialized,),
key=strategy_id,
overwrite='tbl_key'
)
logger.info(f"Updated stats for strategy '{strategy_id}': {current_stats}")
except Exception as e:
logger.error(f"Error updating stats for strategy '{strategy_id}': {e}", exc_info=True)

View File

@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
class StrategyInstance: class StrategyInstance:
def __init__(self, strategy_instance_id: str, strategy_id: str, strategy_name: str, def __init__(self, strategy_instance_id: str, strategy_id: str, strategy_name: str,
user_id: int, generated_code: str, data_cache: DataCache, indicators: Indicators, trades: Trades): user_id: int, generated_code: str, data_cache: DataCache, indicators: Indicators | None, trades: Trades | None):
""" """
Initializes a StrategyInstance. Initializes a StrategyInstance.
@ -25,6 +25,9 @@ class StrategyInstance:
:param indicators: Reference to the Indicators manager. :param indicators: Reference to the Indicators manager.
:param trades: Reference to the Trades manager. :param trades: Reference to the Trades manager.
""" """
# Initialize the backtrader_strategy attribute
self.backtrader_strategy = None # Will be set by Backtrader's MappedStrategy
self.strategy_instance_id = strategy_instance_id self.strategy_instance_id = strategy_instance_id
self.strategy_id = strategy_id self.strategy_id = strategy_id
self.strategy_name = strategy_name self.strategy_name = strategy_name
@ -36,7 +39,8 @@ class StrategyInstance:
# Initialize context variables # Initialize context variables
self.flags: dict[str, Any] = {} self.flags: dict[str, Any] = {}
self.starting_balance = self.trades.get_current_balance(self.user_id) self.variables: dict[str, Any] = {}
self.starting_balance: float = 0.0
self.profit_loss: float = 0.0 self.profit_loss: float = 0.0
self.active: bool = True self.active: bool = True
self.paused: bool = False self.paused: bool = False
@ -44,6 +48,68 @@ class StrategyInstance:
self.exit_method: str = 'all' self.exit_method: str = 'all'
self.start_time = dt.datetime.now() self.start_time = dt.datetime.now()
# Define the local execution environment
self.exec_context = {
'flags': self.flags,
'variables': self.variables,
'exit': self.exit,
'paused': self.paused,
'strategy_id': self.strategy_id,
'user_id': self.user_id,
'get_last_candle': self.get_last_candle,
'get_current_price': self.get_current_price,
'trade_order': self.trade_order,
'exit_strategy': self.exit_strategy,
'notify_user': self.notify_user,
'process_indicator': self.process_indicator,
'get_strategy_profit_loss': self.get_strategy_profit_loss,
'is_in_profit': self.is_in_profit,
'is_in_loss': self.is_in_loss,
'get_active_trades': self.get_active_trades,
'get_starting_balance': self.get_starting_balance,
'set_paused': self.set_paused,
'set_exit': self.set_exit,
'set_available_strategy_balance': self.set_available_strategy_balance,
'get_current_balance': self.get_current_balance,
'get_available_strategy_balance': self.get_available_strategy_balance
}
# Automatically load or initialize the context
self._initialize_or_load_context()
def _initialize_or_load_context(self):
"""
Checks if a context exists for the strategy instance. If it does, load it;
otherwise, initialize a new context.
"""
if self.data_cache.get_rows_from_datacache(
cache_name='strategy_contexts',
filter_vals=[('strategy_instance_id', self.strategy_instance_id)]
).empty:
self.initialize_new_context()
logger.debug(f"Initialized new context for StrategyInstance '{self.strategy_instance_id}'.")
else:
self.load_context()
logger.debug(f"Loaded existing context for StrategyInstance '{self.strategy_instance_id}'.")
def initialize_new_context(self):
"""
Initializes a new context for the strategy instance.
"""
self.flags = {}
self.variables = {}
self.profit_loss = 0.0
self.active = True
self.paused = False
self.exit = False
self.exit_method = 'all'
self.start_time = dt.datetime.now()
# Insert initial context into the cache
self.save_context()
logger.debug(f"New context created and saved for StrategyInstance '{self.strategy_instance_id}'.")
def load_context(self): def load_context(self):
""" """
Loads the strategy execution context from the database. Loads the strategy execution context from the database.
@ -59,17 +125,25 @@ class StrategyInstance:
context = context_data.iloc[0].to_dict() context = context_data.iloc[0].to_dict()
self.flags = json.loads(context.get('flags', '{}')) self.flags = json.loads(context.get('flags', '{}'))
self.starting_balance = context.get('starting_balance', 0.0)
self.profit_loss = context.get('profit_loss', 0.0) self.profit_loss = context.get('profit_loss', 0.0)
self.active = context.get('active', True) self.active = bool(context.get('active', True))
self.paused = context.get('paused', False) self.paused = bool(context.get('paused', False))
self.exit = context.get('exit', False) self.exit = bool(context.get('exit', False))
self.exit_method = context.get('exit_method', 'all') self.exit_method = context.get('exit_method', 'all')
start_time_str = context.get('start_time')
if start_time_str:
self.start_time = dt.datetime.fromisoformat(start_time_str)
context_start_time = context.get('start_time', None) # Update exec_context with loaded flags and variables
if context_start_time: self.exec_context['flags'] = self.flags
self.start_time = dt.datetime.fromisoformat(context_start_time) self.exec_context['variables'] = self.variables
self.exec_context['profit_loss'] = self.profit_loss
self.exec_context['active'] = self.active
self.exec_context['paused'] = self.paused
self.exec_context['exit'] = self.exit
self.exec_context['exit_method'] = self.exit_method
logger.debug(f"Context loaded for StrategyInstance '{self.strategy_instance_id}'.")
except Exception as e: except Exception as e:
logger.error(f"Error loading context for StrategyInstance '{self.strategy_instance_id}': {e}", logger.error(f"Error loading context for StrategyInstance '{self.strategy_instance_id}': {e}",
exc_info=True) exc_info=True)
@ -78,6 +152,7 @@ class StrategyInstance:
def save_context(self): def save_context(self):
""" """
Saves the current strategy execution context to the database. Saves the current strategy execution context to the database.
Inserts a new row if it doesn't exist; otherwise, updates the existing row.
""" """
try: try:
self.data_cache.modify_datacache_item( self.data_cache.modify_datacache_item(
@ -87,62 +162,74 @@ class StrategyInstance:
new_values=( new_values=(
json.dumps(self.flags), json.dumps(self.flags),
self.profit_loss, self.profit_loss,
self.active, int(self.active),
self.paused, int(self.paused),
self.exit, int(self.exit),
self.exit_method, self.exit_method,
self.start_time.isoformat() self.start_time.isoformat()
) )
) )
logger.debug(f"Context saved for StrategyInstance '{self.strategy_instance_id}'.")
except ValueError as ve:
# If the record does not exist, insert it
logger.warning(f"StrategyInstance '{self.strategy_instance_id}' context not found. Attempting to insert.")
self.data_cache.insert_row_into_datacache(
cache_name='strategy_contexts',
columns=(
"strategy_instance_id", "flags", "profit_loss",
"active", "paused", "exit", "exit_method", "start_time"
),
values=(
self.strategy_instance_id,
json.dumps(self.flags),
self.profit_loss,
int(self.active),
int(self.paused),
int(self.exit),
self.exit_method,
self.start_time.isoformat()
)
)
logger.debug(f"Inserted new context for StrategyInstance '{self.strategy_instance_id}'.")
except Exception as e: except Exception as e:
logger.error(f"Error saving context for StrategyInstance '{self.strategy_instance_id}': {e}") logger.error(f"Error saving context for StrategyInstance '{self.strategy_instance_id}': {e}")
traceback.print_exc() traceback.print_exc()
def override_exec_context(self, key: str, value: Any):
"""
Overrides a specific mapping in the execution context with a different method or variable.
:param key: The key in exec_context to override.
:param value: The new method or value to assign.
"""
self.exec_context[key] = value
logger.debug(f"Overridden exec_context key '{key}' with new value '{value}'.")
def execute(self) -> dict[str, Any]: def execute(self) -> dict[str, Any]:
""" """
Executes the strategy's 'next()' method. Executes the strategy's 'next()' method.
:return: Result of the execution. :return: Result of the execution.
""" """
try: try:
# Define the local execution environment # Execute the generated 'next()' method with exec_context as globals
exec_context = { exec(self.generated_code, self.exec_context)
'flags': self.flags,
'strategy_id': self.strategy_id,
'user_id': self.user_id,
'get_last_candle': self.get_last_candle,
'get_current_price': self.get_current_price, # Added method
'buy': self.buy_order,
'sell': self.sell_order,
'exit_strategy': self.exit_strategy,
'notify_user': self.notify_user,
'process_indicator': self.process_indicator,
'get_strategy_profit_loss': self.get_strategy_profit_loss,
'is_in_profit': self.is_in_profit,
'is_in_loss': self.is_in_loss,
'get_active_trades': self.get_active_trades,
'get_starting_balance': self.get_starting_balance,
'set_paused': self.set_paused,
'set_exit': self.set_exit
}
# Execute the generated 'next()' method # Call the 'next()' method if defined
exec(self.generated_code, {}, exec_context) if 'next' in self.exec_context and callable(self.exec_context['next']):
self.exec_context['next']()
# Call the 'next()' method
if 'next' in exec_context and callable(exec_context['next']):
exec_context['next']()
else: else:
logger.error( logger.error(
f"'next' method not defined in generated_code for StrategyInstance '{self.strategy_instance_id}'.") f"'next' method not defined in generated_code for StrategyInstance '{self.strategy_instance_id}'.")
# Retrieve and update profit/loss # Retrieve and update profit/loss
self.profit_loss = exec_context.get('profit_loss', self.profit_loss) self.profit_loss = self.exec_context.get('profit_loss', self.profit_loss)
self.save_context() self.save_context()
return {"success": True, "profit_loss": self.profit_loss} return {"success": True, "profit_loss": self.profit_loss}
except Exception as e: except Exception as e:
logger.error(f"Error executing 'next()' for StrategyInstance '{self.strategy_instance_id}': {e}") logger.error(f"Error executing 'next()' for StrategyInstance '{self.strategy_instance_id}': {e}",
exc_info=True)
traceback.print_exc() traceback.print_exc()
return {"success": False, "message": str(e)} return {"success": False, "message": str(e)}
@ -152,6 +239,7 @@ class StrategyInstance:
:param value: True to pause, False to resume. :param value: True to pause, False to resume.
""" """
self.paused = value self.paused = value
self.exec_context['paused'] = self.paused
self.save_context() self.save_context()
logger.debug(f"Strategy '{self.strategy_id}' paused: {self.paused}") logger.debug(f"Strategy '{self.strategy_id}' paused: {self.paused}")
@ -166,6 +254,43 @@ class StrategyInstance:
self.save_context() self.save_context()
logger.debug(f"Strategy '{self.strategy_id}' exit set: {self.exit} with method '{self.exit_method}'") logger.debug(f"Strategy '{self.strategy_id}' exit set: {self.exit} with method '{self.exit_method}'")
def set_available_strategy_balance(self, balance: float):
"""
Sets the available balance for the strategy.
:param balance: The new available balance.
"""
self.variables['available_strategy_balance'] = balance
logger.debug(f"Available strategy balance set to {balance}.")
def get_current_balance(self) -> float:
"""
Retrieves the current balance from the Trades manager.
:return: Current balance.
"""
try:
balance = self.trades.get_current_balance(self.user_id)
logger.debug(f"Current balance retrieved: {balance}.")
return balance
except Exception as e:
logger.error(f"Error retrieving current balance: {e}", exc_info=True)
return 0.0
def get_available_strategy_balance(self) -> float:
"""
Retrieves the available strategy balance.
:return: Available strategy balance.
"""
try:
balance = self.variables.get('available_strategy_balance', self.starting_balance)
logger.debug(f"Available strategy balance retrieved: {balance}.")
return balance
except Exception as e:
logger.error(f"Error retrieving available strategy balance: {e}", exc_info=True)
return 0.0
def get_total_filled_order_volume(self) -> float: def get_total_filled_order_volume(self) -> float:
""" """
Retrieves the total filled order volume for the strategy. Retrieves the total filled order volume for the strategy.
@ -224,62 +349,47 @@ class StrategyInstance:
logger.error(f"Error retrieving current price for {symbol} on {exchange} ({timeframe}): {e}", exc_info=True) logger.error(f"Error retrieving current price for {symbol} on {exchange} ({timeframe}): {e}", exc_info=True)
return None return None
# Define helper methods def trade_order(
def buy_order(self, size: float, symbol: str, order_type: str = 'market', price: float | None = None, **kwargs): self,
trade_type: str,
size: float,
symbol: str,
order_type: str,
source: dict = None,
tif: str = 'GTC',
stop_loss: dict = None,
trailing_stop: dict = None,
take_profit: dict = None,
limit: dict = None,
trailing_limit: dict = None,
target_market: dict = None,
name_order: dict = None
):
""" """
Executes a buy order. Unified trade order handler for executing buy and sell orders.
"""
if trade_type == 'buy':
logger.info(f"Executing BUY order: Size={size}, Symbol={symbol}, Order Type={order_type}")
# Implement buy order logic here
elif trade_type == 'sell':
logger.info(f"Executing SELL order: Size={size}, Symbol={symbol}, Order Type={order_type}")
# Implement sell order logic here
else:
logger.error(f"Invalid trade_type '{trade_type}'. Order not executed.")
return
:param size: Quantity to buy. # Handle trade options like stop_loss, take_profit, etc.
:param symbol: Trading symbol. if stop_loss:
:param order_type: Type of order ('market' or 'limit'). # Implement stop loss logic
:param price: Price for limit orders. pass
""" if take_profit:
try: # Implement take profit logic
order_data = { pass
'size': size, # Add handling for other trade options as needed
'symbol': symbol,
'order_type': order_type.lower(),
'price': price,
**kwargs
}
status, msg = self.trades.buy(order_data, self.user_id)
if status != 'success':
logger.error(f"Buy order failed: {msg}")
self.notify_user(f"Buy order failed: {msg}")
else:
logger.info(f"Buy order executed successfully: {msg}")
except Exception as e:
logger.error(f"Error executing buy order in StrategyInstance '{self.strategy_instance_id}': {e}",
exc_info=True)
traceback.print_exc()
def sell_order(self, size: float, symbol: str, order_type: str = 'market', price: float | None = None, **kwargs): # Notify user about the trade execution
""" self.notify_user(f"{trade_type.capitalize()} order executed for {size} {symbol} at {order_type} price.")
Executes a sell order.
:param size: Quantity to sell.
:param symbol: Trading symbol.
:param order_type: Type of order ('market' or 'limit').
:param price: Price for limit orders.
"""
try:
order_data = {
'size': size,
'symbol': symbol,
'order_type': order_type.lower(),
'price': price,
**kwargs
}
status, msg = self.trades.sell(order_data, self.user_id)
if status != 'success':
logger.error(f"Sell order failed: {msg}")
self.notify_user(f"Sell order failed: {msg}")
else:
logger.info(f"Sell order executed successfully: {msg}")
except Exception as e:
logger.error(f"Error executing sell order in StrategyInstance '{self.strategy_instance_id}': {e}",
exc_info=True)
traceback.print_exc()
def exit_strategy(self): def exit_strategy(self):
""" """

View File

@ -1,12 +1,35 @@
import logging
import types
import uuid
import backtrader as bt import backtrader as bt
import datetime as dt import datetime as dt
from DataCache_v3 import DataCache from DataCache_v3 import DataCache
from Strategies import Strategies from Strategies import Strategies
from StrategyInstance import StrategyInstance
from indicators import Indicators from indicators import Indicators
import numpy as np import numpy as np
import pandas as pd import pandas as pd
# Configure logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) # Set to DEBUG for detailed logging
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
# Custom EquityCurveAnalyzer
class EquityCurveAnalyzer(bt.Analyzer):
def __init__(self):
self.equity_curve = []
def next(self):
self.equity_curve.append(self.strategy.broker.getvalue())
def get_analysis(self):
return {'equity_curve': self.equity_curve}
# Backtester Class
class Backtester: class Backtester:
def __init__(self, data_cache: DataCache, strategies: Strategies, indicators: Indicators, socketio): def __init__(self, data_cache: DataCache, strategies: Strategies, indicators: Indicators, socketio):
""" Initialize the Backtesting class with a cache for back-tests """ """ Initialize the Backtesting class with a cache for back-tests """
@ -33,17 +56,42 @@ class Backtester:
cache_key = f"backtest:{user_name}:{backtest_name}" cache_key = f"backtest:{user_name}:{backtest_name}"
self.data_cache.insert_row_into_cache('tests', columns, values, key=cache_key) self.data_cache.insert_row_into_cache('tests', columns, values, key=cache_key)
def map_user_strategy(self, user_strategy, precomputed_indicators): def map_user_strategy(self, user_strategy: dict, precomputed_indicators: dict[str, pd.DataFrame],
"""Maps user strategy details into a Backtrader-compatible strategy class.""" mode: str = 'testing') -> any:
"""
Maps user strategy details into a Backtrader-compatible strategy class.
"""
# Extract the generated code and indicators from the strategy components # Extract the generated code and indicators from the strategy components
strategy_components = user_strategy['strategy_components'] strategy_components = user_strategy['strategy_components']
generated_code = strategy_components['generated_code'] generated_code = strategy_components['generated_code']
indicators_used = strategy_components['indicators'] indicators_used = strategy_components['indicators']
# Validate extracted data
if not generated_code:
logger.error("No 'generated_code' found in strategy components.")
raise ValueError("Strategy must contain 'generated_code'.")
if not isinstance(indicators_used, list):
logger.error("'indicators_used' should be a list.")
raise ValueError("'indicators_used' should be a list.")
logger.info(f"Mapping strategy '{user_strategy.get('strategy_name', 'Unnamed')}' with mode '{mode}'.")
# Define the strategy class dynamically # Define the strategy class dynamically
class MappedStrategy(bt.Strategy): class MappedStrategy(bt.Strategy):
params = (
('mode', mode),
('strategy_instance', None), # Will be set during instantiation
)
def __init__(self): def __init__(self):
super().__init__()
self.strategy_instance: StrategyInstance = self.p.strategy_instance
logger.debug(f"StrategyInstance '{self.strategy_instance.strategy_instance_id}' attached to MappedStrategy.")
# Establish backreference
self.strategy_instance.backtrader_strategy = self
self.precomputed_indicators = precomputed_indicators self.precomputed_indicators = precomputed_indicators
self.indicator_pointers = {} self.indicator_pointers = {}
self.indicator_names = list(precomputed_indicators.keys()) self.indicator_names = list(precomputed_indicators.keys())
@ -54,28 +102,8 @@ class Backtester:
self.indicator_pointers[name] = 0 # Start at the first row self.indicator_pointers[name] = 0 # Start at the first row
# Initialize any other needed variables # Initialize any other needed variables
self.flags = {}
self.starting_balance = self.broker.getvalue() self.starting_balance = self.broker.getvalue()
def process_indicator(self, indicator_name, output_field):
# Get the DataFrame for the indicator
df = self.precomputed_indicators[indicator_name]
# Get the current index for the indicator
idx = self.indicator_pointers[indicator_name]
if idx >= len(df):
return None # No more data
# Get the specific output value
if output_field in df.columns:
value = df.iloc[idx][output_field]
if pd.isna(value):
return None # Handle NaN values
return value
else:
return None # Output field not found
def next(self): def next(self):
# Increment pointers # Increment pointers
for name in self.indicator_names: for name in self.indicator_names:
@ -86,68 +114,464 @@ class Backtester:
# Generated strategy logic # Generated strategy logic
try: try:
# Execute the generated code # Execute the strategy logic via StrategyInstance
exec(generated_code) execution_result = self.strategy_instance.execute()
if not execution_result.get('success', False):
error_msg = execution_result.get('message', 'Unknown error during strategy execution.')
logger.error(f"Strategy execution failed: {error_msg}")
# Handle the failure (stop the strategy)
self.stop()
except Exception as e: except Exception as e:
print(f"Error in strategy execution: {e}") logger.error(f"Error in strategy execution: {e}")
return MappedStrategy return MappedStrategy
def prepare_data_feed(self, start_date: str, source: dict): # Add custom handlers to the StrategyInstance
def add_custom_handlers(self, strategy_instance: StrategyInstance) -> StrategyInstance:
"""
Define custom methods to be injected into exec_context.
:param strategy_instance: The strategy instance to inject the custom handlers into.
:return: The modified strategy instance.
"""
# 1. Override trade_order
def trade_order(
trade_type: str,
size: float,
order_type: str,
source: dict = None,
tif: str = 'GTC',
stop_loss: dict = None,
trailing_stop: dict = None,
take_profit: dict = None,
limit: dict = None,
trailing_limit: dict = None,
target_market: dict = None,
name_order: dict = None
):
"""
Custom trade_order method for backtesting.
Executes trades within the Backtrader environment.
:param trade_type: Type of trade ('buy' or 'sell').
:param size: Size of the trade.
:param order_type: Type of order (e.g., 'market').
:param source: Dictionary containing additional trade information, including 'market'.
:param tif: Time in Force for the order.
:param stop_loss: Dictionary with stop loss parameters.
:param trailing_stop: Dictionary with trailing stop parameters.
:param take_profit: Dictionary with take profit parameters.
:param limit: Dictionary with limit order parameters.
:param trailing_limit: Dictionary with trailing limit parameters.
:param target_market: Dictionary with target market parameters.
:param name_order: Dictionary with order name parameters.
"""
# Validate and extract 'symbol' from 'source'
if source and 'market' in source:
symbol = source['market']
logger.debug(f"Extracted symbol '{symbol}' from source.")
else:
logger.error("Symbol not provided in source. Order not executed.")
return # Abort the order execution
if trade_type.lower() == 'buy':
logger.info(f"Executing BUY order: Size={size}, Symbol={symbol}, Order Type={order_type}")
# Execute a buy order in Backtrader via Cerebro
order = strategy_instance.backtrader_strategy.buy(size=size, exectype=bt.Order.Market, name=symbol)
elif trade_type.lower() == 'sell':
logger.info(f"Executing SELL order: Size={size}, Symbol={symbol}, Order Type={order_type}")
# Execute a sell order in Backtrader via Cerebro
order = strategy_instance.backtrader_strategy.sell(size=size, exectype=bt.Order.Market, name=symbol)
else:
logger.error(f"Invalid trade_type '{trade_type}'. Order not executed.")
return # Abort the order execution
# Handle trade options like stop_loss and take_profit
if stop_loss or take_profit:
if stop_loss:
stop_price = stop_loss.get('value')
if stop_price is not None:
logger.info(f"Setting STOP LOSS at {stop_price} for order {order.ref}")
strategy_instance.backtrader_strategy.sell(
size=size,
exectype=bt.Order.Stop,
price=stop_price,
parent=order,
name=f"StopLoss_{order.ref}"
)
if take_profit:
take_profit_price = take_profit.get('value')
if take_profit_price is not None:
logger.info(f"Setting TAKE PROFIT at {take_profit_price} for order {order.ref}")
strategy_instance.backtrader_strategy.sell(
size=size,
exectype=bt.Order.Limit,
price=take_profit_price,
parent=order,
name=f"TakeProfit_{order.ref}"
)
# Notify user about the trade execution
strategy_instance.notify_user(
f"{trade_type.capitalize()} order executed for {size} {symbol} at {order_type} price."
)
logger.debug(f"{trade_type.capitalize()} order executed for {size} {symbol} at {order_type} price.")
# Override the trade_order method
strategy_instance.override_exec_context('trade_order', trade_order)
# 2. Override process_indicator
def process_indicator(indicator_name, output_field):
"""
Custom process_indicator method for backtesting.
:param indicator_name: Name of the indicator.
:param output_field: Specific field to retrieve from the indicator.
:return: The value of the specified indicator field at the current step.
"""
# Access precomputed_indicators via backtrader_strategy
if strategy_instance.backtrader_strategy is None:
logger.error("Backtrader strategy is not set in StrategyInstance.")
return None
df = strategy_instance.backtrader_strategy.precomputed_indicators.get(indicator_name)
if df is None:
logger.error(f"Indicator '{indicator_name}' not found in precomputed indicators.")
return None
# Access indicator_pointers via backtrader_strategy
idx = strategy_instance.backtrader_strategy.indicator_pointers.get(indicator_name, 0)
if idx >= len(df):
logger.warning(f"No more data for indicator '{indicator_name}' at index {idx}.")
return None # No more data
# Get the specific output value
if output_field in df.columns:
value = df.iloc[idx][output_field]
if pd.isna(value):
logger.warning(f"NaN value encountered for indicator '{indicator_name}' at index {idx}.")
return None # Handle NaN values
return value
else:
logger.error(f"Output field '{output_field}' not found in indicator '{indicator_name}'.")
return None # Output field not found
# Override the process_indicator method
strategy_instance.override_exec_context('process_indicator', process_indicator)
# 3. Override get_current_price
def get_current_price(timeframe: str = '1h', exchange: str = 'binance',
symbol: str = 'BTC/USD') -> float | None:
"""
Retrieves the current market price from Backtrader's data feed.
"""
try:
# Access the current close price from Backtrader's data
current_price = strategy_instance.backtrader_strategy.data.close[0]
logger.debug(f"Retrieved current price for {symbol} on {exchange} ({timeframe}): {current_price}")
return current_price
except Exception as e:
logger.error(f"Error retrieving current price for {symbol} on {exchange} ({timeframe}): {e}",
exc_info=True)
return None
# Override the get_current_price method
strategy_instance.override_exec_context('get_current_price', get_current_price)
# 4. Override get_last_candle
def get_last_candle(candle_part: str, timeframe: str, exchange: str, symbol: str):
"""
Retrieves the specified part of the last candle from Backtrader's data feed.
"""
try:
# Map candle_part to Backtrader's data attributes
candle_map = {
'open': strategy_instance.backtrader_strategy.data.open[0],
'high': strategy_instance.backtrader_strategy.data.high[0],
'low': strategy_instance.backtrader_strategy.data.low[0],
'close': strategy_instance.backtrader_strategy.data.close[0],
'volume': strategy_instance.backtrader_strategy.data.volume[0],
}
value = candle_map.get(candle_part.lower())
if value is None:
logger.error(f"Invalid candle_part '{candle_part}'. Must be one of {list(candle_map.keys())}.")
else:
logger.debug(
f"Retrieved '{candle_part}' from last candle for {symbol} on {exchange} ({timeframe}): {value}")
return value
except Exception as e:
logger.error(
f"Error retrieving last candle '{candle_part}' for {symbol} on {exchange} ({timeframe}): {e}",
exc_info=True)
return None
# Override the get_last_candle method
strategy_instance.override_exec_context('get_last_candle', get_last_candle)
# 5. Override get_filled_orders
def get_filled_orders() -> int:
"""
Retrieves the number of filled orders from Backtrader's broker.
"""
try:
# Access Backtrader's broker's filled orders
filled_orders = len(strategy_instance.backtrader_strategy.broker.filled)
logger.debug(f"Number of filled orders: {filled_orders}")
return filled_orders
except Exception as e:
logger.error(f"Error retrieving filled orders: {e}", exc_info=True)
return 0
# Override the get_filled_orders method
strategy_instance.override_exec_context('get_filled_orders', get_filled_orders)
# 6. Override get_available_balance
def get_available_balance() -> float:
"""
Retrieves the available balance from Backtrader's broker.
"""
try:
available_balance = strategy_instance.backtrader_strategy.broker.getcash()
logger.debug(f"Available balance: {available_balance}")
return available_balance
except Exception as e:
logger.error(f"Error retrieving available balance: {e}", exc_info=True)
return 0.0
# Override the get_available_balance method
strategy_instance.override_exec_context('get_available_balance', get_available_balance)
# 7. Override get_current_balance
def get_current_balance() -> float:
"""
Retrieves the current balance from Backtrader's broker.
:return: Current balance.
"""
try:
# Access the total portfolio value from Backtrader's broker
balance = strategy_instance.backtrader_strategy.broker.getvalue()
logger.debug(f"Current balance retrieved: {balance}.")
return balance
except Exception as e:
logger.error(f"Error retrieving current balance: {e}", exc_info=True)
return 0.0
# Override the get_current_balance method
strategy_instance.override_exec_context('get_current_balance', get_current_balance)
# 8. Override get_filled_orders_details (Optional but Recommended)
def get_filled_orders_details() -> list:
"""
Retrieves detailed information about filled orders.
"""
try:
filled_orders = []
for order in strategy_instance.backtrader_strategy.broker.filled:
order_info = {
'ref': order.ref,
'size': order.size,
'price': order.executed.price,
'value': order.executed.value,
'commission': order.executed.comm,
'status': order.status,
'created_at': dt.datetime.fromtimestamp(order.created.dt.timestamp())
}
filled_orders.append(order_info)
logger.debug(f"Filled orders details: {filled_orders}")
return filled_orders
except Exception as e:
logger.error(f"Error retrieving filled orders details: {e}", exc_info=True)
return []
# Override the get_filled_orders_details method
strategy_instance.override_exec_context('get_filled_orders_details', get_filled_orders_details)
def notify_user(self, message: str):
"""
Suppresses user notifications and instead logs them.
:param message: Notification message.
"""
logger.debug(f"User notification during backtest for user ID '{self.user_id}': {message}")
# Bind the overridden method to the instance
strategy_instance.notify_user = types.MethodType(notify_user, strategy_instance)
# Return the modified strategy_instance
return strategy_instance
def prepare_data_feed(self, start_date: str, source, user_name: str) -> pd.DataFrame:
""" """
Prepare the main data feed based on the start date and source. Prepare the main data feed based on the start date and source.
:param start_date: Start date in 'YYYY-MM-DDTHH:MM' format.
:param source: Can be either a list or a dictionary.
- If a list: Expected order is [exchange, symbol, timeframe].
- If a dictionary: Expected keys are 'exchange', 'symbol', and 'timeframe'.
:param user_name: The user name associated with the data feed.
:return: Pandas DataFrame with OHLC data.
This method is designed to be flexible, supporting both list and dictionary formats
for the source. This flexibility allows for backward compatibility with existing code
using lists, while providing a clearer structure when dictionaries are preferred.
""" """
try: try:
# Convert the start date to a datetime object # Convert the start date to a datetime object and make it timezone-aware (UTC)
start_dt = dt.datetime.strptime(start_date, '%Y-%m-%dT%H:%M') start_dt = dt.datetime.strptime(start_date, '%Y-%m-%dT%H:%M')
start_dt = start_dt.replace(tzinfo=dt.timezone.utc) # Set UTC timezone
# Ensure exchange details contain required keys (fallback if missing) # Check if source is a dictionary or a list, then set exchange, symbol, and timeframe
timeframe = source.get('timeframe', '1h') if isinstance(source, dict):
exchange = source.get('exchange', 'Binance') exchange = source.get('exchange', 'Binance')
symbol = source.get('symbol', 'BTCUSDT') symbol = source.get('symbol', 'BTCUSDT')
timeframe = source.get('timeframe', '1h')
elif isinstance(source, list) and len(source) >= 3:
exchange, symbol, timeframe = source[0], source[1], source[2]
else:
logger.error("Source must be either a list with at least 3 elements or a dictionary.")
return pd.DataFrame()
# Now include user_name in ex_details
ex_details = [symbol, timeframe, exchange, user_name]
# Fetch OHLC data from DataCache based on the source # Fetch OHLC data from DataCache based on the source
data = self.data_cache.get_records_since(start_datetime=start_dt, ex_details=[symbol, timeframe, exchange]) data = self.data_cache.get_records_since(start_datetime=start_dt, ex_details=ex_details)
if data.empty:
logger.error(
f"No data retrieved for symbol {symbol} on exchange {exchange} with timeframe {timeframe}.")
return pd.DataFrame() # Return empty DataFrame
logger.info(f"Data feed prepared for {symbol} on {exchange} with timeframe {timeframe}.")
return data return data
except Exception as e: except Exception as e:
print(f"Error preparing data feed: {e}") logger.error(f"Error preparing data feed: {e}")
return None return pd.DataFrame()
def precompute_indicators(self, indicators_definitions, data_feed): def precompute_indicators(self, indicators_definitions: list, user_name: str, data_feed: pd.DataFrame) -> dict:
""" """
Precompute indicator values and return a dictionary of DataFrames. Precompute indicator values and return a dictionary of DataFrames.
:param indicators_definitions: List of indicator definitions.
:param data_feed: Pandas DataFrame with OHLC data.
:return: Dictionary mapping indicator names to their precomputed DataFrames.
""" """
precomputed_indicators = {} precomputed_indicators = {}
total_candles = len(data_feed) total_candles = len(data_feed)
for indicator_def in indicators_definitions: for indicator_def in indicators_definitions:
indicator_name = indicator_def['name'] indicator_name = indicator_def.get('name')
output = indicator_def.get('output') # e.g., 'middle'
if not indicator_name:
logger.warning("Indicator definition missing 'name'. Skipping.")
continue
# Compute the indicator values # Compute the indicator values
indicator_df = self.indicators_manager.process_indicator(indicator=indicator_def, indicator_data = self.indicators_manager.get_latest_indicator_data(
num_results=total_candles) user_name=user_name,
indicator_name=indicator_name,
num_results=total_candles
)
if not indicator_data:
logger.warning(f"No data returned for indicator '{indicator_name}'. Skipping.")
continue
data = indicator_data.get(indicator_name)
# Convert the data to a DataFrame
if isinstance(data, list):
df = pd.DataFrame(data)
elif isinstance(data, dict):
df = pd.DataFrame([data])
else:
logger.warning(f"Unexpected data format for indicator '{indicator_name}'. Skipping.")
continue
# If 'output' is specified, extract that column without renaming
if output:
if output in df.columns:
df = df[['time', output]]
else:
logger.warning(f"Output '{output}' not found in indicator '{indicator_name}'. Skipping.")
continue
# Ensure the DataFrame has a consistent index # Ensure the DataFrame has a consistent index
indicator_df.reset_index(drop=True, inplace=True) df.reset_index(drop=True, inplace=True)
precomputed_indicators[indicator_name] = indicator_df precomputed_indicators[indicator_name] = df
logger.debug(f"Precomputed indicator '{indicator_name}' with {len(df)} data points.")
return precomputed_indicators return precomputed_indicators
def run_backtest(self, strategy_class, data_feed, msg_data, user_name, callback, socket_conn_id): def prepare_backtest_data(self, msg_data: dict, strategy_components: dict) -> tuple:
"""
Prepare the data feed and precomputed indicators for backtesting.
:param msg_data: Message data containing backtest parameters.
:param strategy_components: Components of the user-defined strategy.
:return: Tuple of (data_feed, precomputed_indicators).
:raises ValueError: If data sources are invalid or data feed cannot be prepared.
"""
user_name = msg_data.get('user_name', 'default_user')
data_sources = strategy_components.get('data_sources', [])
if not data_sources:
logger.error("No valid data sources found in the strategy.")
raise ValueError("No valid data sources found in the strategy.")
# For simplicity, use the first data source as the main data feed.
main_source = data_sources[0]
# Prepare the main data feed
data_feed = self.prepare_data_feed(msg_data.get('start_date', '2023-01-01T00:00'), main_source, user_name)
if data_feed.empty:
logger.error("Data feed could not be prepared. Please check the data source.")
raise ValueError("Data feed could not be prepared. Please check the data source.")
# Precompute indicator values
indicators_definitions = strategy_components.get('indicators', [])
precomputed_indicators = self.precompute_indicators(indicators_definitions, user_name, data_feed)
logger.info("Backtest data prepared successfully.")
return data_feed, precomputed_indicators
def run_backtest(self, strategy_class, data_feed: pd.DataFrame, msg_data: dict, user_name: str,
callback, socket_conn_id: str, strategy_instance: StrategyInstance):
""" """
Runs a backtest using Backtrader and uses Flask-SocketIO's background tasks. Runs a backtest using Backtrader and uses Flask-SocketIO's background tasks.
Sends progress updates to the client via WebSocket. Sends progress updates to the client via WebSocket.
""" """
def execute_backtest(): def execute_backtest():
nonlocal data_feed
try: try:
# **Convert 'time' to 'datetime' if necessary**
if 'time' in data_feed.columns:
data_feed['datetime'] = pd.to_datetime(data_feed['time'], unit='ms') # Adjust 'unit' if needed
data_feed.set_index('datetime', inplace=True)
logger.info("Converted 'time' to 'datetime' and set as index in data_feed.")
# **Select relevant columns for Backtrader**
columns_to_keep = ['open', 'high', 'low', 'close', 'volume']
if not set(columns_to_keep).issubset(data_feed.columns):
logger.error("Data feed is missing one or more required columns: %s", columns_to_keep)
raise ValueError("Incomplete data feed for Backtrader.")
data_feed = data_feed[columns_to_keep]
cerebro = bt.Cerebro() cerebro = bt.Cerebro()
# Add the mapped strategy to the backtest # Assign cerebro to strategy_instance for potential use in custom methods
cerebro.addstrategy(strategy_class) strategy_instance.cerebro = cerebro
# Add the mapped strategy to the backtest, including strategy_instance as a parameter
cerebro.addstrategy(strategy_class, strategy_instance=strategy_instance)
# Add the main data feed to Cerebro # Add the main data feed to Cerebro
# noinspection PyArgumentList
bt_feed = bt.feeds.PandasData(dataname=data_feed) bt_feed = bt.feeds.PandasData(dataname=data_feed)
cerebro.adddata(bt_feed) cerebro.adddata(bt_feed)
@ -157,16 +581,27 @@ class Backtester:
commission = msg_data.get('commission', 0.001) commission = msg_data.get('commission', 0.001)
cerebro.broker.setcommission(commission=commission) cerebro.broker.setcommission(commission=commission)
# Add analyzers
cerebro.addanalyzer(EquityCurveAnalyzer, _name='equity_curve')
cerebro.addanalyzer(bt.analyzers.TradeAnalyzer, _name='trade_analyzer')
# Run the backtest # Run the backtest
print("Running backtest...") logger.info("Running backtest...")
start_time = dt.datetime.now() start_time = dt.datetime.now()
cerebro.run() results = cerebro.run()
end_time = dt.datetime.now() end_time = dt.datetime.now()
# Extract performance metrics # Extract performance metrics
final_value = cerebro.broker.getvalue() final_value = cerebro.broker.getvalue()
run_duration = (end_time - start_time).total_seconds() run_duration = (end_time - start_time).total_seconds()
# Extract equity curve from analyzers
equity_curve = results[0].analyzers.equity_curve.get_analysis().get('equity_curve', [])
# Extract trade data from TradeAnalyzer
trade_analyzer = results[0].analyzers.trade_analyzer.get_analysis()
trades = self.parse_trade_analyzer(trade_analyzer)
# Send 100% completion # Send 100% completion
self.socketio.emit('progress_update', {"progress": 100}, room=socket_conn_id) self.socketio.emit('progress_update', {"progress": 100}, room=socket_conn_id)
@ -174,111 +609,163 @@ class Backtester:
backtest_results = { backtest_results = {
"initial_capital": initial_cash, "initial_capital": initial_cash,
"final_portfolio_value": final_value, "final_portfolio_value": final_value,
"run_duration": run_duration "run_duration": run_duration,
"equity_curve": equity_curve,
"trades": trades,
} }
logger.info("Backtest executed successfully.")
callback(backtest_results) callback(backtest_results)
except Exception as e: except Exception as e:
# Handle exceptions and send error messages to the client # Handle exceptions and send error messages to the client
error_message = f"Backtest execution failed: {str(e)}" error_message = f"Backtest execution failed: {str(e)}"
self.socketio.emit('backtest_error', {"message": error_message}, room=socket_conn_id) self.socketio.emit('backtest_error', {"message": error_message}, room=socket_conn_id)
print(f"[BACKTEST ERROR] {error_message}") logger.error(f"[BACKTEST ERROR] {error_message}")
# Start the backtest as a background task # Start the backtest as a background task
self.socketio.start_background_task(execute_backtest) self.socketio.start_background_task(execute_backtest)
def handle_backtest_message(self, user_id, msg_data, socket_conn_id): def handle_backtest_message(self, user_id: str, msg_data: dict, socket_conn_id: str) -> dict:
"""
Handle incoming backtest messages, orchestrate the backtest process.
:param user_id: ID of the user initiating the backtest.
:param msg_data: Dictionary containing backtest parameters.
:param socket_conn_id: Socket connection ID for emitting updates.
:return: Dictionary with the status of backtest initiation.
"""
user_name = msg_data.get('user_name') user_name = msg_data.get('user_name')
backtest_name = f"{msg_data['strategy']}_backtest" backtest_name = f"{msg_data.get('strategy', 'UnnamedStrategy')}_backtest"
# Cache the backtest data # Cache the backtest data
self.cache_backtest(user_name, backtest_name, msg_data) self.cache_backtest(user_name, backtest_name, msg_data)
# Fetch the strategy using user_id and strategy_name # Fetch the strategy using user_id and strategy_name
strategy_name = msg_data.get('strategy') strategy_name = msg_data.get('strategy')
user_strategy = self.strategies.get_strategy_by_name(user_id=user_id, name=strategy_name) user_strategy = self.strategies.get_strategy_by_name(user_id=int(user_id), name=strategy_name)
if not user_strategy: if not user_strategy:
return {"error": f"Strategy {strategy_name} not found for user {user_name}"} logger.error(f"Strategy '{strategy_name}' not found for user '{user_name}'.")
return {"error": f"Strategy '{strategy_name}' not found for user '{user_name}'."}
# Extract the main data source from the strategy components # Prepare the source feeds for the sources referenced in the strategy.
strategy_components = user_strategy['strategy_components'] strategy_components = user_strategy.get('strategy_components', {})
data_sources = strategy_components['data_sources'] try:
data_feed, precomputed_indicators = self.prepare_backtest_data(msg_data, strategy_components)
except ValueError as ve:
logger.error(f"Error preparing backtest data: {ve}")
return {"error": str(ve)}
if not data_sources: # Ensure user_id is an integer
return {"error": "No valid data sources found in the strategy."} try:
user_id_int = int(user_id)
except ValueError:
logger.error(f"Invalid user_id '{user_id}'. Must be an integer.")
return {"error": f"Invalid user_id '{user_id}'. Must be an integer."}
# For simplicity, use the first data source as the main data feed # Generate unique strategy_instance_id for the backtest
main_source = data_sources[0] strategy_instance_id = f"test_{user_id}_{strategy_name}_{dt.datetime.now().isoformat()}"
# Prepare the main data feed # Instantiate StrategyInstance with proper indicators and trades
data_feed = self.prepare_data_feed(msg_data['start_date'], main_source) strategy_instance = StrategyInstance(
strategy_instance_id=strategy_instance_id,
strategy_id=user_strategy.get("id"),
strategy_name=strategy_name,
user_id=user_id_int,
generated_code=strategy_components.get("generated_code", ""),
data_cache=self.data_cache,
indicators=None, # Indicators are handled via overridden methods
trades=None # Trades are handled via overridden methods
)
if data_feed is None: # Override any methods that access exchanges and market data with custom handlers for backtesting
return {"error": "Data feed could not be prepared. Please check the data source."} strategy_instance = self.add_custom_handlers(strategy_instance)
# Precompute indicator values # Map the user strategy to a Backtrader-compatible strategy class
indicators_definitions = strategy_components['indicators']
precomputed_indicators = self.precompute_indicators(indicators_definitions, data_feed)
# Map the user strategy to a Backtrader strategy class
mapped_strategy_class = self.map_user_strategy(user_strategy, precomputed_indicators) mapped_strategy_class = self.map_user_strategy(user_strategy, precomputed_indicators)
# Define the callback function to handle backtest completion # Define the callback function to handle backtest completion
def backtest_callback(results): def backtest_callback(results):
self.store_backtest_results(user_name, backtest_name, results) self.store_backtest_results(user_name, backtest_name, results)
self.update_strategy_stats(user_id, strategy_name, results) self.update_strategy_stats(user_id_int, strategy_name, results)
# Emit the results back to the client # Emit the results back to the client
self.socketio.emit('backtest_results', {"test_id": backtest_name, "results": results}, room=socket_conn_id) self.socketio.emit('backtest_results', {"test_id": backtest_name, "results": results}, room=socket_conn_id)
print(f"[BACKTEST COMPLETE] Results emitted to user '{user_name}'.") logger.info(f"[BACKTEST COMPLETE] Results emitted to user '{user_name}'.")
# Run the backtest asynchronously # Run the backtest asynchronously, passing the strategy_instance
self.run_backtest(mapped_strategy_class, data_feed, msg_data, user_name, backtest_callback, socket_conn_id) self.run_backtest(
mapped_strategy_class,
data_feed,
msg_data,
user_name,
backtest_callback,
socket_conn_id,
strategy_instance
)
logger.info(f"Backtest '{backtest_name}' started for user '{user_name}'.")
return {"reply": "backtest_started"} return {"reply": "backtest_started"}
def update_strategy_stats(self, user_id, strategy_name, results): def update_strategy_stats(self, user_id: int, strategy_name: str, results: dict):
""" Update the strategy stats with the backtest results """ """
Update the strategy stats with the backtest results.
:param user_id: ID of the user.
:param strategy_name: Name of the strategy.
:param results: Dictionary containing backtest results.
"""
strategy = self.strategies.get_strategy_by_name(user_id=user_id, name=strategy_name) strategy = self.strategies.get_strategy_by_name(user_id=user_id, name=strategy_name)
if strategy: if strategy:
initial_capital = results['initial_capital'] strategy_id = strategy.get('id') or strategy.get('tbl_key')
final_value = results['final_portfolio_value'] initial_capital = results.get('initial_capital')
returns = np.array(results['returns']) final_value = results.get('final_portfolio_value')
equity_curve = np.array(results['equity_curve']) equity_curve = results.get('equity_curve', [])
trades = results['trades']
total_return = (final_value - initial_capital) / initial_capital * 100 # Calculate returns based on the equity curve
returns = self.calculate_returns(equity_curve)
trades = results.get('trades', [])
risk_free_rate = 0.0 if returns and trades:
mean_return = np.mean(returns) returns = np.array(returns)
std_return = np.std(returns) equity_curve = np.array(equity_curve)
sharpe_ratio = (mean_return - risk_free_rate) / std_return if std_return != 0 else 0
running_max = np.maximum.accumulate(equity_curve) total_return = (final_value - initial_capital) / initial_capital * 100
drawdowns = (equity_curve - running_max) / running_max
max_drawdown = np.min(drawdowns) * 100
num_trades = len(trades) risk_free_rate = 0.0 # Modify as needed
wins = sum(1 for trade in trades if trade['profit'] > 0) mean_return = np.mean(returns)
losses = num_trades - wins std_return = np.std(returns)
win_loss_ratio = wins / losses if losses != 0 else wins sharpe_ratio = (mean_return - risk_free_rate) / std_return if std_return != 0 else 0
stats = { running_max = np.maximum.accumulate(equity_curve)
'total_return': total_return, drawdowns = (equity_curve - running_max) / running_max
'sharpe_ratio': sharpe_ratio, max_drawdown = np.min(drawdowns) * 100
'max_drawdown': max_drawdown,
'number_of_trades': num_trades,
'win_loss_ratio': win_loss_ratio,
}
strategy.update_stats(stats) num_trades = len(trades)
wins = sum(1 for trade in trades if trade.get('pnl', 0) > 0)
losses = num_trades - wins
win_loss_ratio = wins / losses if losses != 0 else wins
stats = {
'total_return': total_return,
'sharpe_ratio': sharpe_ratio,
'max_drawdown': max_drawdown,
'number_of_trades': num_trades,
'win_loss_ratio': win_loss_ratio,
}
# Update the strategy's stats using the Strategies class
self.strategies.update_stats(strategy_id, stats)
logger.info(f"Strategy '{strategy_name}' stats updated successfully.")
else:
logger.warning("Missing 'returns' or 'trades' data for statistics calculation.")
else: else:
print(f"Strategy {strategy_name} not found for user {user_id}.") logger.error(f"Strategy '{strategy_name}' not found for user '{user_id}'.")
def store_backtest_results(self, user_name, backtest_name, results): def store_backtest_results(self, user_name: str, backtest_name: str, results: dict):
""" Store the backtest results in the cache """ """ Store the backtest results in the cache """
cache_key = f"backtest:{user_name}:{backtest_name}" cache_key = f"backtest:{user_name}:{backtest_name}"
@ -286,7 +773,89 @@ class Backtester:
backtest_data = self.data_cache.get_rows_from_cache('tests', filter_vals) backtest_data = self.data_cache.get_rows_from_cache('tests', filter_vals)
if not backtest_data.empty: if not backtest_data.empty:
backtest_data['results'] = results backtest_data['results'] = str(results) # Convert dict to string or JSON as per your cache implementation
self.data_cache.insert_row_into_cache('tests', backtest_data.keys(), backtest_data.values(), key=cache_key) self.data_cache.insert_row_into_cache('tests', backtest_data.columns, backtest_data.values, key=cache_key)
logger.info(f"Backtest results stored for '{backtest_name}' of user '{user_name}'.")
else: else:
print(f"Backtest {backtest_name} not found in cache.") logger.error(f"Backtest '{backtest_name}' not found in cache for user '{user_name}'.")
def calculate_returns(self, equity_curve: list) -> list:
"""
Calculate returns based on the equity curve.
:param equity_curve: List of portfolio values over time.
:return: List of returns.
"""
if not equity_curve or len(equity_curve) < 2:
logger.warning("Insufficient data to calculate returns.")
return []
returns = []
for i in range(1, len(equity_curve)):
ret = (equity_curve[i] - equity_curve[i - 1]) / equity_curve[i - 1]
returns.append(ret)
logger.debug(f"Calculated returns: {returns}")
return returns
def extract_trades(self, strategy_instance: StrategyInstance) -> list:
"""
Extract trades from the strategy instance.
:param strategy_instance: The strategy instance.
:return: List of trades with profit information.
"""
# Since Trades class is not used, extract trades from TradeAnalyzer
# This method is now obsolete due to integration with TradeAnalyzer
# Instead, trades are extracted directly from 'results' in run_backtest
# Kept here for backward compatibility or future use
return []
def parse_trade_analyzer(self, trade_analyzer: dict) -> list:
"""
Parse the TradeAnalyzer results from Backtrader and return a list of trades.
:param trade_analyzer: Dictionary containing trade analysis.
:return: List of trade dictionaries with relevant information.
"""
trades = []
if not trade_analyzer:
logger.warning("No trade data available in TradeAnalyzer.")
return trades
# TradeAnalyzer stores trades under 'trades'
trade_list = trade_analyzer.get('trades', {})
# Check if 'trades' is a dict (with trade references) or a list
if isinstance(trade_list, dict):
for ref, trade in trade_list.items():
trade_info = {
'ref': ref,
'size': trade.get('size'),
'price': trade.get('price'),
'value': trade.get('value'),
'pnl': trade.get('pnl'),
'commission': trade.get('commission'),
'opendate': trade.get('opendate'),
'closedate': trade.get('closedate'),
'status': trade.get('status'),
}
trades.append(trade_info)
logger.debug(f"Parsed trade: {trade_info}")
elif isinstance(trade_list, list):
for trade in trade_list:
trade_info = {
'ref': trade.get('ref'),
'size': trade.get('size'),
'price': trade.get('price'),
'value': trade.get('value'),
'pnl': trade.get('pnl'),
'commission': trade.get('commission'),
'opendate': trade.get('opendate'),
'closedate': trade.get('closedate'),
'status': trade.get('status'),
}
trades.append(trade_info)
logger.debug(f"Parsed trade: {trade_info}")
else:
logger.error("Unexpected format for 'trades' in TradeAnalyzer.")
logger.info(f"Parsed {len(trades)} trades from TradeAnalyzer.")
return trades

View File

@ -497,6 +497,15 @@ class Indicators:
""" """
username = self.users.get_username(indicator.creator) username = self.users.get_username(indicator.creator)
src = indicator.source src = indicator.source
# Deserialize src if it's a string
if isinstance(src, str):
try:
src = json.loads(src)
except json.JSONDecodeError as e:
print(f"Error decoding JSON for indicator '{indicator.name}': {e}")
return None # or handle the error as appropriate
symbol, timeframe, exchange_name = src['market'], src['timeframe'], src['exchange'] symbol, timeframe, exchange_name = src['market'], src['timeframe'], src['exchange']
# Retrieve necessary details to instantiate the indicator # Retrieve necessary details to instantiate the indicator
@ -596,6 +605,52 @@ class Indicators:
return json_ready_results return json_ready_results
def get_latest_indicator_data(self, user_name: str, indicator_name: str, num_results: int = 1) -> Optional[
Dict[str, Any]]:
"""
Retrieves the latest data points for a specific indicator for a given user.
:param user_name: The name of the user.
:param indicator_name: The name of the indicator.
:param num_results: Number of latest results to fetch (default is 1).
:return: A dictionary containing the latest indicator data, or None if not found.
"""
try:
# Step 1: Get User ID
user_id = self.users.get_id(user_name=user_name)
if not user_id:
raise ValueError(f"Invalid user_name: '{user_name}'")
# Step 2: Retrieve the Specific Indicator
indicators = self.cache_manager.get_rows_from_datacache(
'indicators',
[('creator', str(user_id)), ('name', indicator_name)]
)
if indicators.empty:
print(f"Indicator '{indicator_name}' not found for user '{user_name}'.")
return None # Indicator not found
# Assuming indicator names are unique per user, take the first match
indicator = indicators.iloc[0]
# Step 3: Process the Indicator to Get Data
indicator_result = self.process_indicator(indicator=indicator, num_results=num_results)
# Step 4: Extract the Latest Data Points
if isinstance(indicator_result, pd.DataFrame):
latest_data = indicator_result.tail(num_results).to_dict(orient='records')
return {indicator_name: latest_data}
elif isinstance(indicator_result, dict):
return {indicator_name: indicator_result}
else:
print(f"Unexpected data format for indicator '{indicator_name}'.")
return None
except Exception as e:
print(f"Error retrieving latest data for indicator '{indicator_name}': {e}")
return None
def delete_indicator(self, indicator_name: str, user_name: str) -> None: def delete_indicator(self, indicator_name: str, user_name: str) -> None:
""" """
Remove the indicator by name Remove the indicator by name

View File

@ -141,12 +141,12 @@ export function defineControlGenerators() {
// Process DO statements // Process DO statements
const doStatements = []; const doStatements = [];
let currentBlock = block.getInputTargetBlock('DO'); let currentBlock = block.getInputTargetBlock('DO');
while (currentBlock) { if (currentBlock) {
const blockJson = Blockly.JSON._blockToJson(currentBlock, 1); const blockJson = Blockly.JSON._blockToJson(currentBlock, 1);
if (blockJson) { if (blockJson) {
doStatements.push(blockJson); doStatements.push(blockJson);
} }
currentBlock = currentBlock.getNextBlock(); // No need to handle 'next' here, as _blockToJson will process it
} }
// If no DO statements, exclude the block // If no DO statements, exclude the block
@ -160,17 +160,14 @@ export function defineControlGenerators() {
type: 'execute_if', type: 'execute_if',
inputs: { inputs: {
CONDITION: conditionJson CONDITION: conditionJson
},
statements: {
DO: doStatements
} }
}; };
if (doStatements.length > 0) { // // **Set the skipAdditionalParsing flag**
json.statements = { // json.skipAdditionalParsing = true;
DO: doStatements
};
}
// **Set the skipAdditionalParsing flag**
json.skipAdditionalParsing = true;
console.log(`Generated JSON for 'execute_if' block:`, json); console.log(`Generated JSON for 'execute_if' block:`, json);
return json; return json;

View File

@ -169,7 +169,10 @@ export function defineVAFGenerators() {
console.warn("Empty variable_name in get_variable block. Defaulting to 'undefined_var'."); console.warn("Empty variable_name in get_variable block. Defaulting to 'undefined_var'.");
} }
variables.push({ 'variable': trimmedName }); variables.push({
type: 'get_variable',
variable_name: variableName
});
// Process the 'NEXT' connection // Process the 'NEXT' connection
const nextBlock = currentBlock.getInputTargetBlock('NEXT'); const nextBlock = currentBlock.getInputTargetBlock('NEXT');