379 lines
12 KiB
Python
379 lines
12 KiB
Python
"""Tests for providers.py - AI provider abstraction."""
|
|
|
|
import tempfile
|
|
from pathlib import Path
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
import pytest
|
|
import yaml
|
|
|
|
from smarttools.providers import (
|
|
Provider, ProviderResult,
|
|
load_providers, save_providers, get_provider,
|
|
add_provider, delete_provider,
|
|
call_provider, mock_provider,
|
|
DEFAULT_PROVIDERS
|
|
)
|
|
|
|
|
|
class TestProvider:
|
|
"""Tests for Provider dataclass."""
|
|
|
|
def test_create_basic(self):
|
|
provider = Provider(name="test", command="echo test")
|
|
assert provider.name == "test"
|
|
assert provider.command == "echo test"
|
|
assert provider.description == ""
|
|
|
|
def test_create_with_description(self):
|
|
provider = Provider(
|
|
name="claude",
|
|
command="claude -p",
|
|
description="Anthropic Claude"
|
|
)
|
|
assert provider.description == "Anthropic Claude"
|
|
|
|
def test_to_dict(self):
|
|
provider = Provider(
|
|
name="gpt4",
|
|
command="openai chat",
|
|
description="OpenAI GPT-4"
|
|
)
|
|
d = provider.to_dict()
|
|
assert d["name"] == "gpt4"
|
|
assert d["command"] == "openai chat"
|
|
assert d["description"] == "OpenAI GPT-4"
|
|
|
|
def test_from_dict(self):
|
|
data = {
|
|
"name": "local",
|
|
"command": "ollama run llama2",
|
|
"description": "Local Ollama"
|
|
}
|
|
provider = Provider.from_dict(data)
|
|
assert provider.name == "local"
|
|
assert provider.command == "ollama run llama2"
|
|
assert provider.description == "Local Ollama"
|
|
|
|
def test_from_dict_missing_description(self):
|
|
data = {"name": "test", "command": "test cmd"}
|
|
provider = Provider.from_dict(data)
|
|
assert provider.description == ""
|
|
|
|
def test_roundtrip(self):
|
|
original = Provider(
|
|
name="custom",
|
|
command="my-ai --prompt",
|
|
description="My custom provider"
|
|
)
|
|
restored = Provider.from_dict(original.to_dict())
|
|
assert restored.name == original.name
|
|
assert restored.command == original.command
|
|
assert restored.description == original.description
|
|
|
|
|
|
class TestProviderResult:
|
|
"""Tests for ProviderResult dataclass."""
|
|
|
|
def test_success_result(self):
|
|
result = ProviderResult(text="Hello!", success=True)
|
|
assert result.text == "Hello!"
|
|
assert result.success is True
|
|
assert result.error is None
|
|
|
|
def test_error_result(self):
|
|
result = ProviderResult(text="", success=False, error="API timeout")
|
|
assert result.text == ""
|
|
assert result.success is False
|
|
assert result.error == "API timeout"
|
|
|
|
|
|
class TestMockProvider:
|
|
"""Tests for the mock provider."""
|
|
|
|
def test_mock_returns_success(self):
|
|
result = mock_provider("Test prompt")
|
|
assert result.success is True
|
|
assert "[MOCK RESPONSE]" in result.text
|
|
|
|
def test_mock_includes_prompt_info(self):
|
|
result = mock_provider("This is a test prompt")
|
|
assert "Prompt length:" in result.text
|
|
assert "chars" in result.text
|
|
|
|
def test_mock_shows_first_line_preview(self):
|
|
result = mock_provider("First line here\nSecond line")
|
|
assert "First line" in result.text
|
|
|
|
def test_mock_truncates_long_first_line(self):
|
|
long_line = "x" * 100
|
|
result = mock_provider(long_line)
|
|
assert "..." in result.text
|
|
|
|
def test_mock_counts_lines(self):
|
|
prompt = "line1\nline2\nline3"
|
|
result = mock_provider(prompt)
|
|
assert "3 lines" in result.text
|
|
|
|
|
|
class TestProviderPersistence:
|
|
"""Tests for provider save/load operations."""
|
|
|
|
@pytest.fixture
|
|
def temp_providers_file(self, tmp_path):
|
|
"""Create a temporary providers file."""
|
|
providers_file = tmp_path / ".smarttools" / "providers.yaml"
|
|
with patch('smarttools.providers.PROVIDERS_FILE', providers_file):
|
|
yield providers_file
|
|
|
|
def test_save_and_load_providers(self, temp_providers_file):
|
|
providers = [
|
|
Provider("test1", "cmd1", "Description 1"),
|
|
Provider("test2", "cmd2", "Description 2")
|
|
]
|
|
save_providers(providers)
|
|
|
|
loaded = load_providers()
|
|
assert len(loaded) == 2
|
|
assert loaded[0].name == "test1"
|
|
assert loaded[1].name == "test2"
|
|
|
|
def test_get_provider_exists(self, temp_providers_file):
|
|
providers = [
|
|
Provider("target", "target-cmd", "Target provider")
|
|
]
|
|
save_providers(providers)
|
|
|
|
result = get_provider("target")
|
|
assert result is not None
|
|
assert result.name == "target"
|
|
assert result.command == "target-cmd"
|
|
|
|
def test_get_provider_not_exists(self, temp_providers_file):
|
|
save_providers([])
|
|
|
|
result = get_provider("nonexistent")
|
|
assert result is None
|
|
|
|
def test_add_new_provider(self, temp_providers_file):
|
|
save_providers([])
|
|
|
|
new_provider = Provider("new", "new-cmd", "New provider")
|
|
add_provider(new_provider)
|
|
|
|
loaded = load_providers()
|
|
assert any(p.name == "new" for p in loaded)
|
|
|
|
def test_add_provider_updates_existing(self, temp_providers_file):
|
|
save_providers([
|
|
Provider("existing", "old-cmd", "Old description")
|
|
])
|
|
|
|
updated = Provider("existing", "new-cmd", "New description")
|
|
add_provider(updated)
|
|
|
|
loaded = load_providers()
|
|
existing = next(p for p in loaded if p.name == "existing")
|
|
assert existing.command == "new-cmd"
|
|
assert existing.description == "New description"
|
|
|
|
def test_delete_provider(self, temp_providers_file):
|
|
save_providers([
|
|
Provider("keep", "keep-cmd"),
|
|
Provider("delete", "delete-cmd")
|
|
])
|
|
|
|
result = delete_provider("delete")
|
|
assert result is True
|
|
|
|
loaded = load_providers()
|
|
assert not any(p.name == "delete" for p in loaded)
|
|
assert any(p.name == "keep" for p in loaded)
|
|
|
|
def test_delete_nonexistent_provider(self, temp_providers_file):
|
|
save_providers([])
|
|
|
|
result = delete_provider("nonexistent")
|
|
assert result is False
|
|
|
|
|
|
class TestDefaultProviders:
|
|
"""Tests for default providers."""
|
|
|
|
def test_default_providers_exist(self):
|
|
assert len(DEFAULT_PROVIDERS) > 0
|
|
|
|
def test_mock_in_defaults(self):
|
|
assert any(p.name == "mock" for p in DEFAULT_PROVIDERS)
|
|
|
|
def test_claude_in_defaults(self):
|
|
assert any(p.name == "claude" for p in DEFAULT_PROVIDERS)
|
|
|
|
def test_all_defaults_have_commands(self):
|
|
for provider in DEFAULT_PROVIDERS:
|
|
assert provider.command, f"Provider {provider.name} has no command"
|
|
|
|
|
|
class TestCallProvider:
|
|
"""Tests for call_provider function."""
|
|
|
|
@pytest.fixture
|
|
def temp_providers_file(self, tmp_path):
|
|
"""Create a temporary providers file."""
|
|
providers_file = tmp_path / ".smarttools" / "providers.yaml"
|
|
with patch('smarttools.providers.PROVIDERS_FILE', providers_file):
|
|
yield providers_file
|
|
|
|
def test_call_mock_provider(self, temp_providers_file):
|
|
"""Mock provider should work without subprocess."""
|
|
result = call_provider("mock", "Test prompt")
|
|
assert result.success is True
|
|
assert "[MOCK RESPONSE]" in result.text
|
|
|
|
def test_call_nonexistent_provider(self, temp_providers_file):
|
|
save_providers([])
|
|
|
|
result = call_provider("nonexistent", "Test")
|
|
assert result.success is False
|
|
assert "not found" in result.error.lower()
|
|
|
|
@patch('subprocess.run')
|
|
@patch('shutil.which')
|
|
def test_call_real_provider_success(self, mock_which, mock_run, temp_providers_file):
|
|
# Setup
|
|
mock_which.return_value = "/usr/bin/echo"
|
|
mock_run.return_value = MagicMock(
|
|
returncode=0,
|
|
stdout="AI response here",
|
|
stderr=""
|
|
)
|
|
save_providers([Provider("echo-test", "echo test")])
|
|
|
|
result = call_provider("echo-test", "Prompt")
|
|
|
|
assert result.success is True
|
|
assert result.text == "AI response here"
|
|
|
|
@patch('subprocess.run')
|
|
@patch('shutil.which')
|
|
def test_call_provider_nonzero_exit(self, mock_which, mock_run, temp_providers_file):
|
|
mock_which.return_value = "/usr/bin/cmd"
|
|
mock_run.return_value = MagicMock(
|
|
returncode=1,
|
|
stdout="",
|
|
stderr="Error occurred"
|
|
)
|
|
save_providers([Provider("failing", "failing-cmd")])
|
|
|
|
result = call_provider("failing", "Prompt")
|
|
|
|
assert result.success is False
|
|
assert "exited with code 1" in result.error
|
|
|
|
@patch('subprocess.run')
|
|
@patch('shutil.which')
|
|
def test_call_provider_empty_output(self, mock_which, mock_run, temp_providers_file):
|
|
mock_which.return_value = "/usr/bin/cmd"
|
|
mock_run.return_value = MagicMock(
|
|
returncode=0,
|
|
stdout=" ", # Only whitespace
|
|
stderr=""
|
|
)
|
|
save_providers([Provider("empty", "empty-cmd")])
|
|
|
|
result = call_provider("empty", "Prompt")
|
|
|
|
assert result.success is False
|
|
assert "empty output" in result.error.lower()
|
|
|
|
@patch('subprocess.run')
|
|
@patch('shutil.which')
|
|
def test_call_provider_timeout(self, mock_which, mock_run, temp_providers_file):
|
|
import subprocess
|
|
|
|
mock_which.return_value = "/usr/bin/slow"
|
|
mock_run.side_effect = subprocess.TimeoutExpired(cmd="slow", timeout=10)
|
|
save_providers([Provider("slow", "slow-cmd")])
|
|
|
|
result = call_provider("slow", "Prompt", timeout=10)
|
|
|
|
assert result.success is False
|
|
assert "timed out" in result.error.lower()
|
|
|
|
@patch('shutil.which')
|
|
def test_call_provider_command_not_found(self, mock_which, temp_providers_file):
|
|
mock_which.return_value = None
|
|
save_providers([Provider("missing", "nonexistent-binary")])
|
|
|
|
result = call_provider("missing", "Prompt")
|
|
|
|
assert result.success is False
|
|
assert "not found" in result.error.lower()
|
|
|
|
@patch('subprocess.run')
|
|
@patch('shutil.which')
|
|
def test_provider_receives_prompt_as_stdin(self, mock_which, mock_run, temp_providers_file):
|
|
mock_which.return_value = "/usr/bin/cat"
|
|
mock_run.return_value = MagicMock(returncode=0, stdout="output", stderr="")
|
|
save_providers([Provider("cat", "cat")])
|
|
|
|
call_provider("cat", "My prompt text")
|
|
|
|
# Verify prompt was passed as input
|
|
call_kwargs = mock_run.call_args[1]
|
|
assert call_kwargs["input"] == "My prompt text"
|
|
|
|
def test_environment_variable_expansion(self, temp_providers_file):
|
|
"""Provider commands should expand $HOME etc."""
|
|
save_providers([
|
|
Provider("home-test", "$HOME/bin/my-ai")
|
|
])
|
|
|
|
# This will fail because the command doesn't exist,
|
|
# but we can check the error message to verify expansion happened
|
|
result = call_provider("home-test", "Test")
|
|
|
|
# The error should mention the expanded path, not $HOME
|
|
assert "$HOME" not in result.error
|
|
|
|
|
|
class TestProviderCommandParsing:
|
|
"""Tests for command parsing with shlex."""
|
|
|
|
@pytest.fixture
|
|
def temp_providers_file(self, tmp_path):
|
|
providers_file = tmp_path / ".smarttools" / "providers.yaml"
|
|
with patch('smarttools.providers.PROVIDERS_FILE', providers_file):
|
|
yield providers_file
|
|
|
|
@patch('shutil.which')
|
|
def test_command_with_quotes(self, mock_which, temp_providers_file):
|
|
"""Commands with quotes should be parsed correctly."""
|
|
mock_which.return_value = None # Will fail, but we test parsing
|
|
|
|
save_providers([
|
|
Provider("quoted", 'my-cmd --arg "value with spaces"')
|
|
])
|
|
|
|
result = call_provider("quoted", "Test")
|
|
|
|
# Should fail at command-not-found, not at parsing
|
|
assert "not found" in result.error.lower()
|
|
assert "my-cmd" in result.error
|
|
|
|
@patch('shutil.which')
|
|
def test_command_with_env_vars(self, mock_which, temp_providers_file):
|
|
"""Environment variables in commands should be expanded."""
|
|
import os
|
|
mock_which.return_value = None
|
|
|
|
save_providers([
|
|
Provider("env-test", "$HOME/.local/bin/my-ai")
|
|
])
|
|
|
|
result = call_provider("env-test", "Test")
|
|
|
|
# Error should show expanded path
|
|
home = os.environ.get("HOME", "")
|
|
assert "$HOME" not in result.error or home in result.error
|