CascadingDev/automation/ai_config.py

150 lines
4.6 KiB
Python

"""
Shared AI configuration loader for CascadingDev automation.
The configuration lives at `config/ai.yml` within the repository (for both the
maintainer tool and generated projects). It coordinates preferences for:
• The automation runner's AI command chain and sentinel token.
• Ramble GUI defaults (provider selection and CLI wiring).
This module keeps the parsing logic close to the automation package so the
installer can ship it alongside the pre-commit runner.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List
import yaml
DEFAULT_SENTINEL = "CASCADINGDEV_NO_CHANGES"
DEFAULT_COMMAND_CHAIN = ["claude -p"]
CONFIG_RELATIVE_PATH = Path("config") / "ai.yml"
def parse_command_chain(raw: str | None) -> List[str]:
"""
Split a fallback command string into individual commands.
We reuse the shell-style '||' delimiter so maintainers can export
`CDEV_AI_COMMAND="cmdA || cmdB"`. Empty entries are discarded.
"""
if not raw:
return []
parts = [segment.strip() for segment in raw.split("||")]
return [segment for segment in parts if segment]
@dataclass
class RunnerSettings:
command_chain: List[str]
command_chain_fast: List[str]
command_chain_quality: List[str]
sentinel: str
def get_chain_for_hint(self, hint: str) -> List[str]:
"""
Return the appropriate command chain based on the model hint.
Args:
hint: 'fast', 'quality', or empty string
Returns:
The appropriate command chain list
"""
if hint == "fast" and self.command_chain_fast:
return self.command_chain_fast
elif hint == "quality" and self.command_chain_quality:
return self.command_chain_quality
else:
return self.command_chain
@dataclass
class RambleSettings:
default_provider: str
providers: Dict[str, Dict[str, Any]]
@dataclass
class AISettings:
runner: RunnerSettings
ramble: RambleSettings
def _load_yaml(path: Path) -> Dict[str, Any]:
if not path.exists():
return {}
try:
with path.open("r", encoding="utf-8") as fh:
data = yaml.safe_load(fh) or {}
if not isinstance(data, dict):
return {}
return data
except OSError:
return {}
def load_ai_settings(repo_root: Path) -> AISettings:
"""
Load the AI configuration for the given repository root.
Missing files fall back to sensible defaults so the automation remains
usable out of the box.
"""
config_path = (repo_root / CONFIG_RELATIVE_PATH).resolve()
data = _load_yaml(config_path)
runner_data = data.get("runner") if isinstance(data, dict) else {}
if not isinstance(runner_data, dict):
runner_data = {}
# Load default command chain
chain = runner_data.get("command_chain", [])
if not isinstance(chain, list):
chain = []
command_chain = [str(entry).strip() for entry in chain if str(entry).strip()]
if not command_chain:
command_chain = DEFAULT_COMMAND_CHAIN.copy()
# Load fast-optimized command chain (optional)
chain_fast = runner_data.get("command_chain_fast", [])
if not isinstance(chain_fast, list):
chain_fast = []
command_chain_fast = [str(entry).strip() for entry in chain_fast if str(entry).strip()]
# Load quality-optimized command chain (optional)
chain_quality = runner_data.get("command_chain_quality", [])
if not isinstance(chain_quality, list):
chain_quality = []
command_chain_quality = [str(entry).strip() for entry in chain_quality if str(entry).strip()]
sentinel = runner_data.get("sentinel")
if not isinstance(sentinel, str) or not sentinel.strip():
sentinel = DEFAULT_SENTINEL
ramble_data = data.get("ramble") if isinstance(data, dict) else {}
if not isinstance(ramble_data, dict):
ramble_data = {}
default_provider = ramble_data.get("default_provider")
if not isinstance(default_provider, str) or not default_provider.strip():
default_provider = "mock"
providers = ramble_data.get("providers", {})
if not isinstance(providers, dict):
providers = {}
return AISettings(
runner=RunnerSettings(
command_chain=command_chain,
command_chain_fast=command_chain_fast,
command_chain_quality=command_chain_quality,
sentinel=sentinel,
),
ramble=RambleSettings(
default_provider=default_provider,
providers={str(k): v for k, v in providers.items() if isinstance(v, dict)},
),
)