smarttools/tests/test_providers.py

379 lines
12 KiB
Python

"""Tests for providers.py - AI provider abstraction."""
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock
import pytest
import yaml
from smarttools.providers import (
Provider, ProviderResult,
load_providers, save_providers, get_provider,
add_provider, delete_provider,
call_provider, mock_provider,
DEFAULT_PROVIDERS
)
class TestProvider:
"""Tests for Provider dataclass."""
def test_create_basic(self):
provider = Provider(name="test", command="echo test")
assert provider.name == "test"
assert provider.command == "echo test"
assert provider.description == ""
def test_create_with_description(self):
provider = Provider(
name="claude",
command="claude -p",
description="Anthropic Claude"
)
assert provider.description == "Anthropic Claude"
def test_to_dict(self):
provider = Provider(
name="gpt4",
command="openai chat",
description="OpenAI GPT-4"
)
d = provider.to_dict()
assert d["name"] == "gpt4"
assert d["command"] == "openai chat"
assert d["description"] == "OpenAI GPT-4"
def test_from_dict(self):
data = {
"name": "local",
"command": "ollama run llama2",
"description": "Local Ollama"
}
provider = Provider.from_dict(data)
assert provider.name == "local"
assert provider.command == "ollama run llama2"
assert provider.description == "Local Ollama"
def test_from_dict_missing_description(self):
data = {"name": "test", "command": "test cmd"}
provider = Provider.from_dict(data)
assert provider.description == ""
def test_roundtrip(self):
original = Provider(
name="custom",
command="my-ai --prompt",
description="My custom provider"
)
restored = Provider.from_dict(original.to_dict())
assert restored.name == original.name
assert restored.command == original.command
assert restored.description == original.description
class TestProviderResult:
"""Tests for ProviderResult dataclass."""
def test_success_result(self):
result = ProviderResult(text="Hello!", success=True)
assert result.text == "Hello!"
assert result.success is True
assert result.error is None
def test_error_result(self):
result = ProviderResult(text="", success=False, error="API timeout")
assert result.text == ""
assert result.success is False
assert result.error == "API timeout"
class TestMockProvider:
"""Tests for the mock provider."""
def test_mock_returns_success(self):
result = mock_provider("Test prompt")
assert result.success is True
assert "[MOCK RESPONSE]" in result.text
def test_mock_includes_prompt_info(self):
result = mock_provider("This is a test prompt")
assert "Prompt length:" in result.text
assert "chars" in result.text
def test_mock_shows_first_line_preview(self):
result = mock_provider("First line here\nSecond line")
assert "First line" in result.text
def test_mock_truncates_long_first_line(self):
long_line = "x" * 100
result = mock_provider(long_line)
assert "..." in result.text
def test_mock_counts_lines(self):
prompt = "line1\nline2\nline3"
result = mock_provider(prompt)
assert "3 lines" in result.text
class TestProviderPersistence:
"""Tests for provider save/load operations."""
@pytest.fixture
def temp_providers_file(self, tmp_path):
"""Create a temporary providers file."""
providers_file = tmp_path / ".smarttools" / "providers.yaml"
with patch('smarttools.providers.PROVIDERS_FILE', providers_file):
yield providers_file
def test_save_and_load_providers(self, temp_providers_file):
providers = [
Provider("test1", "cmd1", "Description 1"),
Provider("test2", "cmd2", "Description 2")
]
save_providers(providers)
loaded = load_providers()
assert len(loaded) == 2
assert loaded[0].name == "test1"
assert loaded[1].name == "test2"
def test_get_provider_exists(self, temp_providers_file):
providers = [
Provider("target", "target-cmd", "Target provider")
]
save_providers(providers)
result = get_provider("target")
assert result is not None
assert result.name == "target"
assert result.command == "target-cmd"
def test_get_provider_not_exists(self, temp_providers_file):
save_providers([])
result = get_provider("nonexistent")
assert result is None
def test_add_new_provider(self, temp_providers_file):
save_providers([])
new_provider = Provider("new", "new-cmd", "New provider")
add_provider(new_provider)
loaded = load_providers()
assert any(p.name == "new" for p in loaded)
def test_add_provider_updates_existing(self, temp_providers_file):
save_providers([
Provider("existing", "old-cmd", "Old description")
])
updated = Provider("existing", "new-cmd", "New description")
add_provider(updated)
loaded = load_providers()
existing = next(p for p in loaded if p.name == "existing")
assert existing.command == "new-cmd"
assert existing.description == "New description"
def test_delete_provider(self, temp_providers_file):
save_providers([
Provider("keep", "keep-cmd"),
Provider("delete", "delete-cmd")
])
result = delete_provider("delete")
assert result is True
loaded = load_providers()
assert not any(p.name == "delete" for p in loaded)
assert any(p.name == "keep" for p in loaded)
def test_delete_nonexistent_provider(self, temp_providers_file):
save_providers([])
result = delete_provider("nonexistent")
assert result is False
class TestDefaultProviders:
"""Tests for default providers."""
def test_default_providers_exist(self):
assert len(DEFAULT_PROVIDERS) > 0
def test_mock_in_defaults(self):
assert any(p.name == "mock" for p in DEFAULT_PROVIDERS)
def test_claude_in_defaults(self):
assert any(p.name == "claude" for p in DEFAULT_PROVIDERS)
def test_all_defaults_have_commands(self):
for provider in DEFAULT_PROVIDERS:
assert provider.command, f"Provider {provider.name} has no command"
class TestCallProvider:
"""Tests for call_provider function."""
@pytest.fixture
def temp_providers_file(self, tmp_path):
"""Create a temporary providers file."""
providers_file = tmp_path / ".smarttools" / "providers.yaml"
with patch('smarttools.providers.PROVIDERS_FILE', providers_file):
yield providers_file
def test_call_mock_provider(self, temp_providers_file):
"""Mock provider should work without subprocess."""
result = call_provider("mock", "Test prompt")
assert result.success is True
assert "[MOCK RESPONSE]" in result.text
def test_call_nonexistent_provider(self, temp_providers_file):
save_providers([])
result = call_provider("nonexistent", "Test")
assert result.success is False
assert "not found" in result.error.lower()
@patch('subprocess.run')
@patch('shutil.which')
def test_call_real_provider_success(self, mock_which, mock_run, temp_providers_file):
# Setup
mock_which.return_value = "/usr/bin/echo"
mock_run.return_value = MagicMock(
returncode=0,
stdout="AI response here",
stderr=""
)
save_providers([Provider("echo-test", "echo test")])
result = call_provider("echo-test", "Prompt")
assert result.success is True
assert result.text == "AI response here"
@patch('subprocess.run')
@patch('shutil.which')
def test_call_provider_nonzero_exit(self, mock_which, mock_run, temp_providers_file):
mock_which.return_value = "/usr/bin/cmd"
mock_run.return_value = MagicMock(
returncode=1,
stdout="",
stderr="Error occurred"
)
save_providers([Provider("failing", "failing-cmd")])
result = call_provider("failing", "Prompt")
assert result.success is False
assert "exited with code 1" in result.error
@patch('subprocess.run')
@patch('shutil.which')
def test_call_provider_empty_output(self, mock_which, mock_run, temp_providers_file):
mock_which.return_value = "/usr/bin/cmd"
mock_run.return_value = MagicMock(
returncode=0,
stdout=" ", # Only whitespace
stderr=""
)
save_providers([Provider("empty", "empty-cmd")])
result = call_provider("empty", "Prompt")
assert result.success is False
assert "empty output" in result.error.lower()
@patch('subprocess.run')
@patch('shutil.which')
def test_call_provider_timeout(self, mock_which, mock_run, temp_providers_file):
import subprocess
mock_which.return_value = "/usr/bin/slow"
mock_run.side_effect = subprocess.TimeoutExpired(cmd="slow", timeout=10)
save_providers([Provider("slow", "slow-cmd")])
result = call_provider("slow", "Prompt", timeout=10)
assert result.success is False
assert "timed out" in result.error.lower()
@patch('shutil.which')
def test_call_provider_command_not_found(self, mock_which, temp_providers_file):
mock_which.return_value = None
save_providers([Provider("missing", "nonexistent-binary")])
result = call_provider("missing", "Prompt")
assert result.success is False
assert "not found" in result.error.lower()
@patch('subprocess.run')
@patch('shutil.which')
def test_provider_receives_prompt_as_stdin(self, mock_which, mock_run, temp_providers_file):
mock_which.return_value = "/usr/bin/cat"
mock_run.return_value = MagicMock(returncode=0, stdout="output", stderr="")
save_providers([Provider("cat", "cat")])
call_provider("cat", "My prompt text")
# Verify prompt was passed as input
call_kwargs = mock_run.call_args[1]
assert call_kwargs["input"] == "My prompt text"
def test_environment_variable_expansion(self, temp_providers_file):
"""Provider commands should expand $HOME etc."""
save_providers([
Provider("home-test", "$HOME/bin/my-ai")
])
# This will fail because the command doesn't exist,
# but we can check the error message to verify expansion happened
result = call_provider("home-test", "Test")
# The error should mention the expanded path, not $HOME
assert "$HOME" not in result.error
class TestProviderCommandParsing:
"""Tests for command parsing with shlex."""
@pytest.fixture
def temp_providers_file(self, tmp_path):
providers_file = tmp_path / ".smarttools" / "providers.yaml"
with patch('smarttools.providers.PROVIDERS_FILE', providers_file):
yield providers_file
@patch('shutil.which')
def test_command_with_quotes(self, mock_which, temp_providers_file):
"""Commands with quotes should be parsed correctly."""
mock_which.return_value = None # Will fail, but we test parsing
save_providers([
Provider("quoted", 'my-cmd --arg "value with spaces"')
])
result = call_provider("quoted", "Test")
# Should fail at command-not-found, not at parsing
assert "not found" in result.error.lower()
assert "my-cmd" in result.error
@patch('shutil.which')
def test_command_with_env_vars(self, mock_which, temp_providers_file):
"""Environment variables in commands should be expanded."""
import os
mock_which.return_value = None
save_providers([
Provider("env-test", "$HOME/.local/bin/my-ai")
])
result = call_provider("env-test", "Test")
# Error should show expanded path
home = os.environ.get("HOME", "")
assert "$HOME" not in result.error or home in result.error