"""Tests for providers.py - AI provider abstraction.""" import tempfile from pathlib import Path from unittest.mock import patch, MagicMock import pytest import yaml from smarttools.providers import ( Provider, ProviderResult, load_providers, save_providers, get_provider, add_provider, delete_provider, call_provider, mock_provider, DEFAULT_PROVIDERS ) class TestProvider: """Tests for Provider dataclass.""" def test_create_basic(self): provider = Provider(name="test", command="echo test") assert provider.name == "test" assert provider.command == "echo test" assert provider.description == "" def test_create_with_description(self): provider = Provider( name="claude", command="claude -p", description="Anthropic Claude" ) assert provider.description == "Anthropic Claude" def test_to_dict(self): provider = Provider( name="gpt4", command="openai chat", description="OpenAI GPT-4" ) d = provider.to_dict() assert d["name"] == "gpt4" assert d["command"] == "openai chat" assert d["description"] == "OpenAI GPT-4" def test_from_dict(self): data = { "name": "local", "command": "ollama run llama2", "description": "Local Ollama" } provider = Provider.from_dict(data) assert provider.name == "local" assert provider.command == "ollama run llama2" assert provider.description == "Local Ollama" def test_from_dict_missing_description(self): data = {"name": "test", "command": "test cmd"} provider = Provider.from_dict(data) assert provider.description == "" def test_roundtrip(self): original = Provider( name="custom", command="my-ai --prompt", description="My custom provider" ) restored = Provider.from_dict(original.to_dict()) assert restored.name == original.name assert restored.command == original.command assert restored.description == original.description class TestProviderResult: """Tests for ProviderResult dataclass.""" def test_success_result(self): result = ProviderResult(text="Hello!", success=True) assert result.text == "Hello!" assert result.success is True assert result.error is None def test_error_result(self): result = ProviderResult(text="", success=False, error="API timeout") assert result.text == "" assert result.success is False assert result.error == "API timeout" class TestMockProvider: """Tests for the mock provider.""" def test_mock_returns_success(self): result = mock_provider("Test prompt") assert result.success is True assert "[MOCK RESPONSE]" in result.text def test_mock_includes_prompt_info(self): result = mock_provider("This is a test prompt") assert "Prompt length:" in result.text assert "chars" in result.text def test_mock_shows_first_line_preview(self): result = mock_provider("First line here\nSecond line") assert "First line" in result.text def test_mock_truncates_long_first_line(self): long_line = "x" * 100 result = mock_provider(long_line) assert "..." in result.text def test_mock_counts_lines(self): prompt = "line1\nline2\nline3" result = mock_provider(prompt) assert "3 lines" in result.text class TestProviderPersistence: """Tests for provider save/load operations.""" @pytest.fixture def temp_providers_file(self, tmp_path): """Create a temporary providers file.""" providers_file = tmp_path / ".smarttools" / "providers.yaml" with patch('smarttools.providers.PROVIDERS_FILE', providers_file): yield providers_file def test_save_and_load_providers(self, temp_providers_file): providers = [ Provider("test1", "cmd1", "Description 1"), Provider("test2", "cmd2", "Description 2") ] save_providers(providers) loaded = load_providers() assert len(loaded) == 2 assert loaded[0].name == "test1" assert loaded[1].name == "test2" def test_get_provider_exists(self, temp_providers_file): providers = [ Provider("target", "target-cmd", "Target provider") ] save_providers(providers) result = get_provider("target") assert result is not None assert result.name == "target" assert result.command == "target-cmd" def test_get_provider_not_exists(self, temp_providers_file): save_providers([]) result = get_provider("nonexistent") assert result is None def test_add_new_provider(self, temp_providers_file): save_providers([]) new_provider = Provider("new", "new-cmd", "New provider") add_provider(new_provider) loaded = load_providers() assert any(p.name == "new" for p in loaded) def test_add_provider_updates_existing(self, temp_providers_file): save_providers([ Provider("existing", "old-cmd", "Old description") ]) updated = Provider("existing", "new-cmd", "New description") add_provider(updated) loaded = load_providers() existing = next(p for p in loaded if p.name == "existing") assert existing.command == "new-cmd" assert existing.description == "New description" def test_delete_provider(self, temp_providers_file): save_providers([ Provider("keep", "keep-cmd"), Provider("delete", "delete-cmd") ]) result = delete_provider("delete") assert result is True loaded = load_providers() assert not any(p.name == "delete" for p in loaded) assert any(p.name == "keep" for p in loaded) def test_delete_nonexistent_provider(self, temp_providers_file): save_providers([]) result = delete_provider("nonexistent") assert result is False class TestDefaultProviders: """Tests for default providers.""" def test_default_providers_exist(self): assert len(DEFAULT_PROVIDERS) > 0 def test_mock_in_defaults(self): assert any(p.name == "mock" for p in DEFAULT_PROVIDERS) def test_claude_in_defaults(self): assert any(p.name == "claude" for p in DEFAULT_PROVIDERS) def test_all_defaults_have_commands(self): for provider in DEFAULT_PROVIDERS: assert provider.command, f"Provider {provider.name} has no command" class TestCallProvider: """Tests for call_provider function.""" @pytest.fixture def temp_providers_file(self, tmp_path): """Create a temporary providers file.""" providers_file = tmp_path / ".smarttools" / "providers.yaml" with patch('smarttools.providers.PROVIDERS_FILE', providers_file): yield providers_file def test_call_mock_provider(self, temp_providers_file): """Mock provider should work without subprocess.""" result = call_provider("mock", "Test prompt") assert result.success is True assert "[MOCK RESPONSE]" in result.text def test_call_nonexistent_provider(self, temp_providers_file): save_providers([]) result = call_provider("nonexistent", "Test") assert result.success is False assert "not found" in result.error.lower() @patch('subprocess.run') @patch('shutil.which') def test_call_real_provider_success(self, mock_which, mock_run, temp_providers_file): # Setup mock_which.return_value = "/usr/bin/echo" mock_run.return_value = MagicMock( returncode=0, stdout="AI response here", stderr="" ) save_providers([Provider("echo-test", "echo test")]) result = call_provider("echo-test", "Prompt") assert result.success is True assert result.text == "AI response here" @patch('subprocess.run') @patch('shutil.which') def test_call_provider_nonzero_exit(self, mock_which, mock_run, temp_providers_file): mock_which.return_value = "/usr/bin/cmd" mock_run.return_value = MagicMock( returncode=1, stdout="", stderr="Error occurred" ) save_providers([Provider("failing", "failing-cmd")]) result = call_provider("failing", "Prompt") assert result.success is False assert "exited with code 1" in result.error @patch('subprocess.run') @patch('shutil.which') def test_call_provider_empty_output(self, mock_which, mock_run, temp_providers_file): mock_which.return_value = "/usr/bin/cmd" mock_run.return_value = MagicMock( returncode=0, stdout=" ", # Only whitespace stderr="" ) save_providers([Provider("empty", "empty-cmd")]) result = call_provider("empty", "Prompt") assert result.success is False assert "empty output" in result.error.lower() @patch('subprocess.run') @patch('shutil.which') def test_call_provider_timeout(self, mock_which, mock_run, temp_providers_file): import subprocess mock_which.return_value = "/usr/bin/slow" mock_run.side_effect = subprocess.TimeoutExpired(cmd="slow", timeout=10) save_providers([Provider("slow", "slow-cmd")]) result = call_provider("slow", "Prompt", timeout=10) assert result.success is False assert "timed out" in result.error.lower() @patch('shutil.which') def test_call_provider_command_not_found(self, mock_which, temp_providers_file): mock_which.return_value = None save_providers([Provider("missing", "nonexistent-binary")]) result = call_provider("missing", "Prompt") assert result.success is False assert "not found" in result.error.lower() @patch('subprocess.run') @patch('shutil.which') def test_provider_receives_prompt_as_stdin(self, mock_which, mock_run, temp_providers_file): mock_which.return_value = "/usr/bin/cat" mock_run.return_value = MagicMock(returncode=0, stdout="output", stderr="") save_providers([Provider("cat", "cat")]) call_provider("cat", "My prompt text") # Verify prompt was passed as input call_kwargs = mock_run.call_args[1] assert call_kwargs["input"] == "My prompt text" def test_environment_variable_expansion(self, temp_providers_file): """Provider commands should expand $HOME etc.""" save_providers([ Provider("home-test", "$HOME/bin/my-ai") ]) # This will fail because the command doesn't exist, # but we can check the error message to verify expansion happened result = call_provider("home-test", "Test") # The error should mention the expanded path, not $HOME assert "$HOME" not in result.error class TestProviderCommandParsing: """Tests for command parsing with shlex.""" @pytest.fixture def temp_providers_file(self, tmp_path): providers_file = tmp_path / ".smarttools" / "providers.yaml" with patch('smarttools.providers.PROVIDERS_FILE', providers_file): yield providers_file @patch('shutil.which') def test_command_with_quotes(self, mock_which, temp_providers_file): """Commands with quotes should be parsed correctly.""" mock_which.return_value = None # Will fail, but we test parsing save_providers([ Provider("quoted", 'my-cmd --arg "value with spaces"') ]) result = call_provider("quoted", "Test") # Should fail at command-not-found, not at parsing assert "not found" in result.error.lower() assert "my-cmd" in result.error @patch('shutil.which') def test_command_with_env_vars(self, mock_which, temp_providers_file): """Environment variables in commands should be expanded.""" import os mock_which.return_value = None save_providers([ Provider("env-test", "$HOME/.local/bin/my-ai") ]) result = call_provider("env-test", "Test") # Error should show expanded path home = os.environ.get("HOME", "") assert "$HOME" not in result.error or home in result.error