smarttools/tests/test_tool.py

519 lines
16 KiB
Python

"""Tests for tool.py - Tool definitions and management."""
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
import yaml
from smarttools.tool import (
Tool, ToolArgument, PromptStep, CodeStep,
validate_tool_name, load_tool, save_tool, delete_tool,
list_tools, tool_exists, create_wrapper_script,
DEFAULT_CATEGORIES
)
class TestToolArgument:
"""Tests for ToolArgument dataclass."""
def test_create_basic(self):
arg = ToolArgument(flag="--max", variable="max_size")
assert arg.flag == "--max"
assert arg.variable == "max_size"
assert arg.default is None
assert arg.description == ""
def test_create_with_defaults(self):
arg = ToolArgument(
flag="--format",
variable="format",
default="json",
description="Output format"
)
assert arg.default == "json"
assert arg.description == "Output format"
def test_to_dict_minimal(self):
arg = ToolArgument(flag="--max", variable="max_size")
d = arg.to_dict()
assert d == {"flag": "--max", "variable": "max_size"}
# Optional fields should not be present
assert "default" not in d
assert "description" not in d
def test_to_dict_full(self):
arg = ToolArgument(
flag="--format",
variable="format",
default="json",
description="Output format"
)
d = arg.to_dict()
assert d["default"] == "json"
assert d["description"] == "Output format"
def test_from_dict(self):
data = {
"flag": "--output",
"variable": "output",
"default": "stdout",
"description": "Where to write"
}
arg = ToolArgument.from_dict(data)
assert arg.flag == "--output"
assert arg.variable == "output"
assert arg.default == "stdout"
assert arg.description == "Where to write"
def test_roundtrip(self):
original = ToolArgument(
flag="--count",
variable="count",
default="10",
description="Number of items"
)
restored = ToolArgument.from_dict(original.to_dict())
assert restored.flag == original.flag
assert restored.variable == original.variable
assert restored.default == original.default
assert restored.description == original.description
class TestPromptStep:
"""Tests for PromptStep dataclass."""
def test_create_basic(self):
step = PromptStep(
prompt="Summarize: {input}",
provider="claude",
output_var="summary"
)
assert step.prompt == "Summarize: {input}"
assert step.provider == "claude"
assert step.output_var == "summary"
assert step.prompt_file is None
def test_to_dict(self):
step = PromptStep(
prompt="Translate: {input}",
provider="gpt4",
output_var="translation"
)
d = step.to_dict()
assert d["type"] == "prompt"
assert d["prompt"] == "Translate: {input}"
assert d["provider"] == "gpt4"
assert d["output_var"] == "translation"
def test_to_dict_with_prompt_file(self):
step = PromptStep(
prompt="",
provider="claude",
output_var="result",
prompt_file="complex_prompt.txt"
)
d = step.to_dict()
assert d["prompt_file"] == "complex_prompt.txt"
def test_from_dict(self):
data = {
"type": "prompt",
"prompt": "Fix grammar: {input}",
"provider": "openai",
"output_var": "fixed"
}
step = PromptStep.from_dict(data)
assert step.prompt == "Fix grammar: {input}"
assert step.provider == "openai"
assert step.output_var == "fixed"
def test_roundtrip(self):
original = PromptStep(
prompt="Analyze: {input}",
provider="claude",
output_var="analysis",
prompt_file="analysis.txt"
)
restored = PromptStep.from_dict(original.to_dict())
assert restored.prompt == original.prompt
assert restored.provider == original.provider
assert restored.output_var == original.output_var
assert restored.prompt_file == original.prompt_file
class TestCodeStep:
"""Tests for CodeStep dataclass."""
def test_create_basic(self):
step = CodeStep(
code="result = input.upper()",
output_var="result"
)
assert step.code == "result = input.upper()"
assert step.output_var == "result"
assert step.code_file is None
def test_multiple_output_vars(self):
step = CodeStep(
code="a = 1\nb = 2\nc = 3",
output_var="a, b, c"
)
assert step.output_var == "a, b, c"
def test_to_dict(self):
step = CodeStep(
code="count = len(input.split())",
output_var="count"
)
d = step.to_dict()
assert d["type"] == "code"
assert d["code"] == "count = len(input.split())"
assert d["output_var"] == "count"
def test_from_dict(self):
data = {
"type": "code",
"code": "lines = input.splitlines()",
"output_var": "lines"
}
step = CodeStep.from_dict(data)
assert step.code == "lines = input.splitlines()"
assert step.output_var == "lines"
def test_from_dict_empty_code(self):
"""Code can be empty if code_file is used."""
data = {
"type": "code",
"output_var": "result",
"code_file": "process.py"
}
step = CodeStep.from_dict(data)
assert step.code == ""
assert step.code_file == "process.py"
class TestTool:
"""Tests for Tool dataclass."""
def test_create_minimal(self):
tool = Tool(name="test-tool")
assert tool.name == "test-tool"
assert tool.description == ""
assert tool.category == "Other"
assert tool.arguments == []
assert tool.steps == []
assert tool.output == "{input}"
def test_create_full(self):
tool = Tool(
name="summarize",
description="Summarize text",
category="Text",
arguments=[
ToolArgument(flag="--max", variable="max_words", default="100")
],
steps=[
PromptStep(
prompt="Summarize in {max_words} words: {input}",
provider="claude",
output_var="summary"
)
],
output="{summary}"
)
assert tool.name == "summarize"
assert tool.category == "Text"
assert len(tool.arguments) == 1
assert len(tool.steps) == 1
def test_from_dict(self):
data = {
"name": "translate",
"description": "Translate text",
"category": "Text",
"arguments": [
{"flag": "--to", "variable": "target_lang", "default": "Spanish"}
],
"steps": [
{
"type": "prompt",
"prompt": "Translate to {target_lang}: {input}",
"provider": "claude",
"output_var": "translation"
}
],
"output": "{translation}"
}
tool = Tool.from_dict(data)
assert tool.name == "translate"
assert tool.category == "Text"
assert tool.arguments[0].variable == "target_lang"
assert isinstance(tool.steps[0], PromptStep)
def test_from_dict_with_code_step(self):
data = {
"name": "word-count",
"steps": [
{
"type": "code",
"code": "count = len(input.split())",
"output_var": "count"
}
],
"output": "Word count: {count}"
}
tool = Tool.from_dict(data)
assert isinstance(tool.steps[0], CodeStep)
assert tool.steps[0].code == "count = len(input.split())"
def test_to_dict(self):
tool = Tool(
name="greet",
description="Greet someone",
category="Text",
arguments=[],
steps=[],
output="Hello, {input}!"
)
d = tool.to_dict()
assert d["name"] == "greet"
assert d["description"] == "Greet someone"
# Category "Other" is default, "Text" should be included
assert d["category"] == "Text"
assert d["output"] == "Hello, {input}!"
def test_to_dict_default_category_omitted(self):
"""Default category 'Other' should not appear in dict."""
tool = Tool(name="test", category="Other")
d = tool.to_dict()
assert "category" not in d
def test_get_available_variables(self):
tool = Tool(
name="test",
arguments=[
ToolArgument(flag="--max", variable="max_size"),
ToolArgument(flag="--format", variable="format")
],
steps=[
PromptStep(prompt="", provider="mock", output_var="step1_out"),
CodeStep(code="", output_var="step2_out")
]
)
vars = tool.get_available_variables()
assert "input" in vars
assert "max_size" in vars
assert "format" in vars
assert "step1_out" in vars
assert "step2_out" in vars
def test_roundtrip(self):
original = Tool(
name="complex-tool",
description="A complex tool",
category="Developer",
arguments=[
ToolArgument(flag="--verbose", variable="verbose", default="false")
],
steps=[
PromptStep(prompt="Analyze: {input}", provider="claude", output_var="analysis"),
CodeStep(code="result = analysis.upper()", output_var="result")
],
output="{result}"
)
d = original.to_dict()
restored = Tool.from_dict(d)
assert restored.name == original.name
assert restored.description == original.description
assert restored.category == original.category
assert len(restored.arguments) == len(original.arguments)
assert len(restored.steps) == len(original.steps)
assert restored.output == original.output
class TestValidateToolName:
"""Tests for validate_tool_name function."""
def test_valid_names(self):
valid_names = [
"summarize",
"my-tool",
"tool_v2",
"_private",
"CamelCase",
"tool123",
]
for name in valid_names:
is_valid, error = validate_tool_name(name)
assert is_valid, f"'{name}' should be valid but got: {error}"
def test_empty_name(self):
is_valid, error = validate_tool_name("")
assert not is_valid
assert "empty" in error.lower()
def test_name_with_spaces(self):
is_valid, error = validate_tool_name("my tool")
assert not is_valid
assert "spaces" in error.lower()
def test_name_with_shell_chars(self):
bad_names = [
("my/tool", "/"),
("tool|pipe", "|"),
("tool;cmd", ";"),
("tool$var", "$"),
("tool`cmd`", "`"),
('tool"quote', '"'),
]
for name, bad_char in bad_names:
is_valid, error = validate_tool_name(name)
assert not is_valid, f"'{name}' should be invalid"
def test_name_starting_with_number(self):
is_valid, error = validate_tool_name("123tool")
assert not is_valid
assert "start" in error.lower()
def test_name_starting_with_dash(self):
is_valid, error = validate_tool_name("-tool")
assert not is_valid
assert "start" in error.lower()
class TestToolPersistence:
"""Tests for tool save/load/delete operations."""
@pytest.fixture
def temp_tools_dir(self, tmp_path):
"""Create a temporary tools directory."""
with patch('smarttools.tool.TOOLS_DIR', tmp_path / ".smarttools"):
with patch('smarttools.tool.BIN_DIR', tmp_path / ".local" / "bin"):
yield tmp_path
def test_save_and_load_tool(self, temp_tools_dir):
tool = Tool(
name="test-save",
description="Test saving",
steps=[
PromptStep(prompt="Hello {input}", provider="mock", output_var="response")
],
output="{response}"
)
# Save
config_path = save_tool(tool)
assert config_path.exists()
# Load
loaded = load_tool("test-save")
assert loaded is not None
assert loaded.name == "test-save"
assert loaded.description == "Test saving"
assert len(loaded.steps) == 1
def test_load_nonexistent_tool(self, temp_tools_dir):
result = load_tool("does-not-exist")
assert result is None
def test_delete_tool(self, temp_tools_dir):
# Create tool first
tool = Tool(name="to-delete")
save_tool(tool)
assert tool_exists("to-delete")
# Delete
result = delete_tool("to-delete")
assert result is True
assert not tool_exists("to-delete")
def test_delete_nonexistent_tool(self, temp_tools_dir):
result = delete_tool("never-existed")
assert result is False
def test_list_tools(self, temp_tools_dir):
# Create some tools
save_tool(Tool(name="alpha"))
save_tool(Tool(name="beta"))
save_tool(Tool(name="gamma"))
tools = list_tools()
assert "alpha" in tools
assert "beta" in tools
assert "gamma" in tools
# Should be sorted
assert tools == sorted(tools)
def test_tool_exists(self, temp_tools_dir):
save_tool(Tool(name="exists"))
assert tool_exists("exists")
assert not tool_exists("does-not-exist")
def test_create_wrapper_script(self, temp_tools_dir):
save_tool(Tool(name="wrapper-test"))
from smarttools.tool import get_bin_dir
wrapper_path = get_bin_dir() / "wrapper-test"
assert wrapper_path.exists()
content = wrapper_path.read_text()
assert "#!/bin/bash" in content
assert "wrapper-test" in content
assert "smarttools.runner" in content
class TestLegacyFormat:
"""Tests for loading legacy tool format."""
@pytest.fixture
def temp_tools_dir(self, tmp_path):
"""Create a temporary tools directory."""
with patch('smarttools.tool.TOOLS_DIR', tmp_path / ".smarttools"):
with patch('smarttools.tool.BIN_DIR', tmp_path / ".local" / "bin"):
yield tmp_path / ".smarttools"
def test_load_legacy_format(self, temp_tools_dir):
"""Test loading a tool in the old format."""
# Create legacy format tool
tool_dir = temp_tools_dir / "legacy-tool"
tool_dir.mkdir(parents=True)
legacy_config = {
"name": "legacy-tool",
"description": "A legacy tool",
"prompt": "Process: {input}",
"provider": "claude",
"inputs": [
{"name": "max_size", "flag": "--max", "default": "100"}
]
}
(tool_dir / "config.yaml").write_text(yaml.dump(legacy_config))
# Load and verify conversion
tool = load_tool("legacy-tool")
assert tool is not None
assert tool.name == "legacy-tool"
assert len(tool.steps) == 1
assert isinstance(tool.steps[0], PromptStep)
assert tool.steps[0].provider == "claude"
assert tool.output == "{response}"
# Arguments should be converted
assert len(tool.arguments) == 1
assert tool.arguments[0].variable == "max_size"
class TestDefaultCategories:
"""Tests for default categories."""
def test_default_categories_exist(self):
assert "Text" in DEFAULT_CATEGORIES
assert "Developer" in DEFAULT_CATEGORIES
assert "Data" in DEFAULT_CATEGORIES
assert "Other" in DEFAULT_CATEGORIES