orchestrated-discussions/.venv/lib/python3.12/site-packages/cmdforge/registry_client.py

733 lines
21 KiB
Python

"""Registry API client for CmdForge.
Handles all HTTP communication with the registry server.
"""
import hashlib
import json
import time
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional, List, Dict, Any
from urllib.parse import urljoin, urlencode
import requests
from .config import (
load_config,
get_registry_url,
get_registry_token,
get_client_id,
CONFIG_DIR
)
# Local cache directory
CACHE_DIR = CONFIG_DIR / "registry"
INDEX_CACHE_FILE = CACHE_DIR / "index.json"
INDEX_CACHE_MAX_AGE = timedelta(hours=24)
@dataclass
class RegistryError(Exception):
"""Base exception for registry errors."""
code: str
message: str
details: Optional[Dict] = None
http_status: int = 0
def __str__(self):
return f"{self.code}: {self.message}"
@dataclass
class RateLimitError(RegistryError):
"""Raised when rate limited by the registry."""
retry_after: int = 60
def __init__(self, retry_after: int = 60):
super().__init__(
code="RATE_LIMITED",
message=f"Rate limited. Retry after {retry_after} seconds.",
http_status=429
)
self.retry_after = retry_after
@dataclass
class PaginatedResponse:
"""Paginated API response."""
data: List[Dict]
page: int = 1
per_page: int = 20
total: int = 0
total_pages: int = 0
@dataclass
class ToolInfo:
"""Tool information from the registry."""
owner: str
name: str
version: str
description: str = ""
category: str = ""
tags: List[str] = field(default_factory=list)
downloads: int = 0
deprecated: bool = False
deprecated_message: str = ""
replacement: str = ""
published_at: str = ""
readme: str = ""
@property
def full_name(self) -> str:
return f"{self.owner}/{self.name}"
@classmethod
def from_dict(cls, data: dict) -> "ToolInfo":
return cls(
owner=data.get("owner", ""),
name=data.get("name", ""),
version=data.get("version", ""),
description=data.get("description", ""),
category=data.get("category", ""),
tags=data.get("tags", []),
downloads=data.get("downloads", 0),
deprecated=data.get("deprecated", False),
deprecated_message=data.get("deprecated_message", ""),
replacement=data.get("replacement", ""),
published_at=data.get("published_at", ""),
readme=data.get("readme", "")
)
@dataclass
class DownloadResult:
"""Result of downloading a tool."""
owner: str
name: str
resolved_version: str
config_yaml: str
readme: str = ""
class RegistryClient:
"""Client for interacting with the CmdForge registry API."""
def __init__(
self,
base_url: Optional[str] = None,
token: Optional[str] = None,
timeout: int = 30,
max_retries: int = 3
):
"""
Initialize the registry client.
Args:
base_url: Registry API base URL (default: from config)
token: Auth token for authenticated requests (default: from config)
timeout: Request timeout in seconds
max_retries: Maximum number of retries for failed requests
"""
self.base_url = base_url or get_registry_url()
self.token = token or get_registry_token()
self.timeout = timeout
self.max_retries = max_retries
self.client_id = get_client_id()
# Session for connection pooling
self._session = requests.Session()
self._session.headers.update({
"User-Agent": "CmdForge-CLI/1.0",
"X-CmdForge-Client": "cli/1.0.0",
"Accept": "application/json"
})
# Add client ID header
if self.client_id:
self._session.headers["X-Client-ID"] = self.client_id
def _url(self, path: str) -> str:
"""Build full URL from path."""
# Ensure base_url ends without /api/v1 duplication
base = self.base_url.rstrip("/")
if not path.startswith("/"):
path = "/" + path
return base + path
def _auth_headers(self) -> Dict[str, str]:
"""Get authentication headers if token is available."""
if self.token:
return {"Authorization": f"Bearer {self.token}"}
return {}
def _request(
self,
method: str,
path: str,
params: Optional[Dict] = None,
json_data: Optional[Dict] = None,
require_auth: bool = False,
etag: Optional[str] = None
) -> requests.Response:
"""
Make an HTTP request with retry logic.
Args:
method: HTTP method
path: API path
params: Query parameters
json_data: JSON body data
require_auth: Whether auth is required
etag: ETag for conditional requests
Returns:
Response object
Raises:
RegistryError: On API errors
RateLimitError: When rate limited
"""
url = self._url(path)
headers = {}
if require_auth:
if not self.token:
raise RegistryError(
code="UNAUTHORIZED",
message="Authentication required. Set registry token with 'cmdforge config set-token'",
http_status=401
)
headers.update(self._auth_headers())
if etag:
headers["If-None-Match"] = etag
last_error = None
for attempt in range(self.max_retries):
try:
response = self._session.request(
method=method,
url=url,
params=params,
json=json_data,
headers=headers,
timeout=self.timeout
)
# Handle rate limiting
if response.status_code == 429:
retry_after = int(response.headers.get("Retry-After", 60))
if attempt < self.max_retries - 1:
time.sleep(min(retry_after, 30)) # Cap wait at 30s per attempt
continue
raise RateLimitError(retry_after=retry_after)
# Handle server errors with retry
if response.status_code >= 500:
if attempt < self.max_retries - 1:
time.sleep(2 ** attempt) # Exponential backoff
continue
return response
except requests.exceptions.Timeout:
last_error = RegistryError(
code="TIMEOUT",
message="Request timed out"
)
if attempt < self.max_retries - 1:
time.sleep(2 ** attempt)
continue
except requests.exceptions.ConnectionError:
last_error = RegistryError(
code="CONNECTION_ERROR",
message="Could not connect to registry"
)
if attempt < self.max_retries - 1:
time.sleep(2 ** attempt)
continue
raise last_error or RegistryError(
code="REQUEST_FAILED",
message="Request failed after retries"
)
def _handle_error_response(self, response: requests.Response) -> None:
"""Parse and raise appropriate error from error response."""
try:
data = response.json()
error = data.get("error", {})
raise RegistryError(
code=error.get("code", "UNKNOWN_ERROR"),
message=error.get("message", "Unknown error"),
details=error.get("details"),
http_status=response.status_code
)
except (json.JSONDecodeError, KeyError):
raise RegistryError(
code="UNKNOWN_ERROR",
message=f"HTTP {response.status_code}: {response.text[:200]}",
http_status=response.status_code
)
# -------------------------------------------------------------------------
# Public API Methods
# -------------------------------------------------------------------------
def list_tools(
self,
category: Optional[str] = None,
page: int = 1,
per_page: int = 20,
sort: str = "downloads",
order: str = "desc"
) -> PaginatedResponse:
"""
List tools from the registry.
Args:
category: Filter by category
page: Page number (1-indexed)
per_page: Items per page (max 100)
sort: Sort field (downloads, published_at, name)
order: Sort order (asc, desc)
Returns:
PaginatedResponse with tool data
"""
params = {
"page": page,
"per_page": min(per_page, 100),
"sort": sort,
"order": order
}
if category:
params["category"] = category
response = self._request("GET", "/tools", params=params)
if response.status_code != 200:
self._handle_error_response(response)
data = response.json()
meta = data.get("meta", {})
return PaginatedResponse(
data=data.get("data", []),
page=meta.get("page", page),
per_page=meta.get("per_page", per_page),
total=meta.get("total", 0),
total_pages=meta.get("total_pages", 0)
)
def search_tools(
self,
query: str,
category: Optional[str] = None,
page: int = 1,
per_page: int = 20,
sort: str = "relevance"
) -> PaginatedResponse:
"""
Search for tools in the registry.
Args:
query: Search query
category: Filter by category
page: Page number
per_page: Items per page
sort: Sort field (relevance, downloads, published_at)
Returns:
PaginatedResponse with matching tools
"""
params = {
"q": query,
"page": page,
"per_page": min(per_page, 100),
"sort": sort
}
if category:
params["category"] = category
response = self._request("GET", "/tools/search", params=params)
if response.status_code != 200:
self._handle_error_response(response)
data = response.json()
meta = data.get("meta", {})
return PaginatedResponse(
data=data.get("data", []),
page=meta.get("page", page),
per_page=meta.get("per_page", per_page),
total=meta.get("total", 0),
total_pages=meta.get("total_pages", 0)
)
def get_tool(self, owner: str, name: str) -> ToolInfo:
"""
Get detailed information about a tool.
Args:
owner: Tool owner (namespace)
name: Tool name
Returns:
ToolInfo object
"""
response = self._request("GET", f"/tools/{owner}/{name}")
if response.status_code == 404:
raise RegistryError(
code="TOOL_NOT_FOUND",
message=f"Tool '{owner}/{name}' not found",
http_status=404
)
if response.status_code != 200:
self._handle_error_response(response)
data = response.json().get("data", {})
return ToolInfo.from_dict(data)
def get_tool_versions(self, owner: str, name: str) -> List[str]:
"""
Get all versions of a tool.
Args:
owner: Tool owner
name: Tool name
Returns:
List of version strings (sorted newest first)
"""
response = self._request("GET", f"/tools/{owner}/{name}/versions")
if response.status_code == 404:
raise RegistryError(
code="TOOL_NOT_FOUND",
message=f"Tool '{owner}/{name}' not found",
http_status=404
)
if response.status_code != 200:
self._handle_error_response(response)
data = response.json()
return data.get("data", {}).get("versions", [])
def download_tool(
self,
owner: str,
name: str,
version: Optional[str] = None,
install: bool = True
) -> DownloadResult:
"""
Download a tool's configuration.
Args:
owner: Tool owner
name: Tool name
version: Version or constraint (default: latest)
install: Whether to count as install for stats
Returns:
DownloadResult with config YAML
"""
params = {"install": str(install).lower()}
if version:
params["version"] = version
response = self._request(
"GET",
f"/tools/{owner}/{name}/download",
params=params
)
if response.status_code == 404:
error_data = {}
try:
error_data = response.json().get("error", {})
except json.JSONDecodeError:
pass
code = error_data.get("code", "TOOL_NOT_FOUND")
message = error_data.get("message", f"Tool '{owner}/{name}' not found")
raise RegistryError(
code=code,
message=message,
details=error_data.get("details"),
http_status=404
)
if response.status_code != 200:
self._handle_error_response(response)
data = response.json().get("data", {})
return DownloadResult(
owner=data.get("owner", owner),
name=data.get("name", name),
resolved_version=data.get("resolved_version", ""),
config_yaml=data.get("config", ""),
readme=data.get("readme", "")
)
def get_categories(self) -> List[Dict[str, Any]]:
"""
Get list of tool categories.
Returns:
List of category dicts with name, description, icon
"""
response = self._request("GET", "/categories")
if response.status_code != 200:
self._handle_error_response(response)
return response.json().get("data", [])
def publish_tool(
self,
config_yaml: str,
readme: str = "",
dry_run: bool = False
) -> Dict[str, Any]:
"""
Publish a tool to the registry.
Args:
config_yaml: Tool configuration YAML content
readme: README.md content
dry_run: If True, validate without publishing
Returns:
Dict with PR URL or validation results
"""
payload = {
"config": config_yaml,
"readme": readme,
"dry_run": dry_run
}
response = self._request(
"POST",
"/tools",
json_data=payload,
require_auth=True
)
if response.status_code == 409:
# Version already exists
self._handle_error_response(response)
if response.status_code not in (200, 201):
self._handle_error_response(response)
return response.json().get("data", {})
def get_my_tools(self) -> List[ToolInfo]:
"""
Get tools published by the authenticated user.
Returns:
List of ToolInfo objects
"""
response = self._request("GET", "/me/tools", require_auth=True)
if response.status_code != 200:
self._handle_error_response(response)
tools = response.json().get("data", [])
return [ToolInfo.from_dict(t) for t in tools]
def get_popular_tools(self, limit: int = 10) -> List[ToolInfo]:
"""
Get most popular tools.
Args:
limit: Maximum number of tools to return
Returns:
List of ToolInfo objects
"""
response = self._request(
"GET",
"/stats/popular",
params={"limit": limit}
)
if response.status_code != 200:
self._handle_error_response(response)
tools = response.json().get("data", [])
return [ToolInfo.from_dict(t) for t in tools]
# -------------------------------------------------------------------------
# Index Caching
# -------------------------------------------------------------------------
def get_index(self, force_refresh: bool = False) -> Dict[str, Any]:
"""
Get the full tool index, using cache when possible.
Args:
force_refresh: Force refresh from server
Returns:
Index dict with tools list
"""
# Check cache first
if not force_refresh:
cached = self._load_cached_index()
if cached:
return cached
# Fetch from server
etag = self._get_cached_etag()
response = self._request("GET", "/index.json", etag=etag)
if response.status_code == 304:
# Not modified, use cache
cached = self._load_cached_index()
if cached:
return cached
if response.status_code != 200:
# Try to use stale cache on error
cached = self._load_cached_index(ignore_age=True)
if cached:
return cached
self._handle_error_response(response)
data = response.json()
# Cache the response
new_etag = response.headers.get("ETag")
self._save_cached_index(data, new_etag)
return data
def _load_cached_index(self, ignore_age: bool = False) -> Optional[Dict]:
"""Load cached index if valid."""
if not INDEX_CACHE_FILE.exists():
return None
try:
cache_data = json.loads(INDEX_CACHE_FILE.read_text())
# Check age
if not ignore_age:
cached_at = datetime.fromisoformat(cache_data.get("_cached_at", ""))
if datetime.now() - cached_at > INDEX_CACHE_MAX_AGE:
return None
# Verify checksum
if not self._verify_index_checksum(cache_data):
return None
return cache_data
except (json.JSONDecodeError, KeyError, ValueError):
return None
def _save_cached_index(self, data: Dict, etag: Optional[str] = None) -> None:
"""Save index to cache."""
CACHE_DIR.mkdir(parents=True, exist_ok=True)
data["_cached_at"] = datetime.now().isoformat()
if etag:
data["_etag"] = etag
INDEX_CACHE_FILE.write_text(json.dumps(data, indent=2))
def _get_cached_etag(self) -> Optional[str]:
"""Get ETag from cached index."""
if not INDEX_CACHE_FILE.exists():
return None
try:
cache_data = json.loads(INDEX_CACHE_FILE.read_text())
return cache_data.get("_etag")
except (json.JSONDecodeError, KeyError):
return None
def _verify_index_checksum(self, data: Dict) -> bool:
"""Verify cached index integrity."""
checksum = data.get("checksum", "")
if not checksum:
return True # No checksum to verify
# Compute checksum of tools list
tools = data.get("tools", [])
content = json.dumps(tools, sort_keys=True)
computed = "sha256:" + hashlib.sha256(content.encode()).hexdigest()
return computed == checksum
def clear_cache(self) -> None:
"""Clear the local index cache."""
if INDEX_CACHE_FILE.exists():
INDEX_CACHE_FILE.unlink()
# -------------------------------------------------------------------------
# Convenience functions
# -------------------------------------------------------------------------
def get_client() -> RegistryClient:
"""Get a configured registry client instance."""
return RegistryClient()
def search(query: str, **kwargs) -> PaginatedResponse:
"""Search the registry for tools."""
return get_client().search_tools(query, **kwargs)
def install_tool(tool_spec: str, version: Optional[str] = None) -> DownloadResult:
"""
Download a tool for installation.
Args:
tool_spec: Tool specification (owner/name or just name)
version: Version constraint
Returns:
DownloadResult with config YAML
"""
client = get_client()
# Parse tool spec
if "/" in tool_spec:
owner, name = tool_spec.split("/", 1)
else:
# Shorthand - try official namespace first
owner = "official"
name = tool_spec
try:
return client.download_tool(owner, name, version=version, install=True)
except RegistryError as e:
if e.code == "TOOL_NOT_FOUND" and owner == "official":
# Fall back to searching for most popular tool with this name
results = client.search_tools(name, per_page=1)
if results.data:
first = results.data[0]
return client.download_tool(
first["owner"],
first["name"],
version=version,
install=True
)
raise