CmdForge/src/cmdforge/gui/widgets/flow_graph.py

419 lines
14 KiB
Python

"""Flow visualization widget using NodeGraphQt."""
from typing import Optional, List
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QMenu
from PySide6.QtCore import Signal, Qt, QTimer, QEvent, QPropertyAnimation, QEasingCurve
from PySide6.QtGui import QKeyEvent, QAction
from NodeGraphQt import NodeGraph, BaseNode
from ...tool import Tool, PromptStep, CodeStep
# =============================================================================
# Custom Node Types
# =============================================================================
class CmdForgeBaseNode(BaseNode):
"""Base class for CmdForge nodes with common styling."""
__identifier__ = 'cmdforge'
def __init__(self):
super().__init__()
# Store reference to the step data
self._step_data = None
self._step_index = -1
class InputNode(CmdForgeBaseNode):
"""Node representing tool input (stdin and arguments)."""
NODE_NAME = 'Input'
def __init__(self):
super().__init__()
self.set_color(90, 90, 160) # Indigo
# Default output for stdin
self.add_output('input', color=(180, 180, 250))
def set_arguments(self, arguments: list):
"""Add output ports for each argument variable."""
for arg in arguments:
if arg.variable:
# Check if port already exists
existing = [p.name() for p in self.output_ports()]
if arg.variable not in existing:
self.add_output(arg.variable, color=(180, 180, 250))
class PromptNode(CmdForgeBaseNode):
"""Node representing a prompt step."""
NODE_NAME = 'Prompt'
def __init__(self):
super().__init__()
self.set_color(102, 126, 234) # Indigo (matching CmdForge theme)
self.add_input('in', color=(180, 180, 250), multi_input=True)
self.add_output('out', color=(180, 250, 180))
# Add properties for display
self.add_text_input('provider', 'Provider', text='claude')
self.add_text_input('output_var', 'Output', text='response')
def set_step(self, step: PromptStep, index: int):
"""Configure node from a PromptStep."""
self._step_data = step
self._step_index = index
self.set_property('provider', step.provider or 'claude')
self.set_property('output_var', step.output_var or 'response')
# Update output port name
if self.output_ports():
self.output_ports()[0]._name = step.output_var or 'response'
class CodeNode(CmdForgeBaseNode):
"""Node representing a code step."""
NODE_NAME = 'Code'
def __init__(self):
super().__init__()
self.set_color(72, 187, 120) # Green
self.add_input('in', color=(180, 250, 180), multi_input=True)
self.add_output('out', color=(250, 220, 180))
# Add properties
self.add_text_input('output_var', 'Output', text='result')
def set_step(self, step: CodeStep, index: int):
"""Configure node from a CodeStep."""
self._step_data = step
self._step_index = index
self.set_property('output_var', step.output_var or 'result')
# Update output port name
if self.output_ports():
self.output_ports()[0]._name = step.output_var or 'result'
class OutputNode(CmdForgeBaseNode):
"""Node representing tool output."""
NODE_NAME = 'Output'
def __init__(self):
super().__init__()
self.set_color(237, 137, 54) # Orange
self.add_input('in', color=(250, 220, 180), multi_input=True)
# =============================================================================
# Flow Graph Widget
# =============================================================================
class FlowGraphWidget(QWidget):
"""Widget for visualizing tool flow as a node graph."""
# Emitted when a node is double-clicked (step_index, step_type)
node_double_clicked = Signal(int, str)
# Emitted when the flow structure changes
flow_changed = Signal()
def __init__(self, parent=None):
super().__init__(parent)
self._tool: Optional[Tool] = None
self._graph: Optional[NodeGraph] = None
self._input_node = None
self._output_node = None
self._step_nodes: List[CmdForgeBaseNode] = []
self._setup_ui()
def _setup_ui(self):
"""Set up the UI."""
layout = QVBoxLayout(self)
layout.setContentsMargins(0, 0, 0, 0)
layout.setSpacing(0)
# Create node graph
self._graph = NodeGraph()
# Register custom node types
self._graph.register_node(InputNode)
self._graph.register_node(PromptNode)
self._graph.register_node(CodeNode)
self._graph.register_node(OutputNode)
# Connect signals
self._graph.node_double_clicked.connect(self._on_node_double_clicked)
# Add graph widget
layout.addWidget(self._graph.widget, 1)
# Install event filter to catch F key and focus events on the graph widget
self._graph.widget.installEventFilter(self)
# Create floating help banner (overlay on graph widget)
self._help_banner = QLabel(self._graph.widget)
self._help_banner.setText(
"Pan: Middle-click drag | "
"Zoom: Scroll wheel | "
"Select: Click or drag box | "
"Edit: Double-click | "
"Select All: A | "
"Fit Selection: F"
)
self._help_banner.setStyleSheet("""
QLabel {
background-color: rgba(45, 55, 72, 0.9);
color: #e2e8f0;
font-size: 11px;
padding: 8px 16px;
border-radius: 4px;
}
""")
self._help_banner.adjustSize()
self._help_banner.hide()
# Animation for fading
self._fade_animation = QPropertyAnimation(self._help_banner, b"windowOpacity")
self._fade_animation.setDuration(500)
self._fade_animation.setEasingCurve(QEasingCurve.InOutQuad)
# Timer for auto-hide
self._hide_timer = QTimer(self)
self._hide_timer.setSingleShot(True)
self._hide_timer.timeout.connect(self._fade_out_banner)
# Track if mouse is over banner
self._help_banner.setMouseTracking(True)
self._help_banner.enterEvent = self._on_banner_enter
self._help_banner.leaveEvent = self._on_banner_leave
# Set up context menu
self._graph.widget.setContextMenuPolicy(Qt.CustomContextMenu)
self._graph.widget.customContextMenuRequested.connect(self._show_context_menu)
def set_tool(self, tool: Optional[Tool]):
"""Set the tool to visualize."""
self._tool = tool
self._rebuild_graph()
def _rebuild_graph(self):
"""Rebuild the graph from the tool data."""
if not self._graph:
return
# Clear existing nodes
self._graph.clear_session()
self._input_node = None
self._output_node = None
self._step_nodes = []
if not self._tool:
return
# Create input node
self._input_node = self._graph.create_node(
'cmdforge.InputNode',
name='Input',
pos=[-300, 0]
)
if self._tool.arguments:
self._input_node.set_arguments(self._tool.arguments)
# Create step nodes
x_pos = 0
prev_node = self._input_node
# Count steps by type for default naming
prompt_count = 0
code_count = 0
for i, step in enumerate(self._tool.steps or []):
if isinstance(step, PromptStep):
prompt_count += 1
# Use custom name if set, otherwise default to "Prompt N" (per type)
node_name = step.name if step.name else f'Prompt {prompt_count}'
node = self._graph.create_node(
'cmdforge.PromptNode',
name=node_name,
pos=[x_pos, 0]
)
node.set_step(step, i)
elif isinstance(step, CodeStep):
code_count += 1
# Use custom name if set, otherwise default to "Code N" (per type)
node_name = step.name if step.name else f'Code {code_count}'
node = self._graph.create_node(
'cmdforge.CodeNode',
name=node_name,
pos=[x_pos, 0]
)
node.set_step(step, i)
else:
continue
self._step_nodes.append(node)
# Connect to previous node
if prev_node and prev_node.output_ports():
prev_node.output_ports()[0].connect_to(node.input_ports()[0])
prev_node = node
x_pos += 250
# Create output node
self._output_node = self._graph.create_node(
'cmdforge.OutputNode',
name='Output',
pos=[x_pos, 0]
)
# Connect last step to output
if prev_node and prev_node.output_ports():
prev_node.output_ports()[0].connect_to(self._output_node.input_ports()[0])
# Auto-layout and fit view
self._graph.auto_layout_nodes()
# Select all nodes then fit to selection, then clear selection
all_nodes = self._graph.all_nodes()
for node in all_nodes:
node.set_selected(True)
# Use a timer to fit after the widget is fully rendered
QTimer.singleShot(50, self._fit_and_clear_selection)
def _fit_and_clear_selection(self):
"""Fit view to all nodes and clear selection."""
if self._graph:
self._graph.fit_to_selection()
# Clear selection
self._graph.clear_selection()
def select_all_nodes(self):
"""Select all nodes in the graph."""
if not self._graph:
return
all_nodes = self._graph.all_nodes()
for node in all_nodes:
node.set_selected(True)
def fit_selection(self):
"""Fit view to show selected nodes (or all if none selected)."""
if not self._graph:
return
selected = self._graph.selected_nodes()
if not selected:
# No selection - fit all nodes
all_nodes = self._graph.all_nodes()
if not all_nodes:
return
for node in all_nodes:
node.set_selected(True)
self._graph.fit_to_selection()
self._graph.clear_selection()
else:
# Fit to current selection (don't clear it)
self._graph.fit_to_selection()
def keyPressEvent(self, event: QKeyEvent):
"""Handle keyboard shortcuts."""
if event.key() == Qt.Key_F:
self.fit_selection()
event.accept()
elif event.key() == Qt.Key_A:
self.select_all_nodes()
event.accept()
else:
super().keyPressEvent(event)
def showEvent(self, event):
"""Handle widget becoming visible."""
super().showEvent(event)
# Show help banner when flow view becomes visible
QTimer.singleShot(100, self._show_help_banner)
def eventFilter(self, obj, event):
"""Filter events from the graph widget to catch keyboard shortcuts."""
if event.type() == QEvent.KeyPress:
if event.key() == Qt.Key_F:
self.fit_selection()
return True # Event handled
elif event.key() == Qt.Key_A:
self.select_all_nodes()
return True # Event handled
elif event.type() == QEvent.Enter:
# Show banner when mouse enters the graph area
self._show_help_banner()
return super().eventFilter(obj, event)
def _show_help_banner(self):
"""Show the help banner with fade effect."""
if not self._help_banner or not self._graph:
return
# Make sure size is calculated
self._help_banner.adjustSize()
# Position at top center of the graph widget
parent_width = self._graph.widget.width()
banner_width = self._help_banner.width()
x = max(10, (parent_width - banner_width) // 2)
self._help_banner.move(x, 10)
# Raise to front and show
self._help_banner.raise_()
self._help_banner.show()
# Start auto-hide timer
self._hide_timer.start(3000) # Hide after 3 seconds
def _fade_out_banner(self):
"""Fade out the help banner."""
if self._help_banner and self._help_banner.isVisible():
# Use a simple hide with timer since windowOpacity doesn't work well on child widgets
self._help_banner.hide()
def _on_banner_enter(self, event):
"""Mouse entered banner - stop hide timer."""
self._hide_timer.stop()
def _on_banner_leave(self, event):
"""Mouse left banner - restart hide timer."""
self._hide_timer.start(1500) # Hide 1.5 seconds after mouse leaves
def _show_context_menu(self, pos):
"""Show context menu."""
menu = QMenu(self._graph.widget)
select_all_action = QAction("Select All (A)", menu)
select_all_action.triggered.connect(self.select_all_nodes)
menu.addAction(select_all_action)
fit_action = QAction("Fit Selection (F)", menu)
fit_action.triggered.connect(self.fit_selection)
menu.addAction(fit_action)
menu.exec_(self._graph.widget.mapToGlobal(pos))
def _on_node_double_clicked(self, node):
"""Handle node double-click."""
if hasattr(node, '_step_index') and node._step_index >= 0:
step_type = 'prompt' if isinstance(node, PromptNode) else 'code'
self.node_double_clicked.emit(node._step_index, step_type)
def refresh(self):
"""Refresh the graph from current tool data."""
self._rebuild_graph()
def get_graph(self) -> Optional[NodeGraph]:
"""Get the underlying NodeGraph instance."""
return self._graph