"""Tests for tool.py - Tool definitions and management.""" import tempfile from pathlib import Path from unittest.mock import patch import pytest import yaml from smarttools.tool import ( Tool, ToolArgument, PromptStep, CodeStep, validate_tool_name, load_tool, save_tool, delete_tool, list_tools, tool_exists, create_wrapper_script, DEFAULT_CATEGORIES ) class TestToolArgument: """Tests for ToolArgument dataclass.""" def test_create_basic(self): arg = ToolArgument(flag="--max", variable="max_size") assert arg.flag == "--max" assert arg.variable == "max_size" assert arg.default is None assert arg.description == "" def test_create_with_defaults(self): arg = ToolArgument( flag="--format", variable="format", default="json", description="Output format" ) assert arg.default == "json" assert arg.description == "Output format" def test_to_dict_minimal(self): arg = ToolArgument(flag="--max", variable="max_size") d = arg.to_dict() assert d == {"flag": "--max", "variable": "max_size"} # Optional fields should not be present assert "default" not in d assert "description" not in d def test_to_dict_full(self): arg = ToolArgument( flag="--format", variable="format", default="json", description="Output format" ) d = arg.to_dict() assert d["default"] == "json" assert d["description"] == "Output format" def test_from_dict(self): data = { "flag": "--output", "variable": "output", "default": "stdout", "description": "Where to write" } arg = ToolArgument.from_dict(data) assert arg.flag == "--output" assert arg.variable == "output" assert arg.default == "stdout" assert arg.description == "Where to write" def test_roundtrip(self): original = ToolArgument( flag="--count", variable="count", default="10", description="Number of items" ) restored = ToolArgument.from_dict(original.to_dict()) assert restored.flag == original.flag assert restored.variable == original.variable assert restored.default == original.default assert restored.description == original.description class TestPromptStep: """Tests for PromptStep dataclass.""" def test_create_basic(self): step = PromptStep( prompt="Summarize: {input}", provider="claude", output_var="summary" ) assert step.prompt == "Summarize: {input}" assert step.provider == "claude" assert step.output_var == "summary" assert step.prompt_file is None def test_to_dict(self): step = PromptStep( prompt="Translate: {input}", provider="gpt4", output_var="translation" ) d = step.to_dict() assert d["type"] == "prompt" assert d["prompt"] == "Translate: {input}" assert d["provider"] == "gpt4" assert d["output_var"] == "translation" def test_to_dict_with_prompt_file(self): step = PromptStep( prompt="", provider="claude", output_var="result", prompt_file="complex_prompt.txt" ) d = step.to_dict() assert d["prompt_file"] == "complex_prompt.txt" def test_from_dict(self): data = { "type": "prompt", "prompt": "Fix grammar: {input}", "provider": "openai", "output_var": "fixed" } step = PromptStep.from_dict(data) assert step.prompt == "Fix grammar: {input}" assert step.provider == "openai" assert step.output_var == "fixed" def test_roundtrip(self): original = PromptStep( prompt="Analyze: {input}", provider="claude", output_var="analysis", prompt_file="analysis.txt" ) restored = PromptStep.from_dict(original.to_dict()) assert restored.prompt == original.prompt assert restored.provider == original.provider assert restored.output_var == original.output_var assert restored.prompt_file == original.prompt_file class TestCodeStep: """Tests for CodeStep dataclass.""" def test_create_basic(self): step = CodeStep( code="result = input.upper()", output_var="result" ) assert step.code == "result = input.upper()" assert step.output_var == "result" assert step.code_file is None def test_multiple_output_vars(self): step = CodeStep( code="a = 1\nb = 2\nc = 3", output_var="a, b, c" ) assert step.output_var == "a, b, c" def test_to_dict(self): step = CodeStep( code="count = len(input.split())", output_var="count" ) d = step.to_dict() assert d["type"] == "code" assert d["code"] == "count = len(input.split())" assert d["output_var"] == "count" def test_from_dict(self): data = { "type": "code", "code": "lines = input.splitlines()", "output_var": "lines" } step = CodeStep.from_dict(data) assert step.code == "lines = input.splitlines()" assert step.output_var == "lines" def test_from_dict_empty_code(self): """Code can be empty if code_file is used.""" data = { "type": "code", "output_var": "result", "code_file": "process.py" } step = CodeStep.from_dict(data) assert step.code == "" assert step.code_file == "process.py" class TestTool: """Tests for Tool dataclass.""" def test_create_minimal(self): tool = Tool(name="test-tool") assert tool.name == "test-tool" assert tool.description == "" assert tool.category == "Other" assert tool.arguments == [] assert tool.steps == [] assert tool.output == "{input}" def test_create_full(self): tool = Tool( name="summarize", description="Summarize text", category="Text", arguments=[ ToolArgument(flag="--max", variable="max_words", default="100") ], steps=[ PromptStep( prompt="Summarize in {max_words} words: {input}", provider="claude", output_var="summary" ) ], output="{summary}" ) assert tool.name == "summarize" assert tool.category == "Text" assert len(tool.arguments) == 1 assert len(tool.steps) == 1 def test_from_dict(self): data = { "name": "translate", "description": "Translate text", "category": "Text", "arguments": [ {"flag": "--to", "variable": "target_lang", "default": "Spanish"} ], "steps": [ { "type": "prompt", "prompt": "Translate to {target_lang}: {input}", "provider": "claude", "output_var": "translation" } ], "output": "{translation}" } tool = Tool.from_dict(data) assert tool.name == "translate" assert tool.category == "Text" assert tool.arguments[0].variable == "target_lang" assert isinstance(tool.steps[0], PromptStep) def test_from_dict_with_code_step(self): data = { "name": "word-count", "steps": [ { "type": "code", "code": "count = len(input.split())", "output_var": "count" } ], "output": "Word count: {count}" } tool = Tool.from_dict(data) assert isinstance(tool.steps[0], CodeStep) assert tool.steps[0].code == "count = len(input.split())" def test_to_dict(self): tool = Tool( name="greet", description="Greet someone", category="Text", arguments=[], steps=[], output="Hello, {input}!" ) d = tool.to_dict() assert d["name"] == "greet" assert d["description"] == "Greet someone" # Category "Other" is default, "Text" should be included assert d["category"] == "Text" assert d["output"] == "Hello, {input}!" def test_to_dict_default_category_omitted(self): """Default category 'Other' should not appear in dict.""" tool = Tool(name="test", category="Other") d = tool.to_dict() assert "category" not in d def test_get_available_variables(self): tool = Tool( name="test", arguments=[ ToolArgument(flag="--max", variable="max_size"), ToolArgument(flag="--format", variable="format") ], steps=[ PromptStep(prompt="", provider="mock", output_var="step1_out"), CodeStep(code="", output_var="step2_out") ] ) vars = tool.get_available_variables() assert "input" in vars assert "max_size" in vars assert "format" in vars assert "step1_out" in vars assert "step2_out" in vars def test_roundtrip(self): original = Tool( name="complex-tool", description="A complex tool", category="Developer", arguments=[ ToolArgument(flag="--verbose", variable="verbose", default="false") ], steps=[ PromptStep(prompt="Analyze: {input}", provider="claude", output_var="analysis"), CodeStep(code="result = analysis.upper()", output_var="result") ], output="{result}" ) d = original.to_dict() restored = Tool.from_dict(d) assert restored.name == original.name assert restored.description == original.description assert restored.category == original.category assert len(restored.arguments) == len(original.arguments) assert len(restored.steps) == len(original.steps) assert restored.output == original.output class TestValidateToolName: """Tests for validate_tool_name function.""" def test_valid_names(self): valid_names = [ "summarize", "my-tool", "tool_v2", "_private", "CamelCase", "tool123", ] for name in valid_names: is_valid, error = validate_tool_name(name) assert is_valid, f"'{name}' should be valid but got: {error}" def test_empty_name(self): is_valid, error = validate_tool_name("") assert not is_valid assert "empty" in error.lower() def test_name_with_spaces(self): is_valid, error = validate_tool_name("my tool") assert not is_valid assert "spaces" in error.lower() def test_name_with_shell_chars(self): bad_names = [ ("my/tool", "/"), ("tool|pipe", "|"), ("tool;cmd", ";"), ("tool$var", "$"), ("tool`cmd`", "`"), ('tool"quote', '"'), ] for name, bad_char in bad_names: is_valid, error = validate_tool_name(name) assert not is_valid, f"'{name}' should be invalid" def test_name_starting_with_number(self): is_valid, error = validate_tool_name("123tool") assert not is_valid assert "start" in error.lower() def test_name_starting_with_dash(self): is_valid, error = validate_tool_name("-tool") assert not is_valid assert "start" in error.lower() class TestToolPersistence: """Tests for tool save/load/delete operations.""" @pytest.fixture def temp_tools_dir(self, tmp_path): """Create a temporary tools directory.""" with patch('smarttools.tool.TOOLS_DIR', tmp_path / ".smarttools"): with patch('smarttools.tool.BIN_DIR', tmp_path / ".local" / "bin"): yield tmp_path def test_save_and_load_tool(self, temp_tools_dir): tool = Tool( name="test-save", description="Test saving", steps=[ PromptStep(prompt="Hello {input}", provider="mock", output_var="response") ], output="{response}" ) # Save config_path = save_tool(tool) assert config_path.exists() # Load loaded = load_tool("test-save") assert loaded is not None assert loaded.name == "test-save" assert loaded.description == "Test saving" assert len(loaded.steps) == 1 def test_load_nonexistent_tool(self, temp_tools_dir): result = load_tool("does-not-exist") assert result is None def test_delete_tool(self, temp_tools_dir): # Create tool first tool = Tool(name="to-delete") save_tool(tool) assert tool_exists("to-delete") # Delete result = delete_tool("to-delete") assert result is True assert not tool_exists("to-delete") def test_delete_nonexistent_tool(self, temp_tools_dir): result = delete_tool("never-existed") assert result is False def test_list_tools(self, temp_tools_dir): # Create some tools save_tool(Tool(name="alpha")) save_tool(Tool(name="beta")) save_tool(Tool(name="gamma")) tools = list_tools() assert "alpha" in tools assert "beta" in tools assert "gamma" in tools # Should be sorted assert tools == sorted(tools) def test_tool_exists(self, temp_tools_dir): save_tool(Tool(name="exists")) assert tool_exists("exists") assert not tool_exists("does-not-exist") def test_create_wrapper_script(self, temp_tools_dir): save_tool(Tool(name="wrapper-test")) from smarttools.tool import get_bin_dir wrapper_path = get_bin_dir() / "wrapper-test" assert wrapper_path.exists() content = wrapper_path.read_text() assert "#!/bin/bash" in content assert "wrapper-test" in content assert "smarttools.runner" in content class TestLegacyFormat: """Tests for loading legacy tool format.""" @pytest.fixture def temp_tools_dir(self, tmp_path): """Create a temporary tools directory.""" with patch('smarttools.tool.TOOLS_DIR', tmp_path / ".smarttools"): with patch('smarttools.tool.BIN_DIR', tmp_path / ".local" / "bin"): yield tmp_path / ".smarttools" def test_load_legacy_format(self, temp_tools_dir): """Test loading a tool in the old format.""" # Create legacy format tool tool_dir = temp_tools_dir / "legacy-tool" tool_dir.mkdir(parents=True) legacy_config = { "name": "legacy-tool", "description": "A legacy tool", "prompt": "Process: {input}", "provider": "claude", "inputs": [ {"name": "max_size", "flag": "--max", "default": "100"} ] } (tool_dir / "config.yaml").write_text(yaml.dump(legacy_config)) # Load and verify conversion tool = load_tool("legacy-tool") assert tool is not None assert tool.name == "legacy-tool" assert len(tool.steps) == 1 assert isinstance(tool.steps[0], PromptStep) assert tool.steps[0].provider == "claude" assert tool.output == "{response}" # Arguments should be converted assert len(tool.arguments) == 1 assert tool.arguments[0].variable == "max_size" class TestDefaultCategories: """Tests for default categories.""" def test_default_categories_exist(self): assert "Text" in DEFAULT_CATEGORIES assert "Developer" in DEFAULT_CATEGORIES assert "Data" in DEFAULT_CATEGORIES assert "Other" in DEFAULT_CATEGORIES