CmdForge/src/cmdforge/tool.py

443 lines
14 KiB
Python

"""Tool loading, saving, and management."""
import os
import stat
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, List, Literal
import yaml
# Default tools directory
TOOLS_DIR = Path.home() / ".cmdforge"
# Default bin directory for wrapper scripts
BIN_DIR = Path.home() / ".local" / "bin"
@dataclass
class ToolArgument:
"""Definition of a custom input argument."""
flag: str # e.g., "--max-size"
variable: str # e.g., "max_size"
default: Optional[str] = None
description: str = ""
def to_dict(self) -> dict:
d = {"flag": self.flag, "variable": self.variable}
if self.default:
d["default"] = self.default
if self.description:
d["description"] = self.description
return d
@classmethod
def from_dict(cls, data: dict) -> "ToolArgument":
return cls(
flag=data["flag"],
variable=data["variable"],
default=data.get("default"),
description=data.get("description", "")
)
@dataclass
class PromptStep:
"""A prompt step that calls an AI provider."""
prompt: str # The prompt template
provider: str # Provider name
output_var: str # Variable to store output
prompt_file: Optional[str] = None # Optional filename for external prompt
profile: Optional[str] = None # Optional AI persona profile name
def to_dict(self) -> dict:
d = {
"type": "prompt",
"prompt": self.prompt,
"provider": self.provider,
"output_var": self.output_var
}
if self.prompt_file:
d["prompt_file"] = self.prompt_file
if self.profile:
d["profile"] = self.profile
return d
@classmethod
def from_dict(cls, data: dict) -> "PromptStep":
return cls(
prompt=data["prompt"],
provider=data["provider"],
output_var=data["output_var"],
prompt_file=data.get("prompt_file"),
profile=data.get("profile")
)
@dataclass
class CodeStep:
"""A code step that runs Python code."""
code: str # Python code (inline or loaded from file)
output_var: str # Variable name(s) to capture (comma-separated for multiple)
code_file: Optional[str] = None # Optional filename for external code
def to_dict(self) -> dict:
d = {
"type": "code",
"code": self.code,
"output_var": self.output_var
}
if self.code_file:
d["code_file"] = self.code_file
return d
@classmethod
def from_dict(cls, data: dict) -> "CodeStep":
return cls(
code=data.get("code", ""),
output_var=data["output_var"],
code_file=data.get("code_file")
)
@dataclass
class ToolStep:
"""A step that calls another tool."""
tool: str # Tool reference (owner/name or just name)
output_var: str # Variable to store output
input_template: str = "{input}" # Input template (supports variable substitution)
args: dict = field(default_factory=dict) # Arguments to pass to the tool
provider: Optional[str] = None # Provider override for the called tool
def to_dict(self) -> dict:
d = {
"type": "tool",
"tool": self.tool,
"output_var": self.output_var,
}
if self.input_template != "{input}":
d["input"] = self.input_template
if self.args:
d["args"] = self.args
if self.provider:
d["provider"] = self.provider
return d
@classmethod
def from_dict(cls, data: dict) -> "ToolStep":
return cls(
tool=data["tool"],
output_var=data["output_var"],
input_template=data.get("input", "{input}"),
args=data.get("args", {}),
provider=data.get("provider")
)
Step = PromptStep | CodeStep | ToolStep
@dataclass
class ToolSource:
"""Attribution and source information for imported/external tools."""
type: str = "original" # "original", "imported", "forked"
license: Optional[str] = None
url: Optional[str] = None
author: Optional[str] = None
original_tool: Optional[str] = None # e.g., "fabric/patterns/extract_wisdom"
def to_dict(self) -> dict:
d = {"type": self.type}
if self.license:
d["license"] = self.license
if self.url:
d["url"] = self.url
if self.author:
d["author"] = self.author
if self.original_tool:
d["original_tool"] = self.original_tool
return d
@classmethod
def from_dict(cls, data: dict) -> "ToolSource":
return cls(
type=data.get("type", "original"),
license=data.get("license"),
url=data.get("url"),
author=data.get("author"),
original_tool=data.get("original_tool"),
)
# Default categories for organizing tools
DEFAULT_CATEGORIES = ["Text", "Developer", "Data", "Other"]
@dataclass
class Tool:
"""A CmdForge tool definition."""
name: str
description: str = ""
category: str = "Other" # Tool category for organization
arguments: List[ToolArgument] = field(default_factory=list)
steps: List[Step] = field(default_factory=list)
output: str = "{input}" # Output template
dependencies: List[str] = field(default_factory=list) # Required tools for meta-tools
source: Optional[ToolSource] = None # Attribution for imported/external tools
version: str = "" # Tool version
@classmethod
def from_dict(cls, data: dict) -> "Tool":
arguments = []
for arg in data.get("arguments", []):
arguments.append(ToolArgument.from_dict(arg))
steps = []
for step in data.get("steps", []):
if step.get("type") == "prompt":
steps.append(PromptStep.from_dict(step))
elif step.get("type") == "code":
steps.append(CodeStep.from_dict(step))
elif step.get("type") == "tool":
steps.append(ToolStep.from_dict(step))
# Parse source attribution if present
source = None
if "source" in data:
source = ToolSource.from_dict(data["source"])
return cls(
name=data["name"],
description=data.get("description", ""),
category=data.get("category", "Other"),
arguments=arguments,
steps=steps,
output=data.get("output", "{input}"),
dependencies=data.get("dependencies", []),
source=source,
version=data.get("version", ""),
)
def to_dict(self) -> dict:
d = {
"name": self.name,
"description": self.description,
}
if self.version:
d["version"] = self.version
# Only include category if it's not the default
if self.category and self.category != "Other":
d["category"] = self.category
# Include source attribution if present
if self.source:
d["source"] = self.source.to_dict()
if self.dependencies:
d["dependencies"] = self.dependencies
d["arguments"] = [arg.to_dict() for arg in self.arguments]
d["steps"] = [step.to_dict() for step in self.steps]
d["output"] = self.output
return d
def get_available_variables(self) -> List[str]:
"""Get all variables available for use in templates."""
variables = ["input"] # Always available
# Add argument variables
for arg in self.arguments:
variables.append(arg.variable)
# Add step output variables
for step in self.steps:
variables.append(step.output_var)
return variables
def get_tools_dir() -> Path:
"""Get the tools directory, creating it if needed."""
TOOLS_DIR.mkdir(parents=True, exist_ok=True)
return TOOLS_DIR
def get_bin_dir() -> Path:
"""Get the bin directory for wrapper scripts, creating it if needed."""
BIN_DIR.mkdir(parents=True, exist_ok=True)
return BIN_DIR
def list_tools() -> list[str]:
"""List all available tools."""
tools_dir = get_tools_dir()
tools = []
for item in tools_dir.iterdir():
if item.is_dir():
config = item / "config.yaml"
if config.exists():
tools.append(item.name)
return sorted(tools)
def load_tool(name: str) -> Optional[Tool]:
"""Load a tool by name."""
config_path = get_tools_dir() / name / "config.yaml"
if not config_path.exists():
return None
try:
data = yaml.safe_load(config_path.read_text())
# Handle legacy format (prompt/provider/provider_args/inputs)
if "prompt" in data and "steps" not in data:
# Convert to new format
steps = []
if data.get("prompt"):
steps.append({
"type": "prompt",
"prompt": data["prompt"],
"provider": data.get("provider", "mock"),
"output_var": "response"
})
arguments = []
for inp in data.get("inputs", []):
arguments.append({
"flag": inp.get("flag", f"--{inp['name']}"),
"variable": inp["name"],
"default": inp.get("default"),
"description": inp.get("description", "")
})
data = {
"name": data["name"],
"description": data.get("description", ""),
"arguments": arguments,
"steps": steps,
"output": "{response}" if steps else "{input}"
}
return Tool.from_dict(data)
except yaml.YAMLError as e:
import sys
print(f"Error loading tool '{name}': YAML syntax error", file=sys.stderr)
if hasattr(e, 'problem_mark') and e.problem_mark:
mark = e.problem_mark
print(f" Line {mark.line + 1}, column {mark.column + 1}", file=sys.stderr)
# Show the problematic line with context
try:
lines = config_path.read_text().split('\n')
if mark.line < len(lines):
print(file=sys.stderr)
# Show line before for context
if mark.line > 0:
print(f" {mark.line}: {lines[mark.line - 1]}", file=sys.stderr)
print(f" > {mark.line + 1}: {lines[mark.line]}", file=sys.stderr)
print(f" {' ' * (mark.column + 4)}^", file=sys.stderr)
except Exception:
pass
if hasattr(e, 'problem') and e.problem:
print(f"\n Problem: {e.problem}", file=sys.stderr)
return None
except KeyError as e:
import sys
print(f"Error loading tool '{name}': Missing required field {e}", file=sys.stderr)
return None
except Exception as e:
import sys
print(f"Error loading tool '{name}': {e}", file=sys.stderr)
return None
def save_tool(tool: Tool) -> Path:
"""Save a tool to disk."""
tool_dir = get_tools_dir() / tool.name
tool_dir.mkdir(parents=True, exist_ok=True)
config_path = tool_dir / "config.yaml"
config_path.write_text(yaml.dump(tool.to_dict(), default_flow_style=False, sort_keys=False))
# Create wrapper script
create_wrapper_script(tool.name)
return config_path
def delete_tool(name: str) -> bool:
"""Delete a tool."""
tool_dir = get_tools_dir() / name
if not tool_dir.exists():
return False
# Remove wrapper script
wrapper = get_bin_dir() / name
if wrapper.exists():
wrapper.unlink()
# Remove tool directory
import shutil
shutil.rmtree(tool_dir)
return True
def create_wrapper_script(name: str) -> Path:
"""Create a wrapper script for a tool in ~/.local/bin."""
import sys
bin_dir = get_bin_dir()
wrapper_path = bin_dir / name
# Use the current Python interpreter to ensure cmdforge is available
python_path = sys.executable
script = f"""#!/bin/bash
# CmdForge wrapper for '{name}'
# Auto-generated - do not edit
exec {python_path} -m cmdforge.runner {name} "$@"
"""
wrapper_path.write_text(script)
wrapper_path.chmod(wrapper_path.stat().st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
return wrapper_path
def tool_exists(name: str) -> bool:
"""Check if a tool exists."""
return (get_tools_dir() / name / "config.yaml").exists()
def validate_tool_name(name: str) -> tuple[bool, str]:
"""
Validate a tool name.
Returns:
(is_valid, error_message) - error_message is empty if valid
"""
if not name:
return False, "Tool name cannot be empty"
if ' ' in name:
return False, "Tool name cannot contain spaces"
# Check for shell-problematic characters
bad_chars = set('/\\|&;$`"\'<>(){}[]!?*#~')
found = [c for c in name if c in bad_chars]
if found:
return False, f"Tool name cannot contain: {' '.join(found)}"
# Must start with letter or underscore
if not (name[0].isalpha() or name[0] == '_'):
return False, "Tool name must start with a letter or underscore"
# Check it's a valid identifier-ish (alphanumeric, underscore, dash)
for c in name:
if not (c.isalnum() or c in '_-'):
return False, f"Tool name can only contain letters, numbers, underscore, and dash"
return True, ""