267 lines
9.2 KiB
Python
267 lines
9.2 KiB
Python
"""Provider abstraction for AI CLI tools."""
|
|
|
|
import os
|
|
import subprocess
|
|
import shutil
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Optional, List
|
|
|
|
import yaml
|
|
|
|
|
|
# Default providers config location
|
|
PROVIDERS_FILE = Path.home() / ".smarttools" / "providers.yaml"
|
|
|
|
|
|
@dataclass
|
|
class Provider:
|
|
"""Definition of an AI provider."""
|
|
name: str
|
|
command: str
|
|
description: str = ""
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
"name": self.name,
|
|
"command": self.command,
|
|
"description": self.description,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict) -> "Provider":
|
|
return cls(
|
|
name=data["name"],
|
|
command=data["command"],
|
|
description=data.get("description", ""),
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ProviderResult:
|
|
"""Result from a provider call."""
|
|
text: str
|
|
success: bool
|
|
error: Optional[str] = None
|
|
|
|
|
|
# Default providers that come pre-configured
|
|
# Profiled with 4-task test: Math, Code, Reasoning, Extract (Dec 2025)
|
|
DEFAULT_PROVIDERS = [
|
|
# TOP PICKS - best value/performance ratio
|
|
Provider("opencode-deepseek", "$HOME/.opencode/bin/opencode run --model deepseek/deepseek-chat", "13s 4/4 | BEST VALUE, cheap, fast, accurate"),
|
|
Provider("opencode-pickle", "$HOME/.opencode/bin/opencode run --model opencode/big-pickle", "13s 4/4 | BEST FREE, accurate"),
|
|
Provider("claude-haiku", "claude -p --model haiku", "14s 4/4 | fast, accurate, best paid option"),
|
|
Provider("codex", "codex exec -", "14s 4/4 | reliable, auto-routes"),
|
|
|
|
# CLAUDE - all accurate, paid, good for code
|
|
Provider("claude", "claude -p", "18s 4/4 | auto-routes to best model"),
|
|
Provider("claude-opus", "claude -p --model opus", "18s 4/4 | highest quality, expensive"),
|
|
Provider("claude-sonnet", "claude -p --model sonnet", "21s 4/4 | balanced quality/speed"),
|
|
|
|
# OPENCODE - additional models
|
|
Provider("opencode-nano", "$HOME/.opencode/bin/opencode run --model opencode/gpt-5-nano", "24s 4/4 | GPT-5 Nano, reliable"),
|
|
Provider("opencode-reasoner", "$HOME/.opencode/bin/opencode run --model deepseek/deepseek-reasoner", "33s 4/4 | complex reasoning, cheap"),
|
|
Provider("opencode-grok", "$HOME/.opencode/bin/opencode run --model opencode/grok-code", "11s 2/4 | fastest but unreliable, FREE"),
|
|
|
|
# GEMINI - slow CLI but good for large docs (1M token context)
|
|
Provider("gemini-flash", "gemini --model gemini-2.5-flash", "28s 4/4 | use this for quick tasks"),
|
|
Provider("gemini", "gemini --model gemini-2.5-pro", "91s 3/4 | slow CLI, best for large docs/PDFs"),
|
|
|
|
# Mock for testing
|
|
Provider("mock", "mock", "Mock provider for testing"),
|
|
]
|
|
|
|
|
|
def get_providers_file() -> Path:
|
|
"""Get the providers config file, creating default if needed."""
|
|
if not PROVIDERS_FILE.exists():
|
|
PROVIDERS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
save_providers(DEFAULT_PROVIDERS)
|
|
return PROVIDERS_FILE
|
|
|
|
|
|
def load_providers() -> List[Provider]:
|
|
"""Load all defined providers."""
|
|
providers_file = get_providers_file()
|
|
|
|
try:
|
|
data = yaml.safe_load(providers_file.read_text())
|
|
if not data or "providers" not in data:
|
|
return DEFAULT_PROVIDERS.copy()
|
|
return [Provider.from_dict(p) for p in data["providers"]]
|
|
except Exception:
|
|
return DEFAULT_PROVIDERS.copy()
|
|
|
|
|
|
def save_providers(providers: List[Provider]):
|
|
"""Save providers to config file."""
|
|
PROVIDERS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
data = {"providers": [p.to_dict() for p in providers]}
|
|
PROVIDERS_FILE.write_text(yaml.dump(data, default_flow_style=False, sort_keys=False))
|
|
|
|
|
|
def get_provider(name: str) -> Optional[Provider]:
|
|
"""Get a provider by name."""
|
|
providers = load_providers()
|
|
for p in providers:
|
|
if p.name == name:
|
|
return p
|
|
return None
|
|
|
|
|
|
def add_provider(provider: Provider) -> bool:
|
|
"""Add or update a provider."""
|
|
providers = load_providers()
|
|
|
|
# Update if exists, otherwise add
|
|
for i, p in enumerate(providers):
|
|
if p.name == provider.name:
|
|
providers[i] = provider
|
|
save_providers(providers)
|
|
return True
|
|
|
|
providers.append(provider)
|
|
save_providers(providers)
|
|
return True
|
|
|
|
|
|
def delete_provider(name: str) -> bool:
|
|
"""Delete a provider by name."""
|
|
providers = load_providers()
|
|
original_len = len(providers)
|
|
providers = [p for p in providers if p.name != name]
|
|
|
|
if len(providers) < original_len:
|
|
save_providers(providers)
|
|
return True
|
|
return False
|
|
|
|
|
|
def call_provider(provider_name: str, prompt: str, timeout: int = 300) -> ProviderResult:
|
|
"""
|
|
Call an AI provider with the given prompt.
|
|
|
|
Args:
|
|
provider_name: Name of the provider to use
|
|
prompt: The prompt to send
|
|
timeout: Maximum execution time in seconds
|
|
|
|
Returns:
|
|
ProviderResult with the response text or error
|
|
"""
|
|
# Handle mock provider specially
|
|
if provider_name.lower() == "mock":
|
|
return mock_provider(prompt)
|
|
|
|
# Look up provider
|
|
provider = get_provider(provider_name)
|
|
if not provider:
|
|
return ProviderResult(
|
|
text="",
|
|
success=False,
|
|
error=f"Provider '{provider_name}' not found. Use 'smarttools providers' to manage providers."
|
|
)
|
|
|
|
# Parse command (expand environment variables)
|
|
cmd = os.path.expandvars(provider.command)
|
|
|
|
# Check if base command exists
|
|
base_cmd = cmd.split()[0]
|
|
# Expand ~ for the which check
|
|
base_cmd_expanded = os.path.expanduser(base_cmd)
|
|
if not shutil.which(base_cmd_expanded) and not os.path.isfile(base_cmd_expanded):
|
|
return ProviderResult(
|
|
text="",
|
|
success=False,
|
|
error=f"Command '{base_cmd}' not found. Is it installed and in PATH?\n\nTo install AI providers, run: smarttools providers install"
|
|
)
|
|
|
|
try:
|
|
result = subprocess.run(
|
|
cmd,
|
|
shell=True,
|
|
input=prompt,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=timeout
|
|
)
|
|
|
|
if result.returncode != 0:
|
|
error_msg = f"Provider exited with code {result.returncode}: {result.stderr}"
|
|
if "not found" in result.stderr.lower() or "not installed" in result.stderr.lower():
|
|
error_msg += "\n\nTo install AI providers, run: smarttools providers install"
|
|
return ProviderResult(
|
|
text="",
|
|
success=False,
|
|
error=error_msg
|
|
)
|
|
|
|
# Warn if output is empty (provider ran but returned nothing)
|
|
if not result.stdout.strip():
|
|
stderr = result.stderr.strip()
|
|
|
|
# Check for OpenCode's ProviderModelNotFoundError
|
|
if "ProviderModelNotFoundError" in stderr or "ModelNotFoundError" in stderr:
|
|
# Extract provider and model info if possible
|
|
import re
|
|
provider_match = re.search(r'providerID:\s*"([^"]+)"', stderr)
|
|
model_match = re.search(r'modelID:\s*"([^"]+)"', stderr)
|
|
provider_id = provider_match.group(1) if provider_match else "unknown"
|
|
model_id = model_match.group(1) if model_match else "unknown"
|
|
|
|
return ProviderResult(
|
|
text="",
|
|
success=False,
|
|
error=f"Model '{model_id}' from provider '{provider_id}' is not available.\n\n"
|
|
f"To fix this, either:\n"
|
|
f" 1. Run 'opencode' to connect the {provider_id} provider\n"
|
|
f" 2. Use --provider to pick a different model (e.g., --provider opencode-pickle)\n"
|
|
f" 3. Run 'smarttools ui' to edit the tool's default provider"
|
|
)
|
|
|
|
stderr_hint = f" (stderr: {stderr[:200]}...)" if len(stderr) > 200 else (f" (stderr: {stderr})" if stderr else "")
|
|
return ProviderResult(
|
|
text="",
|
|
success=False,
|
|
error=f"Provider returned empty output{stderr_hint}.\n\nThis may mean the model is not available. Try a different provider or run: smarttools providers install"
|
|
)
|
|
|
|
return ProviderResult(text=result.stdout, success=True)
|
|
|
|
except subprocess.TimeoutExpired:
|
|
return ProviderResult(
|
|
text="",
|
|
success=False,
|
|
error=f"Provider timed out after {timeout} seconds"
|
|
)
|
|
except Exception as e:
|
|
return ProviderResult(
|
|
text="",
|
|
success=False,
|
|
error=f"Provider error: {str(e)}"
|
|
)
|
|
|
|
|
|
def mock_provider(prompt: str) -> ProviderResult:
|
|
"""
|
|
Return a mock response for testing.
|
|
|
|
Args:
|
|
prompt: The prompt (used for generating mock response)
|
|
|
|
Returns:
|
|
ProviderResult with mock response
|
|
"""
|
|
lines = prompt.strip().split('\n')
|
|
preview = lines[0][:50] + "..." if len(lines[0]) > 50 else lines[0]
|
|
|
|
return ProviderResult(
|
|
text=f"[MOCK RESPONSE]\n"
|
|
f"Prompt length: {len(prompt)} chars, {len(lines)} lines\n"
|
|
f"First line: {preview}\n"
|
|
f"\n"
|
|
f"This is a mock response. Use a real provider for actual output.",
|
|
success=True
|
|
)
|