505 lines
20 KiB
Python
505 lines
20 KiB
Python
"""Test Step Dialog for interactive step testing."""
|
||
|
||
import json
|
||
import re
|
||
import time
|
||
from typing import Union
|
||
|
||
from PySide6.QtWidgets import (
|
||
QDialog, QVBoxLayout, QHBoxLayout, QFormLayout, QGroupBox,
|
||
QLabel, QLineEdit, QPlainTextEdit, QPushButton, QComboBox,
|
||
QTextEdit, QTableWidget, QTableWidgetItem, QHeaderView,
|
||
QSplitter, QWidget, QMessageBox
|
||
)
|
||
from PySide6.QtCore import Qt, QThread, Signal
|
||
from PySide6.QtGui import QColor
|
||
|
||
from ...tool import PromptStep, CodeStep, ToolStep
|
||
from ...runner import execute_prompt_step, execute_code_step, execute_tool_step
|
||
from ...providers import load_providers
|
||
|
||
|
||
class StepTestWorker(QThread):
|
||
"""Background worker for executing step tests."""
|
||
finished = Signal(dict) # Emits result dict
|
||
|
||
def __init__(self, step: Union[PromptStep, CodeStep, ToolStep], variables: dict, provider_override: str = None):
|
||
super().__init__()
|
||
self.step = step
|
||
self.variables = variables
|
||
self.provider_override = provider_override
|
||
|
||
def run(self):
|
||
result = {
|
||
"success": False,
|
||
"output": "",
|
||
"output_vars": {},
|
||
"error": None,
|
||
"elapsed_ms": 0
|
||
}
|
||
|
||
start_time = time.time()
|
||
|
||
try:
|
||
if isinstance(self.step, PromptStep):
|
||
output, success = execute_prompt_step(
|
||
self.step, self.variables, self.provider_override
|
||
)
|
||
result["success"] = success
|
||
result["output"] = output
|
||
result["output_vars"] = {self.step.output_var: output}
|
||
if not success:
|
||
result["error"] = "Provider call failed"
|
||
|
||
elif isinstance(self.step, CodeStep):
|
||
outputs, success = execute_code_step(
|
||
self.step, self.variables, step_num=1
|
||
)
|
||
result["success"] = success
|
||
result["output_vars"] = outputs
|
||
result["output"] = "\n".join(f"{k} = {v}" for k, v in outputs.items())
|
||
if not success:
|
||
result["error"] = "Code execution failed"
|
||
|
||
elif isinstance(self.step, ToolStep):
|
||
output, success = execute_tool_step(
|
||
self.step, self.variables,
|
||
depth=0,
|
||
provider_override=self.provider_override,
|
||
dry_run=False,
|
||
verbose=False
|
||
)
|
||
result["success"] = success
|
||
result["output"] = output
|
||
result["output_vars"] = {self.step.output_var: output}
|
||
if not success:
|
||
result["error"] = f"Tool '{self.step.tool}' execution failed"
|
||
|
||
except Exception as e:
|
||
result["success"] = False
|
||
result["error"] = str(e)
|
||
|
||
result["elapsed_ms"] = int((time.time() - start_time) * 1000)
|
||
self.finished.emit(result)
|
||
|
||
|
||
class TestStepDialog(QDialog):
|
||
"""Dialog for interactively testing a single step."""
|
||
|
||
# Assertion types available
|
||
ASSERTION_TYPES = [
|
||
("not_empty", "Not Empty", "Output must not be empty"),
|
||
("contains", "Contains", "Output must contain the specified text"),
|
||
("not_contains", "Does Not Contain", "Output must NOT contain the specified text"),
|
||
("equals", "Equals", "Output must exactly equal the expected value"),
|
||
("valid_json", "Valid JSON", "Output must be valid JSON"),
|
||
("valid_python", "Valid Python", "Output must be valid Python syntax"),
|
||
("matches_regex", "Matches Regex", "Output must match the regular expression"),
|
||
("min_length", "Min Length", "Output must be at least N characters"),
|
||
("max_length", "Max Length", "Output must be at most N characters"),
|
||
]
|
||
|
||
def __init__(self, parent, step: Union[PromptStep, CodeStep, ToolStep], available_vars: list = None):
|
||
super().__init__(parent)
|
||
self.step = step
|
||
self.available_vars = available_vars or ["input"]
|
||
self._worker = None
|
||
|
||
# Determine step type for title
|
||
if isinstance(step, PromptStep):
|
||
step_type = "Prompt"
|
||
elif isinstance(step, CodeStep):
|
||
step_type = "Code"
|
||
elif isinstance(step, ToolStep):
|
||
step_type = f"Tool ({step.tool})"
|
||
else:
|
||
step_type = "Unknown"
|
||
|
||
step_name = step.name if step.name else step_type
|
||
self.setWindowTitle(f"Test Step: {step_name}")
|
||
self.setMinimumSize(800, 700)
|
||
|
||
self._setup_ui()
|
||
self._detect_variables()
|
||
|
||
def _setup_ui(self):
|
||
"""Set up the dialog UI."""
|
||
layout = QVBoxLayout(self)
|
||
layout.setSpacing(12)
|
||
|
||
# Main splitter: top (inputs) | bottom (output)
|
||
splitter = QSplitter(Qt.Vertical)
|
||
|
||
# Top section: Variables and Assertions
|
||
top_widget = QWidget()
|
||
top_layout = QHBoxLayout(top_widget)
|
||
top_layout.setContentsMargins(0, 0, 0, 0)
|
||
|
||
# Left: Variables input
|
||
vars_group = QGroupBox("Input Variables")
|
||
vars_layout = QVBoxLayout(vars_group)
|
||
|
||
vars_help = QLabel("Provide test values for variables used in this step:")
|
||
vars_help.setStyleSheet("color: #718096; font-size: 11px;")
|
||
vars_layout.addWidget(vars_help)
|
||
|
||
# Variables form
|
||
self.vars_form = QFormLayout()
|
||
self.vars_form.setSpacing(8)
|
||
self.var_inputs = {} # variable name -> QLineEdit or QPlainTextEdit
|
||
vars_layout.addLayout(self.vars_form)
|
||
vars_layout.addStretch()
|
||
|
||
top_layout.addWidget(vars_group, 1)
|
||
|
||
# Right: Assertions
|
||
assert_group = QGroupBox("Assertions (Optional)")
|
||
assert_layout = QVBoxLayout(assert_group)
|
||
|
||
assert_help = QLabel("Define checks to validate the step output:")
|
||
assert_help.setStyleSheet("color: #718096; font-size: 11px;")
|
||
assert_layout.addWidget(assert_help)
|
||
|
||
# Assertions table
|
||
self.assertions_table = QTableWidget(0, 3)
|
||
self.assertions_table.setHorizontalHeaderLabels(["Type", "Value", ""])
|
||
self.assertions_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Fixed)
|
||
self.assertions_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.Stretch)
|
||
self.assertions_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.Fixed)
|
||
self.assertions_table.setColumnWidth(0, 130)
|
||
self.assertions_table.setColumnWidth(2, 40)
|
||
self.assertions_table.verticalHeader().setVisible(False)
|
||
self.assertions_table.setMaximumHeight(150)
|
||
assert_layout.addWidget(self.assertions_table)
|
||
|
||
# Add assertion button
|
||
btn_add_assertion = QPushButton("+ Add Assertion")
|
||
btn_add_assertion.clicked.connect(self._add_assertion_row)
|
||
assert_layout.addWidget(btn_add_assertion)
|
||
|
||
top_layout.addWidget(assert_group, 1)
|
||
|
||
splitter.addWidget(top_widget)
|
||
|
||
# Bottom section: Controls and Output
|
||
bottom_widget = QWidget()
|
||
bottom_layout = QVBoxLayout(bottom_widget)
|
||
bottom_layout.setContentsMargins(0, 0, 0, 0)
|
||
|
||
# Controls row
|
||
controls_layout = QHBoxLayout()
|
||
|
||
# Provider override (for prompt and tool steps)
|
||
if isinstance(self.step, (PromptStep, ToolStep)):
|
||
controls_layout.addWidget(QLabel("Provider:"))
|
||
self.provider_combo = QComboBox()
|
||
self.provider_combo.addItem("(use step's default)")
|
||
providers = load_providers()
|
||
for provider in sorted(providers, key=lambda p: p.name):
|
||
self.provider_combo.addItem(provider.name)
|
||
# Add common defaults
|
||
for default in ["mock"]:
|
||
if self.provider_combo.findText(default) < 0:
|
||
self.provider_combo.addItem(default)
|
||
self.provider_combo.setMinimumWidth(150)
|
||
controls_layout.addWidget(self.provider_combo)
|
||
else:
|
||
self.provider_combo = None
|
||
|
||
controls_layout.addStretch()
|
||
|
||
# Run button
|
||
self.btn_run = QPushButton("Run Step")
|
||
self.btn_run.setMinimumHeight(36)
|
||
self.btn_run.setMinimumWidth(120)
|
||
self.btn_run.clicked.connect(self._run_test)
|
||
controls_layout.addWidget(self.btn_run)
|
||
|
||
bottom_layout.addLayout(controls_layout)
|
||
|
||
# Output section
|
||
output_group = QGroupBox("Output")
|
||
output_layout = QVBoxLayout(output_group)
|
||
|
||
# Status line
|
||
self.status_label = QLabel("Click 'Run Step' to test this step")
|
||
self.status_label.setStyleSheet("color: #718096;")
|
||
output_layout.addWidget(self.status_label)
|
||
|
||
# Output display
|
||
self.output_display = QTextEdit()
|
||
self.output_display.setReadOnly(True)
|
||
self.output_display.setPlaceholderText("Step output will appear here...")
|
||
font = self.output_display.font()
|
||
font.setFamily("Consolas, Monaco, monospace")
|
||
self.output_display.setFont(font)
|
||
output_layout.addWidget(self.output_display)
|
||
|
||
# Assertion results
|
||
self.assertion_results = QLabel("")
|
||
self.assertion_results.setWordWrap(True)
|
||
output_layout.addWidget(self.assertion_results)
|
||
|
||
bottom_layout.addWidget(output_group)
|
||
|
||
splitter.addWidget(bottom_widget)
|
||
splitter.setSizes([300, 400])
|
||
|
||
layout.addWidget(splitter)
|
||
|
||
# Dialog buttons
|
||
buttons_layout = QHBoxLayout()
|
||
buttons_layout.addStretch()
|
||
|
||
btn_close = QPushButton("Close")
|
||
btn_close.clicked.connect(self.accept)
|
||
buttons_layout.addWidget(btn_close)
|
||
|
||
layout.addLayout(buttons_layout)
|
||
|
||
def _detect_variables(self):
|
||
"""Detect variables used in the step and create input fields."""
|
||
# Get template text based on step type
|
||
template = ""
|
||
if isinstance(self.step, PromptStep):
|
||
template = self.step.prompt
|
||
elif isinstance(self.step, CodeStep):
|
||
template = self.step.code
|
||
elif isinstance(self.step, ToolStep):
|
||
template = self.step.input_template
|
||
# Also add args values
|
||
for value in self.step.args.values():
|
||
template += " " + str(value)
|
||
|
||
# Find all {variable} references (excluding escaped {{ }})
|
||
# Simple regex: match {word} but not {{ or }}
|
||
var_pattern = r'\{([a-zA-Z_][a-zA-Z0-9_]*)\}'
|
||
found_vars = set(re.findall(var_pattern, template))
|
||
|
||
# Combine with available_vars (from previous steps)
|
||
all_vars = sorted(set(self.available_vars) | found_vars)
|
||
|
||
# Create input fields for each variable
|
||
for var_name in all_vars:
|
||
if var_name == "input":
|
||
# Use multiline for input
|
||
widget = QPlainTextEdit()
|
||
widget.setPlaceholderText("Enter test input text...")
|
||
widget.setMaximumHeight(80)
|
||
else:
|
||
widget = QLineEdit()
|
||
widget.setPlaceholderText(f"Value for {{{var_name}}}")
|
||
|
||
self.var_inputs[var_name] = widget
|
||
self.vars_form.addRow(f"{{{var_name}}}:", widget)
|
||
|
||
def _add_assertion_row(self):
|
||
"""Add a new assertion row to the table."""
|
||
row = self.assertions_table.rowCount()
|
||
self.assertions_table.insertRow(row)
|
||
|
||
# Set proper row height for widgets
|
||
self.assertions_table.setRowHeight(row, 36)
|
||
|
||
# Type dropdown
|
||
type_combo = QComboBox()
|
||
type_combo.setMinimumHeight(28)
|
||
type_combo.setMinimumWidth(120)
|
||
for type_id, display_name, tooltip in self.ASSERTION_TYPES:
|
||
type_combo.addItem(display_name, type_id)
|
||
idx = type_combo.count() - 1
|
||
type_combo.setItemData(idx, tooltip, Qt.ToolTipRole)
|
||
self.assertions_table.setCellWidget(row, 0, type_combo)
|
||
|
||
# Value input
|
||
value_edit = QLineEdit()
|
||
value_edit.setPlaceholderText("Expected value (if applicable)")
|
||
value_edit.setMinimumHeight(28)
|
||
self.assertions_table.setCellWidget(row, 1, value_edit)
|
||
|
||
# Remove button
|
||
btn_remove = QPushButton("×")
|
||
btn_remove.setFixedWidth(30)
|
||
btn_remove.clicked.connect(lambda: self._remove_assertion_row(row))
|
||
self.assertions_table.setCellWidget(row, 2, btn_remove)
|
||
|
||
def _remove_assertion_row(self, row: int):
|
||
"""Remove an assertion row."""
|
||
self.assertions_table.removeRow(row)
|
||
# Update remove button connections for remaining rows
|
||
for i in range(self.assertions_table.rowCount()):
|
||
btn = self.assertions_table.cellWidget(i, 2)
|
||
if btn:
|
||
btn.clicked.disconnect()
|
||
btn.clicked.connect(lambda checked=False, r=i: self._remove_assertion_row(r))
|
||
|
||
def _get_assertions(self) -> list:
|
||
"""Get list of assertions from the table."""
|
||
assertions = []
|
||
for row in range(self.assertions_table.rowCount()):
|
||
type_combo = self.assertions_table.cellWidget(row, 0)
|
||
value_edit = self.assertions_table.cellWidget(row, 1)
|
||
if type_combo:
|
||
assertions.append({
|
||
"type": type_combo.currentData(),
|
||
"display": type_combo.currentText(),
|
||
"value": value_edit.text() if value_edit else ""
|
||
})
|
||
return assertions
|
||
|
||
def _run_test(self):
|
||
"""Run the step test."""
|
||
# Collect variable values
|
||
variables = {}
|
||
for var_name, widget in self.var_inputs.items():
|
||
if isinstance(widget, QPlainTextEdit):
|
||
variables[var_name] = widget.toPlainText()
|
||
else:
|
||
variables[var_name] = widget.text()
|
||
|
||
# Get provider override
|
||
provider_override = None
|
||
if self.provider_combo and self.provider_combo.currentIndex() > 0:
|
||
provider_override = self.provider_combo.currentText()
|
||
|
||
# Disable run button and show loading
|
||
self.btn_run.setEnabled(False)
|
||
self.btn_run.setText("Running...")
|
||
self.status_label.setText("Executing step...")
|
||
self.status_label.setStyleSheet("color: #718096;")
|
||
self.output_display.clear()
|
||
self.assertion_results.clear()
|
||
|
||
# Start worker thread
|
||
self._worker = StepTestWorker(self.step, variables, provider_override)
|
||
self._worker.finished.connect(self._on_test_finished)
|
||
self._worker.start()
|
||
|
||
def _on_test_finished(self, result: dict):
|
||
"""Handle test completion."""
|
||
self.btn_run.setEnabled(True)
|
||
self.btn_run.setText("Run Step")
|
||
|
||
# Display result
|
||
if result["success"]:
|
||
self.status_label.setText(f"✓ Step completed in {result['elapsed_ms']}ms")
|
||
self.status_label.setStyleSheet("color: #38a169; font-weight: bold;")
|
||
|
||
# Show output
|
||
output_text = result["output"]
|
||
if result["output_vars"]:
|
||
output_text += "\n\n--- Output Variables ---\n"
|
||
for var, value in result["output_vars"].items():
|
||
preview = value[:200] + "..." if len(value) > 200 else value
|
||
output_text += f"{var} = {preview}\n"
|
||
self.output_display.setPlainText(output_text)
|
||
else:
|
||
self.status_label.setText(f"✗ Step failed ({result['elapsed_ms']}ms)")
|
||
self.status_label.setStyleSheet("color: #e53e3e; font-weight: bold;")
|
||
|
||
error_text = result.get("error", "Unknown error")
|
||
self.output_display.setHtml(f"<span style='color: #e53e3e;'><b>Error:</b> {error_text}</span>")
|
||
|
||
# Run assertions
|
||
assertions = self._get_assertions()
|
||
if assertions and result["success"]:
|
||
self._run_assertions(result["output"], assertions)
|
||
|
||
def _run_assertions(self, output: str, assertions: list):
|
||
"""Run assertions against the output."""
|
||
results = []
|
||
all_passed = True
|
||
|
||
for assertion in assertions:
|
||
a_type = assertion["type"]
|
||
a_value = assertion["value"]
|
||
a_display = assertion["display"]
|
||
passed = False
|
||
message = ""
|
||
|
||
try:
|
||
if a_type == "not_empty":
|
||
passed = bool(output.strip())
|
||
message = "Output is not empty" if passed else "Output is empty"
|
||
|
||
elif a_type == "contains":
|
||
passed = a_value in output
|
||
message = f"Output contains '{a_value}'" if passed else f"Output does not contain '{a_value}'"
|
||
|
||
elif a_type == "not_contains":
|
||
passed = a_value not in output
|
||
message = f"Output does not contain '{a_value}'" if passed else f"Output contains '{a_value}'"
|
||
|
||
elif a_type == "equals":
|
||
passed = output.strip() == a_value.strip()
|
||
message = "Output equals expected" if passed else "Output does not equal expected"
|
||
|
||
elif a_type == "valid_json":
|
||
try:
|
||
json.loads(output)
|
||
passed = True
|
||
message = "Output is valid JSON"
|
||
except json.JSONDecodeError as e:
|
||
passed = False
|
||
message = f"Invalid JSON: {e}"
|
||
|
||
elif a_type == "valid_python":
|
||
try:
|
||
import ast
|
||
ast.parse(output)
|
||
passed = True
|
||
message = "Output is valid Python"
|
||
except SyntaxError as e:
|
||
passed = False
|
||
message = f"Invalid Python: {e}"
|
||
|
||
elif a_type == "matches_regex":
|
||
try:
|
||
passed = bool(re.search(a_value, output))
|
||
message = f"Output matches regex" if passed else f"Output does not match regex"
|
||
except re.error as e:
|
||
passed = False
|
||
message = f"Invalid regex: {e}"
|
||
|
||
elif a_type == "min_length":
|
||
try:
|
||
min_len = int(a_value)
|
||
passed = len(output) >= min_len
|
||
message = f"Length {len(output)} >= {min_len}" if passed else f"Length {len(output)} < {min_len}"
|
||
except ValueError:
|
||
passed = False
|
||
message = "Invalid minimum length value"
|
||
|
||
elif a_type == "max_length":
|
||
try:
|
||
max_len = int(a_value)
|
||
passed = len(output) <= max_len
|
||
message = f"Length {len(output)} <= {max_len}" if passed else f"Length {len(output)} > {max_len}"
|
||
except ValueError:
|
||
passed = False
|
||
message = "Invalid maximum length value"
|
||
|
||
except Exception as e:
|
||
passed = False
|
||
message = f"Error: {e}"
|
||
|
||
if not passed:
|
||
all_passed = False
|
||
|
||
results.append((a_display, passed, message))
|
||
|
||
# Display results
|
||
result_html = "<b>Assertion Results:</b><br>"
|
||
for display, passed, message in results:
|
||
icon = "✓" if passed else "✗"
|
||
color = "#38a169" if passed else "#e53e3e"
|
||
result_html += f"<span style='color: {color};'>{icon} {display}: {message}</span><br>"
|
||
|
||
if all_passed:
|
||
result_html = f"<span style='color: #38a169; font-weight: bold;'>All {len(results)} assertion(s) passed!</span><br>" + result_html
|
||
else:
|
||
failed_count = sum(1 for _, p, _ in results if not p)
|
||
result_html = f"<span style='color: #e53e3e; font-weight: bold;'>{failed_count} of {len(results)} assertion(s) failed</span><br>" + result_html
|
||
|
||
self.assertion_results.setText(result_html)
|