519 lines
16 KiB
Python
519 lines
16 KiB
Python
"""Tests for tool.py - Tool definitions and management."""
|
|
|
|
import tempfile
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import yaml
|
|
|
|
from smarttools.tool import (
|
|
Tool, ToolArgument, PromptStep, CodeStep,
|
|
validate_tool_name, load_tool, save_tool, delete_tool,
|
|
list_tools, tool_exists, create_wrapper_script,
|
|
DEFAULT_CATEGORIES
|
|
)
|
|
|
|
|
|
class TestToolArgument:
|
|
"""Tests for ToolArgument dataclass."""
|
|
|
|
def test_create_basic(self):
|
|
arg = ToolArgument(flag="--max", variable="max_size")
|
|
assert arg.flag == "--max"
|
|
assert arg.variable == "max_size"
|
|
assert arg.default is None
|
|
assert arg.description == ""
|
|
|
|
def test_create_with_defaults(self):
|
|
arg = ToolArgument(
|
|
flag="--format",
|
|
variable="format",
|
|
default="json",
|
|
description="Output format"
|
|
)
|
|
assert arg.default == "json"
|
|
assert arg.description == "Output format"
|
|
|
|
def test_to_dict_minimal(self):
|
|
arg = ToolArgument(flag="--max", variable="max_size")
|
|
d = arg.to_dict()
|
|
assert d == {"flag": "--max", "variable": "max_size"}
|
|
# Optional fields should not be present
|
|
assert "default" not in d
|
|
assert "description" not in d
|
|
|
|
def test_to_dict_full(self):
|
|
arg = ToolArgument(
|
|
flag="--format",
|
|
variable="format",
|
|
default="json",
|
|
description="Output format"
|
|
)
|
|
d = arg.to_dict()
|
|
assert d["default"] == "json"
|
|
assert d["description"] == "Output format"
|
|
|
|
def test_from_dict(self):
|
|
data = {
|
|
"flag": "--output",
|
|
"variable": "output",
|
|
"default": "stdout",
|
|
"description": "Where to write"
|
|
}
|
|
arg = ToolArgument.from_dict(data)
|
|
assert arg.flag == "--output"
|
|
assert arg.variable == "output"
|
|
assert arg.default == "stdout"
|
|
assert arg.description == "Where to write"
|
|
|
|
def test_roundtrip(self):
|
|
original = ToolArgument(
|
|
flag="--count",
|
|
variable="count",
|
|
default="10",
|
|
description="Number of items"
|
|
)
|
|
restored = ToolArgument.from_dict(original.to_dict())
|
|
assert restored.flag == original.flag
|
|
assert restored.variable == original.variable
|
|
assert restored.default == original.default
|
|
assert restored.description == original.description
|
|
|
|
|
|
class TestPromptStep:
|
|
"""Tests for PromptStep dataclass."""
|
|
|
|
def test_create_basic(self):
|
|
step = PromptStep(
|
|
prompt="Summarize: {input}",
|
|
provider="claude",
|
|
output_var="summary"
|
|
)
|
|
assert step.prompt == "Summarize: {input}"
|
|
assert step.provider == "claude"
|
|
assert step.output_var == "summary"
|
|
assert step.prompt_file is None
|
|
|
|
def test_to_dict(self):
|
|
step = PromptStep(
|
|
prompt="Translate: {input}",
|
|
provider="gpt4",
|
|
output_var="translation"
|
|
)
|
|
d = step.to_dict()
|
|
assert d["type"] == "prompt"
|
|
assert d["prompt"] == "Translate: {input}"
|
|
assert d["provider"] == "gpt4"
|
|
assert d["output_var"] == "translation"
|
|
|
|
def test_to_dict_with_prompt_file(self):
|
|
step = PromptStep(
|
|
prompt="",
|
|
provider="claude",
|
|
output_var="result",
|
|
prompt_file="complex_prompt.txt"
|
|
)
|
|
d = step.to_dict()
|
|
assert d["prompt_file"] == "complex_prompt.txt"
|
|
|
|
def test_from_dict(self):
|
|
data = {
|
|
"type": "prompt",
|
|
"prompt": "Fix grammar: {input}",
|
|
"provider": "openai",
|
|
"output_var": "fixed"
|
|
}
|
|
step = PromptStep.from_dict(data)
|
|
assert step.prompt == "Fix grammar: {input}"
|
|
assert step.provider == "openai"
|
|
assert step.output_var == "fixed"
|
|
|
|
def test_roundtrip(self):
|
|
original = PromptStep(
|
|
prompt="Analyze: {input}",
|
|
provider="claude",
|
|
output_var="analysis",
|
|
prompt_file="analysis.txt"
|
|
)
|
|
restored = PromptStep.from_dict(original.to_dict())
|
|
assert restored.prompt == original.prompt
|
|
assert restored.provider == original.provider
|
|
assert restored.output_var == original.output_var
|
|
assert restored.prompt_file == original.prompt_file
|
|
|
|
|
|
class TestCodeStep:
|
|
"""Tests for CodeStep dataclass."""
|
|
|
|
def test_create_basic(self):
|
|
step = CodeStep(
|
|
code="result = input.upper()",
|
|
output_var="result"
|
|
)
|
|
assert step.code == "result = input.upper()"
|
|
assert step.output_var == "result"
|
|
assert step.code_file is None
|
|
|
|
def test_multiple_output_vars(self):
|
|
step = CodeStep(
|
|
code="a = 1\nb = 2\nc = 3",
|
|
output_var="a, b, c"
|
|
)
|
|
assert step.output_var == "a, b, c"
|
|
|
|
def test_to_dict(self):
|
|
step = CodeStep(
|
|
code="count = len(input.split())",
|
|
output_var="count"
|
|
)
|
|
d = step.to_dict()
|
|
assert d["type"] == "code"
|
|
assert d["code"] == "count = len(input.split())"
|
|
assert d["output_var"] == "count"
|
|
|
|
def test_from_dict(self):
|
|
data = {
|
|
"type": "code",
|
|
"code": "lines = input.splitlines()",
|
|
"output_var": "lines"
|
|
}
|
|
step = CodeStep.from_dict(data)
|
|
assert step.code == "lines = input.splitlines()"
|
|
assert step.output_var == "lines"
|
|
|
|
def test_from_dict_empty_code(self):
|
|
"""Code can be empty if code_file is used."""
|
|
data = {
|
|
"type": "code",
|
|
"output_var": "result",
|
|
"code_file": "process.py"
|
|
}
|
|
step = CodeStep.from_dict(data)
|
|
assert step.code == ""
|
|
assert step.code_file == "process.py"
|
|
|
|
|
|
class TestTool:
|
|
"""Tests for Tool dataclass."""
|
|
|
|
def test_create_minimal(self):
|
|
tool = Tool(name="test-tool")
|
|
assert tool.name == "test-tool"
|
|
assert tool.description == ""
|
|
assert tool.category == "Other"
|
|
assert tool.arguments == []
|
|
assert tool.steps == []
|
|
assert tool.output == "{input}"
|
|
|
|
def test_create_full(self):
|
|
tool = Tool(
|
|
name="summarize",
|
|
description="Summarize text",
|
|
category="Text",
|
|
arguments=[
|
|
ToolArgument(flag="--max", variable="max_words", default="100")
|
|
],
|
|
steps=[
|
|
PromptStep(
|
|
prompt="Summarize in {max_words} words: {input}",
|
|
provider="claude",
|
|
output_var="summary"
|
|
)
|
|
],
|
|
output="{summary}"
|
|
)
|
|
assert tool.name == "summarize"
|
|
assert tool.category == "Text"
|
|
assert len(tool.arguments) == 1
|
|
assert len(tool.steps) == 1
|
|
|
|
def test_from_dict(self):
|
|
data = {
|
|
"name": "translate",
|
|
"description": "Translate text",
|
|
"category": "Text",
|
|
"arguments": [
|
|
{"flag": "--to", "variable": "target_lang", "default": "Spanish"}
|
|
],
|
|
"steps": [
|
|
{
|
|
"type": "prompt",
|
|
"prompt": "Translate to {target_lang}: {input}",
|
|
"provider": "claude",
|
|
"output_var": "translation"
|
|
}
|
|
],
|
|
"output": "{translation}"
|
|
}
|
|
tool = Tool.from_dict(data)
|
|
assert tool.name == "translate"
|
|
assert tool.category == "Text"
|
|
assert tool.arguments[0].variable == "target_lang"
|
|
assert isinstance(tool.steps[0], PromptStep)
|
|
|
|
def test_from_dict_with_code_step(self):
|
|
data = {
|
|
"name": "word-count",
|
|
"steps": [
|
|
{
|
|
"type": "code",
|
|
"code": "count = len(input.split())",
|
|
"output_var": "count"
|
|
}
|
|
],
|
|
"output": "Word count: {count}"
|
|
}
|
|
tool = Tool.from_dict(data)
|
|
assert isinstance(tool.steps[0], CodeStep)
|
|
assert tool.steps[0].code == "count = len(input.split())"
|
|
|
|
def test_to_dict(self):
|
|
tool = Tool(
|
|
name="greet",
|
|
description="Greet someone",
|
|
category="Text",
|
|
arguments=[],
|
|
steps=[],
|
|
output="Hello, {input}!"
|
|
)
|
|
d = tool.to_dict()
|
|
assert d["name"] == "greet"
|
|
assert d["description"] == "Greet someone"
|
|
# Category "Other" is default, "Text" should be included
|
|
assert d["category"] == "Text"
|
|
assert d["output"] == "Hello, {input}!"
|
|
|
|
def test_to_dict_default_category_omitted(self):
|
|
"""Default category 'Other' should not appear in dict."""
|
|
tool = Tool(name="test", category="Other")
|
|
d = tool.to_dict()
|
|
assert "category" not in d
|
|
|
|
def test_get_available_variables(self):
|
|
tool = Tool(
|
|
name="test",
|
|
arguments=[
|
|
ToolArgument(flag="--max", variable="max_size"),
|
|
ToolArgument(flag="--format", variable="format")
|
|
],
|
|
steps=[
|
|
PromptStep(prompt="", provider="mock", output_var="step1_out"),
|
|
CodeStep(code="", output_var="step2_out")
|
|
]
|
|
)
|
|
vars = tool.get_available_variables()
|
|
assert "input" in vars
|
|
assert "max_size" in vars
|
|
assert "format" in vars
|
|
assert "step1_out" in vars
|
|
assert "step2_out" in vars
|
|
|
|
def test_roundtrip(self):
|
|
original = Tool(
|
|
name="complex-tool",
|
|
description="A complex tool",
|
|
category="Developer",
|
|
arguments=[
|
|
ToolArgument(flag="--verbose", variable="verbose", default="false")
|
|
],
|
|
steps=[
|
|
PromptStep(prompt="Analyze: {input}", provider="claude", output_var="analysis"),
|
|
CodeStep(code="result = analysis.upper()", output_var="result")
|
|
],
|
|
output="{result}"
|
|
)
|
|
d = original.to_dict()
|
|
restored = Tool.from_dict(d)
|
|
|
|
assert restored.name == original.name
|
|
assert restored.description == original.description
|
|
assert restored.category == original.category
|
|
assert len(restored.arguments) == len(original.arguments)
|
|
assert len(restored.steps) == len(original.steps)
|
|
assert restored.output == original.output
|
|
|
|
|
|
class TestValidateToolName:
|
|
"""Tests for validate_tool_name function."""
|
|
|
|
def test_valid_names(self):
|
|
valid_names = [
|
|
"summarize",
|
|
"my-tool",
|
|
"tool_v2",
|
|
"_private",
|
|
"CamelCase",
|
|
"tool123",
|
|
]
|
|
for name in valid_names:
|
|
is_valid, error = validate_tool_name(name)
|
|
assert is_valid, f"'{name}' should be valid but got: {error}"
|
|
|
|
def test_empty_name(self):
|
|
is_valid, error = validate_tool_name("")
|
|
assert not is_valid
|
|
assert "empty" in error.lower()
|
|
|
|
def test_name_with_spaces(self):
|
|
is_valid, error = validate_tool_name("my tool")
|
|
assert not is_valid
|
|
assert "spaces" in error.lower()
|
|
|
|
def test_name_with_shell_chars(self):
|
|
bad_names = [
|
|
("my/tool", "/"),
|
|
("tool|pipe", "|"),
|
|
("tool;cmd", ";"),
|
|
("tool$var", "$"),
|
|
("tool`cmd`", "`"),
|
|
('tool"quote', '"'),
|
|
]
|
|
for name, bad_char in bad_names:
|
|
is_valid, error = validate_tool_name(name)
|
|
assert not is_valid, f"'{name}' should be invalid"
|
|
|
|
def test_name_starting_with_number(self):
|
|
is_valid, error = validate_tool_name("123tool")
|
|
assert not is_valid
|
|
assert "start" in error.lower()
|
|
|
|
def test_name_starting_with_dash(self):
|
|
is_valid, error = validate_tool_name("-tool")
|
|
assert not is_valid
|
|
assert "start" in error.lower()
|
|
|
|
|
|
class TestToolPersistence:
|
|
"""Tests for tool save/load/delete operations."""
|
|
|
|
@pytest.fixture
|
|
def temp_tools_dir(self, tmp_path):
|
|
"""Create a temporary tools directory."""
|
|
with patch('smarttools.tool.TOOLS_DIR', tmp_path / ".smarttools"):
|
|
with patch('smarttools.tool.BIN_DIR', tmp_path / ".local" / "bin"):
|
|
yield tmp_path
|
|
|
|
def test_save_and_load_tool(self, temp_tools_dir):
|
|
tool = Tool(
|
|
name="test-save",
|
|
description="Test saving",
|
|
steps=[
|
|
PromptStep(prompt="Hello {input}", provider="mock", output_var="response")
|
|
],
|
|
output="{response}"
|
|
)
|
|
|
|
# Save
|
|
config_path = save_tool(tool)
|
|
assert config_path.exists()
|
|
|
|
# Load
|
|
loaded = load_tool("test-save")
|
|
assert loaded is not None
|
|
assert loaded.name == "test-save"
|
|
assert loaded.description == "Test saving"
|
|
assert len(loaded.steps) == 1
|
|
|
|
def test_load_nonexistent_tool(self, temp_tools_dir):
|
|
result = load_tool("does-not-exist")
|
|
assert result is None
|
|
|
|
def test_delete_tool(self, temp_tools_dir):
|
|
# Create tool first
|
|
tool = Tool(name="to-delete")
|
|
save_tool(tool)
|
|
|
|
assert tool_exists("to-delete")
|
|
|
|
# Delete
|
|
result = delete_tool("to-delete")
|
|
assert result is True
|
|
assert not tool_exists("to-delete")
|
|
|
|
def test_delete_nonexistent_tool(self, temp_tools_dir):
|
|
result = delete_tool("never-existed")
|
|
assert result is False
|
|
|
|
def test_list_tools(self, temp_tools_dir):
|
|
# Create some tools
|
|
save_tool(Tool(name="alpha"))
|
|
save_tool(Tool(name="beta"))
|
|
save_tool(Tool(name="gamma"))
|
|
|
|
tools = list_tools()
|
|
assert "alpha" in tools
|
|
assert "beta" in tools
|
|
assert "gamma" in tools
|
|
# Should be sorted
|
|
assert tools == sorted(tools)
|
|
|
|
def test_tool_exists(self, temp_tools_dir):
|
|
save_tool(Tool(name="exists"))
|
|
|
|
assert tool_exists("exists")
|
|
assert not tool_exists("does-not-exist")
|
|
|
|
def test_create_wrapper_script(self, temp_tools_dir):
|
|
save_tool(Tool(name="wrapper-test"))
|
|
|
|
from smarttools.tool import get_bin_dir
|
|
wrapper_path = get_bin_dir() / "wrapper-test"
|
|
|
|
assert wrapper_path.exists()
|
|
content = wrapper_path.read_text()
|
|
assert "#!/bin/bash" in content
|
|
assert "wrapper-test" in content
|
|
assert "smarttools.runner" in content
|
|
|
|
|
|
class TestLegacyFormat:
|
|
"""Tests for loading legacy tool format."""
|
|
|
|
@pytest.fixture
|
|
def temp_tools_dir(self, tmp_path):
|
|
"""Create a temporary tools directory."""
|
|
with patch('smarttools.tool.TOOLS_DIR', tmp_path / ".smarttools"):
|
|
with patch('smarttools.tool.BIN_DIR', tmp_path / ".local" / "bin"):
|
|
yield tmp_path / ".smarttools"
|
|
|
|
def test_load_legacy_format(self, temp_tools_dir):
|
|
"""Test loading a tool in the old format."""
|
|
# Create legacy format tool
|
|
tool_dir = temp_tools_dir / "legacy-tool"
|
|
tool_dir.mkdir(parents=True)
|
|
|
|
legacy_config = {
|
|
"name": "legacy-tool",
|
|
"description": "A legacy tool",
|
|
"prompt": "Process: {input}",
|
|
"provider": "claude",
|
|
"inputs": [
|
|
{"name": "max_size", "flag": "--max", "default": "100"}
|
|
]
|
|
}
|
|
(tool_dir / "config.yaml").write_text(yaml.dump(legacy_config))
|
|
|
|
# Load and verify conversion
|
|
tool = load_tool("legacy-tool")
|
|
assert tool is not None
|
|
assert tool.name == "legacy-tool"
|
|
assert len(tool.steps) == 1
|
|
assert isinstance(tool.steps[0], PromptStep)
|
|
assert tool.steps[0].provider == "claude"
|
|
assert tool.output == "{response}"
|
|
|
|
# Arguments should be converted
|
|
assert len(tool.arguments) == 1
|
|
assert tool.arguments[0].variable == "max_size"
|
|
|
|
|
|
class TestDefaultCategories:
|
|
"""Tests for default categories."""
|
|
|
|
def test_default_categories_exist(self):
|
|
assert "Text" in DEFAULT_CATEGORIES
|
|
assert "Developer" in DEFAULT_CATEGORIES
|
|
assert "Data" in DEFAULT_CATEGORIES
|
|
assert "Other" in DEFAULT_CATEGORIES
|