"""Flow visualization widget using NodeGraphQt.""" from typing import Optional, List from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QMenu from PySide6.QtCore import Signal, Qt, QTimer, QEvent, QPropertyAnimation, QEasingCurve from PySide6.QtGui import QKeyEvent, QAction from NodeGraphQt import NodeGraph, BaseNode from ...tool import Tool, PromptStep, CodeStep # ============================================================================= # Custom Node Types # ============================================================================= class CmdForgeBaseNode(BaseNode): """Base class for CmdForge nodes with common styling.""" __identifier__ = 'cmdforge' def __init__(self): super().__init__() # Store reference to the step data self._step_data = None self._step_index = -1 class InputNode(CmdForgeBaseNode): """Node representing tool input (stdin and arguments).""" NODE_NAME = 'Input' def __init__(self): super().__init__() self.set_color(90, 90, 160) # Indigo # Default output for stdin self.add_output('input', color=(180, 180, 250)) def set_arguments(self, arguments: list): """Add output ports for each argument variable.""" for arg in arguments: if arg.variable: # Check if port already exists existing = [p.name() for p in self.output_ports()] if arg.variable not in existing: self.add_output(arg.variable, color=(180, 180, 250)) class PromptNode(CmdForgeBaseNode): """Node representing a prompt step.""" NODE_NAME = 'Prompt' def __init__(self): super().__init__() self.set_color(102, 126, 234) # Indigo (matching CmdForge theme) self.add_input('in', color=(180, 180, 250), multi_input=True) self.add_output('out', color=(180, 250, 180)) # Add properties for display self.add_text_input('provider', 'Provider', text='claude') self.add_text_input('output_var', 'Output', text='response') def set_step(self, step: PromptStep, index: int): """Configure node from a PromptStep.""" self._step_data = step self._step_index = index self.set_property('provider', step.provider or 'claude') self.set_property('output_var', step.output_var or 'response') # Update output port name if self.output_ports(): self.output_ports()[0]._name = step.output_var or 'response' class CodeNode(CmdForgeBaseNode): """Node representing a code step.""" NODE_NAME = 'Code' def __init__(self): super().__init__() self.set_color(72, 187, 120) # Green self.add_input('in', color=(180, 250, 180), multi_input=True) self.add_output('out', color=(250, 220, 180)) # Add properties self.add_text_input('output_var', 'Output', text='result') def set_step(self, step: CodeStep, index: int): """Configure node from a CodeStep.""" self._step_data = step self._step_index = index self.set_property('output_var', step.output_var or 'result') # Update output port name if self.output_ports(): self.output_ports()[0]._name = step.output_var or 'result' class OutputNode(CmdForgeBaseNode): """Node representing tool output.""" NODE_NAME = 'Output' def __init__(self): super().__init__() self.set_color(237, 137, 54) # Orange self.add_input('in', color=(250, 220, 180), multi_input=True) # ============================================================================= # Flow Graph Widget # ============================================================================= class FlowGraphWidget(QWidget): """Widget for visualizing tool flow as a node graph.""" # Emitted when a node is double-clicked (step_index, step_type) node_double_clicked = Signal(int, str) # Emitted when the flow structure changes flow_changed = Signal() def __init__(self, parent=None): super().__init__(parent) self._tool: Optional[Tool] = None self._graph: Optional[NodeGraph] = None self._input_node = None self._output_node = None self._step_nodes: List[CmdForgeBaseNode] = [] self._setup_ui() def _setup_ui(self): """Set up the UI.""" layout = QVBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) layout.setSpacing(0) # Create node graph self._graph = NodeGraph() # Register custom node types self._graph.register_node(InputNode) self._graph.register_node(PromptNode) self._graph.register_node(CodeNode) self._graph.register_node(OutputNode) # Connect signals self._graph.node_double_clicked.connect(self._on_node_double_clicked) # Add graph widget layout.addWidget(self._graph.widget, 1) # Install event filter to catch F key and focus events on the graph widget self._graph.widget.installEventFilter(self) # Create floating help banner (overlay on graph widget) self._help_banner = QLabel(self._graph.widget) self._help_banner.setText( "Pan: Middle-click drag | " "Zoom: Scroll wheel | " "Select: Click or drag box | " "Edit: Double-click | " "Select All: A | " "Fit Selection: F" ) self._help_banner.setStyleSheet(""" QLabel { background-color: rgba(45, 55, 72, 0.9); color: #e2e8f0; font-size: 11px; padding: 8px 16px; border-radius: 4px; } """) self._help_banner.adjustSize() self._help_banner.hide() # Animation for fading self._fade_animation = QPropertyAnimation(self._help_banner, b"windowOpacity") self._fade_animation.setDuration(500) self._fade_animation.setEasingCurve(QEasingCurve.InOutQuad) # Timer for auto-hide self._hide_timer = QTimer(self) self._hide_timer.setSingleShot(True) self._hide_timer.timeout.connect(self._fade_out_banner) # Track if mouse is over banner self._help_banner.setMouseTracking(True) self._help_banner.enterEvent = self._on_banner_enter self._help_banner.leaveEvent = self._on_banner_leave # Set up context menu self._graph.widget.setContextMenuPolicy(Qt.CustomContextMenu) self._graph.widget.customContextMenuRequested.connect(self._show_context_menu) def set_tool(self, tool: Optional[Tool]): """Set the tool to visualize.""" self._tool = tool self._rebuild_graph() def _rebuild_graph(self): """Rebuild the graph from the tool data.""" if not self._graph: return # Clear existing nodes self._graph.clear_session() self._input_node = None self._output_node = None self._step_nodes = [] if not self._tool: return # Create input node self._input_node = self._graph.create_node( 'cmdforge.InputNode', name='Input', pos=[-300, 0] ) if self._tool.arguments: self._input_node.set_arguments(self._tool.arguments) # Create step nodes x_pos = 0 prev_node = self._input_node # Count steps by type for default naming prompt_count = 0 code_count = 0 for i, step in enumerate(self._tool.steps or []): if isinstance(step, PromptStep): prompt_count += 1 # Use custom name if set, otherwise default to "Prompt N" (per type) node_name = step.name if step.name else f'Prompt {prompt_count}' node = self._graph.create_node( 'cmdforge.PromptNode', name=node_name, pos=[x_pos, 0] ) node.set_step(step, i) elif isinstance(step, CodeStep): code_count += 1 # Use custom name if set, otherwise default to "Code N" (per type) node_name = step.name if step.name else f'Code {code_count}' node = self._graph.create_node( 'cmdforge.CodeNode', name=node_name, pos=[x_pos, 0] ) node.set_step(step, i) else: continue self._step_nodes.append(node) # Connect to previous node if prev_node and prev_node.output_ports(): prev_node.output_ports()[0].connect_to(node.input_ports()[0]) prev_node = node x_pos += 250 # Create output node self._output_node = self._graph.create_node( 'cmdforge.OutputNode', name='Output', pos=[x_pos, 0] ) # Connect last step to output if prev_node and prev_node.output_ports(): prev_node.output_ports()[0].connect_to(self._output_node.input_ports()[0]) # Auto-layout and fit view self._graph.auto_layout_nodes() # Select all nodes then fit to selection, then clear selection all_nodes = self._graph.all_nodes() for node in all_nodes: node.set_selected(True) # Use a timer to fit after the widget is fully rendered QTimer.singleShot(50, self._fit_and_clear_selection) def _fit_and_clear_selection(self): """Fit view to all nodes and clear selection.""" if self._graph: self._graph.fit_to_selection() # Clear selection self._graph.clear_selection() def select_all_nodes(self): """Select all nodes in the graph.""" if not self._graph: return all_nodes = self._graph.all_nodes() for node in all_nodes: node.set_selected(True) def fit_selection(self): """Fit view to show selected nodes (or all if none selected).""" if not self._graph: return selected = self._graph.selected_nodes() if not selected: # No selection - fit all nodes all_nodes = self._graph.all_nodes() if not all_nodes: return for node in all_nodes: node.set_selected(True) self._graph.fit_to_selection() self._graph.clear_selection() else: # Fit to current selection (don't clear it) self._graph.fit_to_selection() def keyPressEvent(self, event: QKeyEvent): """Handle keyboard shortcuts.""" if event.key() == Qt.Key_F: self.fit_selection() event.accept() elif event.key() == Qt.Key_A: self.select_all_nodes() event.accept() else: super().keyPressEvent(event) def showEvent(self, event): """Handle widget becoming visible.""" super().showEvent(event) # Show help banner when flow view becomes visible QTimer.singleShot(100, self._show_help_banner) def eventFilter(self, obj, event): """Filter events from the graph widget to catch keyboard shortcuts.""" if event.type() == QEvent.KeyPress: if event.key() == Qt.Key_F: self.fit_selection() return True # Event handled elif event.key() == Qt.Key_A: self.select_all_nodes() return True # Event handled elif event.type() == QEvent.Enter: # Show banner when mouse enters the graph area self._show_help_banner() return super().eventFilter(obj, event) def _show_help_banner(self): """Show the help banner with fade effect.""" if not self._help_banner or not self._graph: return # Make sure size is calculated self._help_banner.adjustSize() # Position at top center of the graph widget parent_width = self._graph.widget.width() banner_width = self._help_banner.width() x = max(10, (parent_width - banner_width) // 2) self._help_banner.move(x, 10) # Raise to front and show self._help_banner.raise_() self._help_banner.show() # Start auto-hide timer self._hide_timer.start(3000) # Hide after 3 seconds def _fade_out_banner(self): """Fade out the help banner.""" if self._help_banner and self._help_banner.isVisible(): # Use a simple hide with timer since windowOpacity doesn't work well on child widgets self._help_banner.hide() def _on_banner_enter(self, event): """Mouse entered banner - stop hide timer.""" self._hide_timer.stop() def _on_banner_leave(self, event): """Mouse left banner - restart hide timer.""" self._hide_timer.start(1500) # Hide 1.5 seconds after mouse leaves def _show_context_menu(self, pos): """Show context menu.""" menu = QMenu(self._graph.widget) select_all_action = QAction("Select All (A)", menu) select_all_action.triggered.connect(self.select_all_nodes) menu.addAction(select_all_action) fit_action = QAction("Fit Selection (F)", menu) fit_action.triggered.connect(self.fit_selection) menu.addAction(fit_action) menu.exec_(self._graph.widget.mapToGlobal(pos)) def _on_node_double_clicked(self, node): """Handle node double-click.""" if hasattr(node, '_step_index') and node._step_index >= 0: step_type = 'prompt' if isinstance(node, PromptNode) else 'code' self.node_double_clicked.emit(node._step_index, step_type) def refresh(self): """Refresh the graph from current tool data.""" self._rebuild_graph() def get_graph(self) -> Optional[NodeGraph]: """Get the underlying NodeGraph instance.""" return self._graph