"""Registry API client for CmdForge. Handles all HTTP communication with the registry server. """ import hashlib import json import time from dataclasses import dataclass, field from datetime import datetime, timedelta from pathlib import Path from typing import Optional, List, Dict, Any from urllib.parse import urljoin, urlencode import requests from .config import ( load_config, get_registry_url, get_registry_token, get_client_id, CONFIG_DIR ) # Local cache directory CACHE_DIR = CONFIG_DIR / "registry" INDEX_CACHE_FILE = CACHE_DIR / "index.json" INDEX_CACHE_MAX_AGE = timedelta(hours=24) @dataclass class RegistryError(Exception): """Base exception for registry errors.""" code: str message: str details: Optional[Dict] = None http_status: int = 0 def __str__(self): return f"{self.code}: {self.message}" @dataclass class RateLimitError(RegistryError): """Raised when rate limited by the registry.""" retry_after: int = 60 def __init__(self, retry_after: int = 60): super().__init__( code="RATE_LIMITED", message=f"Rate limited. Retry after {retry_after} seconds.", http_status=429 ) self.retry_after = retry_after @dataclass class PaginatedResponse: """Paginated API response.""" data: List[Dict] page: int = 1 per_page: int = 20 total: int = 0 total_pages: int = 0 @dataclass class ToolInfo: """Tool information from the registry.""" owner: str name: str version: str description: str = "" category: str = "" tags: List[str] = field(default_factory=list) downloads: int = 0 deprecated: bool = False deprecated_message: str = "" replacement: str = "" published_at: str = "" readme: str = "" @property def full_name(self) -> str: return f"{self.owner}/{self.name}" @classmethod def from_dict(cls, data: dict) -> "ToolInfo": return cls( owner=data.get("owner", ""), name=data.get("name", ""), version=data.get("version", ""), description=data.get("description", ""), category=data.get("category", ""), tags=data.get("tags", []), downloads=data.get("downloads", 0), deprecated=data.get("deprecated", False), deprecated_message=data.get("deprecated_message", ""), replacement=data.get("replacement", ""), published_at=data.get("published_at", ""), readme=data.get("readme", "") ) @dataclass class DownloadResult: """Result of downloading a tool.""" owner: str name: str resolved_version: str config_yaml: str readme: str = "" class RegistryClient: """Client for interacting with the CmdForge registry API.""" def __init__( self, base_url: Optional[str] = None, token: Optional[str] = None, timeout: int = 30, max_retries: int = 3 ): """ Initialize the registry client. Args: base_url: Registry API base URL (default: from config) token: Auth token for authenticated requests (default: from config) timeout: Request timeout in seconds max_retries: Maximum number of retries for failed requests """ self.base_url = base_url or get_registry_url() self.token = token or get_registry_token() self.timeout = timeout self.max_retries = max_retries self.client_id = get_client_id() # Session for connection pooling self._session = requests.Session() self._session.headers.update({ "User-Agent": "CmdForge-CLI/1.0", "X-CmdForge-Client": "cli/1.0.0", "Accept": "application/json" }) # Add client ID header if self.client_id: self._session.headers["X-Client-ID"] = self.client_id def _url(self, path: str) -> str: """Build full URL from path.""" # Ensure base_url ends without /api/v1 duplication base = self.base_url.rstrip("/") if not path.startswith("/"): path = "/" + path return base + path def _auth_headers(self) -> Dict[str, str]: """Get authentication headers if token is available.""" if self.token: return {"Authorization": f"Bearer {self.token}"} return {} def _request( self, method: str, path: str, params: Optional[Dict] = None, json_data: Optional[Dict] = None, require_auth: bool = False, etag: Optional[str] = None ) -> requests.Response: """ Make an HTTP request with retry logic. Args: method: HTTP method path: API path params: Query parameters json_data: JSON body data require_auth: Whether auth is required etag: ETag for conditional requests Returns: Response object Raises: RegistryError: On API errors RateLimitError: When rate limited """ url = self._url(path) headers = {} if require_auth: if not self.token: raise RegistryError( code="UNAUTHORIZED", message="Authentication required. Set registry token with 'cmdforge config set-token'", http_status=401 ) headers.update(self._auth_headers()) if etag: headers["If-None-Match"] = etag last_error = None for attempt in range(self.max_retries): try: response = self._session.request( method=method, url=url, params=params, json=json_data, headers=headers, timeout=self.timeout ) # Handle rate limiting if response.status_code == 429: retry_after = int(response.headers.get("Retry-After", 60)) if attempt < self.max_retries - 1: time.sleep(min(retry_after, 30)) # Cap wait at 30s per attempt continue raise RateLimitError(retry_after=retry_after) # Handle server errors with retry if response.status_code >= 500: if attempt < self.max_retries - 1: time.sleep(2 ** attempt) # Exponential backoff continue return response except requests.exceptions.Timeout: last_error = RegistryError( code="TIMEOUT", message="Request timed out" ) if attempt < self.max_retries - 1: time.sleep(2 ** attempt) continue except requests.exceptions.ConnectionError: last_error = RegistryError( code="CONNECTION_ERROR", message="Could not connect to registry" ) if attempt < self.max_retries - 1: time.sleep(2 ** attempt) continue raise last_error or RegistryError( code="REQUEST_FAILED", message="Request failed after retries" ) def _handle_error_response(self, response: requests.Response) -> None: """Parse and raise appropriate error from error response.""" try: data = response.json() error = data.get("error", {}) raise RegistryError( code=error.get("code", "UNKNOWN_ERROR"), message=error.get("message", "Unknown error"), details=error.get("details"), http_status=response.status_code ) except (json.JSONDecodeError, KeyError): raise RegistryError( code="UNKNOWN_ERROR", message=f"HTTP {response.status_code}: {response.text[:200]}", http_status=response.status_code ) # ------------------------------------------------------------------------- # Public API Methods # ------------------------------------------------------------------------- def list_tools( self, category: Optional[str] = None, page: int = 1, per_page: int = 20, sort: str = "downloads", order: str = "desc" ) -> PaginatedResponse: """ List tools from the registry. Args: category: Filter by category page: Page number (1-indexed) per_page: Items per page (max 100) sort: Sort field (downloads, published_at, name) order: Sort order (asc, desc) Returns: PaginatedResponse with tool data """ params = { "page": page, "per_page": min(per_page, 100), "sort": sort, "order": order } if category: params["category"] = category response = self._request("GET", "/tools", params=params) if response.status_code != 200: self._handle_error_response(response) data = response.json() meta = data.get("meta", {}) return PaginatedResponse( data=data.get("data", []), page=meta.get("page", page), per_page=meta.get("per_page", per_page), total=meta.get("total", 0), total_pages=meta.get("total_pages", 0) ) def search_tools( self, query: str, category: Optional[str] = None, page: int = 1, per_page: int = 20, sort: str = "relevance" ) -> PaginatedResponse: """ Search for tools in the registry. Args: query: Search query category: Filter by category page: Page number per_page: Items per page sort: Sort field (relevance, downloads, published_at) Returns: PaginatedResponse with matching tools """ params = { "q": query, "page": page, "per_page": min(per_page, 100), "sort": sort } if category: params["category"] = category response = self._request("GET", "/tools/search", params=params) if response.status_code != 200: self._handle_error_response(response) data = response.json() meta = data.get("meta", {}) return PaginatedResponse( data=data.get("data", []), page=meta.get("page", page), per_page=meta.get("per_page", per_page), total=meta.get("total", 0), total_pages=meta.get("total_pages", 0) ) def get_tool(self, owner: str, name: str) -> ToolInfo: """ Get detailed information about a tool. Args: owner: Tool owner (namespace) name: Tool name Returns: ToolInfo object """ response = self._request("GET", f"/tools/{owner}/{name}") if response.status_code == 404: raise RegistryError( code="TOOL_NOT_FOUND", message=f"Tool '{owner}/{name}' not found", http_status=404 ) if response.status_code != 200: self._handle_error_response(response) data = response.json().get("data", {}) return ToolInfo.from_dict(data) def get_tool_versions(self, owner: str, name: str) -> List[str]: """ Get all versions of a tool. Args: owner: Tool owner name: Tool name Returns: List of version strings (sorted newest first) """ response = self._request("GET", f"/tools/{owner}/{name}/versions") if response.status_code == 404: raise RegistryError( code="TOOL_NOT_FOUND", message=f"Tool '{owner}/{name}' not found", http_status=404 ) if response.status_code != 200: self._handle_error_response(response) data = response.json() return data.get("data", {}).get("versions", []) def download_tool( self, owner: str, name: str, version: Optional[str] = None, install: bool = True ) -> DownloadResult: """ Download a tool's configuration. Args: owner: Tool owner name: Tool name version: Version or constraint (default: latest) install: Whether to count as install for stats Returns: DownloadResult with config YAML """ params = {"install": str(install).lower()} if version: params["version"] = version response = self._request( "GET", f"/tools/{owner}/{name}/download", params=params ) if response.status_code == 404: error_data = {} try: error_data = response.json().get("error", {}) except json.JSONDecodeError: pass code = error_data.get("code", "TOOL_NOT_FOUND") message = error_data.get("message", f"Tool '{owner}/{name}' not found") raise RegistryError( code=code, message=message, details=error_data.get("details"), http_status=404 ) if response.status_code != 200: self._handle_error_response(response) data = response.json().get("data", {}) return DownloadResult( owner=data.get("owner", owner), name=data.get("name", name), resolved_version=data.get("resolved_version", ""), config_yaml=data.get("config", ""), readme=data.get("readme", "") ) def get_categories(self) -> List[Dict[str, Any]]: """ Get list of tool categories. Returns: List of category dicts with name, description, icon """ response = self._request("GET", "/categories") if response.status_code != 200: self._handle_error_response(response) return response.json().get("data", []) def publish_tool( self, config_yaml: str, readme: str = "", dry_run: bool = False ) -> Dict[str, Any]: """ Publish a tool to the registry. Args: config_yaml: Tool configuration YAML content readme: README.md content dry_run: If True, validate without publishing Returns: Dict with PR URL or validation results """ payload = { "config": config_yaml, "readme": readme, "dry_run": dry_run } response = self._request( "POST", "/tools", json_data=payload, require_auth=True ) if response.status_code == 409: # Version already exists self._handle_error_response(response) if response.status_code not in (200, 201): self._handle_error_response(response) return response.json().get("data", {}) def get_my_tools(self) -> List[ToolInfo]: """ Get tools published by the authenticated user. Returns: List of ToolInfo objects """ response = self._request("GET", "/me/tools", require_auth=True) if response.status_code != 200: self._handle_error_response(response) tools = response.json().get("data", []) return [ToolInfo.from_dict(t) for t in tools] def get_popular_tools(self, limit: int = 10) -> List[ToolInfo]: """ Get most popular tools. Args: limit: Maximum number of tools to return Returns: List of ToolInfo objects """ response = self._request( "GET", "/stats/popular", params={"limit": limit} ) if response.status_code != 200: self._handle_error_response(response) tools = response.json().get("data", []) return [ToolInfo.from_dict(t) for t in tools] # ------------------------------------------------------------------------- # Index Caching # ------------------------------------------------------------------------- def get_index(self, force_refresh: bool = False) -> Dict[str, Any]: """ Get the full tool index, using cache when possible. Args: force_refresh: Force refresh from server Returns: Index dict with tools list """ # Check cache first if not force_refresh: cached = self._load_cached_index() if cached: return cached # Fetch from server etag = self._get_cached_etag() response = self._request("GET", "/index.json", etag=etag) if response.status_code == 304: # Not modified, use cache cached = self._load_cached_index() if cached: return cached if response.status_code != 200: # Try to use stale cache on error cached = self._load_cached_index(ignore_age=True) if cached: return cached self._handle_error_response(response) data = response.json() # Cache the response new_etag = response.headers.get("ETag") self._save_cached_index(data, new_etag) return data def _load_cached_index(self, ignore_age: bool = False) -> Optional[Dict]: """Load cached index if valid.""" if not INDEX_CACHE_FILE.exists(): return None try: cache_data = json.loads(INDEX_CACHE_FILE.read_text()) # Check age if not ignore_age: cached_at = datetime.fromisoformat(cache_data.get("_cached_at", "")) if datetime.now() - cached_at > INDEX_CACHE_MAX_AGE: return None # Verify checksum if not self._verify_index_checksum(cache_data): return None return cache_data except (json.JSONDecodeError, KeyError, ValueError): return None def _save_cached_index(self, data: Dict, etag: Optional[str] = None) -> None: """Save index to cache.""" CACHE_DIR.mkdir(parents=True, exist_ok=True) data["_cached_at"] = datetime.now().isoformat() if etag: data["_etag"] = etag INDEX_CACHE_FILE.write_text(json.dumps(data, indent=2)) def _get_cached_etag(self) -> Optional[str]: """Get ETag from cached index.""" if not INDEX_CACHE_FILE.exists(): return None try: cache_data = json.loads(INDEX_CACHE_FILE.read_text()) return cache_data.get("_etag") except (json.JSONDecodeError, KeyError): return None def _verify_index_checksum(self, data: Dict) -> bool: """Verify cached index integrity.""" checksum = data.get("checksum", "") if not checksum: return True # No checksum to verify # Compute checksum of tools list tools = data.get("tools", []) content = json.dumps(tools, sort_keys=True) computed = "sha256:" + hashlib.sha256(content.encode()).hexdigest() return computed == checksum def clear_cache(self) -> None: """Clear the local index cache.""" if INDEX_CACHE_FILE.exists(): INDEX_CACHE_FILE.unlink() # ------------------------------------------------------------------------- # Convenience functions # ------------------------------------------------------------------------- def get_client() -> RegistryClient: """Get a configured registry client instance.""" return RegistryClient() def search(query: str, **kwargs) -> PaginatedResponse: """Search the registry for tools.""" return get_client().search_tools(query, **kwargs) def install_tool(tool_spec: str, version: Optional[str] = None) -> DownloadResult: """ Download a tool for installation. Args: tool_spec: Tool specification (owner/name or just name) version: Version constraint Returns: DownloadResult with config YAML """ client = get_client() # Parse tool spec if "/" in tool_spec: owner, name = tool_spec.split("/", 1) else: # Shorthand - try official namespace first owner = "official" name = tool_spec try: return client.download_tool(owner, name, version=version, install=True) except RegistryError as e: if e.code == "TOOL_NOT_FOUND" and owner == "official": # Fall back to searching for most popular tool with this name results = client.search_tools(name, per_page=1) if results.data: first = results.data[0] return client.download_tool( first["owner"], first["name"], version=version, install=True ) raise