feat: load ai command chain from shared config

This commit is contained in:
rob 2025-11-01 14:37:57 -03:00
parent 88c783a278
commit 57255a7e09
5 changed files with 151 additions and 31 deletions

112
automation/ai_config.py Normal file
View File

@ -0,0 +1,112 @@
"""
Shared AI configuration loader for CascadingDev automation.
The configuration lives at `config/ai.yml` within the repository (for both the
maintainer tool and generated projects). It coordinates preferences for:
The automation runner's AI command chain and sentinel token.
Ramble GUI defaults (provider selection and CLI wiring).
This module keeps the parsing logic close to the automation package so the
installer can ship it alongside the pre-commit runner.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List
import yaml
DEFAULT_SENTINEL = "CASCADINGDEV_NO_CHANGES"
DEFAULT_COMMAND_CHAIN = ["claude -p"]
CONFIG_RELATIVE_PATH = Path("config") / "ai.yml"
def parse_command_chain(raw: str | None) -> List[str]:
"""
Split a fallback command string into individual commands.
We reuse the shell-style '||' delimiter so maintainers can export
`CDEV_AI_COMMAND="cmdA || cmdB"`. Empty entries are discarded.
"""
if not raw:
return []
parts = [segment.strip() for segment in raw.split("||")]
return [segment for segment in parts if segment]
@dataclass
class RunnerSettings:
command_chain: List[str]
sentinel: str
@dataclass
class RambleSettings:
default_provider: str
providers: Dict[str, Dict[str, Any]]
@dataclass
class AISettings:
runner: RunnerSettings
ramble: RambleSettings
def _load_yaml(path: Path) -> Dict[str, Any]:
if not path.exists():
return {}
try:
with path.open("r", encoding="utf-8") as fh:
data = yaml.safe_load(fh) or {}
if not isinstance(data, dict):
return {}
return data
except OSError:
return {}
def load_ai_settings(repo_root: Path) -> AISettings:
"""
Load the AI configuration for the given repository root.
Missing files fall back to sensible defaults so the automation remains
usable out of the box.
"""
config_path = (repo_root / CONFIG_RELATIVE_PATH).resolve()
data = _load_yaml(config_path)
runner_data = data.get("runner") if isinstance(data, dict) else {}
if not isinstance(runner_data, dict):
runner_data = {}
chain = runner_data.get("command_chain", [])
if not isinstance(chain, list):
chain = []
command_chain = [str(entry).strip() for entry in chain if str(entry).strip()]
if not command_chain:
command_chain = DEFAULT_COMMAND_CHAIN.copy()
sentinel = runner_data.get("sentinel")
if not isinstance(sentinel, str) or not sentinel.strip():
sentinel = DEFAULT_SENTINEL
ramble_data = data.get("ramble") if isinstance(data, dict) else {}
if not isinstance(ramble_data, dict):
ramble_data = {}
default_provider = ramble_data.get("default_provider")
if not isinstance(default_provider, str) or not default_provider.strip():
default_provider = "mock"
providers = ramble_data.get("providers", {})
if not isinstance(providers, dict):
providers = {}
return AISettings(
runner=RunnerSettings(command_chain=command_chain, sentinel=sentinel),
ramble=RambleSettings(
default_provider=default_provider,
providers={str(k): v for k, v in providers.items() if isinstance(v, dict)},
),
)

View File

@ -11,9 +11,15 @@ import re
import shutil import shutil
import subprocess import subprocess
import tempfile import tempfile
from dataclasses import dataclass, field from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from automation.ai_config import (
DEFAULT_COMMAND_CHAIN,
DEFAULT_SENTINEL,
load_ai_settings,
parse_command_chain,
)
from automation.config import RulesConfig from automation.config import RulesConfig
@ -21,36 +27,34 @@ class PatchGenerationError(RuntimeError):
pass pass
NO_CHANGES_SENTINEL = "CASCADINGDEV_NO_CHANGES"
COMMAND_DELIMITER = "||"
def _parse_command_chain(raw: str) -> list[str]:
"""
Split a fallback command string into individual commands.
Commands are separated by '||' to align with common shell semantics.
Whitespace-only segments are ignored.
"""
parts = [segment.strip() for segment in raw.split(COMMAND_DELIMITER)]
commands = [segment for segment in parts if segment]
return commands or ["claude -p"]
@dataclass @dataclass
class ModelConfig: class ModelConfig:
"""Configuration for invoking the AI model command.""" """Configuration for invoking the AI model command."""
raw: str | None = None commands: list[str]
commands: list[str] = field(init=False, repr=False) sentinel: str
def __post_init__(self) -> None: @classmethod
source = self.raw or os.environ.get("CDEV_AI_COMMAND", "claude -p") def from_sources(cls, repo_root: Path, override: str | None = None) -> "ModelConfig":
self.commands = _parse_command_chain(source) """
Build a ModelConfig using CLI override, environment variable, or
repository configuration (in that precedence order).
"""
settings = load_ai_settings(repo_root)
sentinel = settings.runner.sentinel or DEFAULT_SENTINEL
@property if override and override.strip():
def command(self) -> str: commands = parse_command_chain(override)
"""Return the primary command (first in the fallback chain).""" else:
return self.commands[0] env_override = os.environ.get("CDEV_AI_COMMAND")
if env_override and env_override.strip():
commands = parse_command_chain(env_override)
else:
commands = settings.runner.command_chain
if not commands:
commands = settings.runner.command_chain or DEFAULT_COMMAND_CHAIN.copy()
return cls(commands=commands, sentinel=sentinel)
def generate_output( def generate_output(
@ -108,6 +112,7 @@ def generate_output(
source_content=source_content, source_content=source_content,
output_content=output_preimage, output_content=output_preimage,
instruction=instruction, instruction=instruction,
no_change_token=model.sentinel,
) )
# Call the AI model and get its raw output. # Call the AI model and get its raw output.
@ -342,6 +347,7 @@ def build_prompt(
source_content: str, source_content: str,
output_content: str, output_content: str,
instruction: str, instruction: str,
no_change_token: str,
) -> str: ) -> str:
""" """
Constructs the full prompt string for the AI model by formatting the Constructs the full prompt string for the AI model by formatting the
@ -365,7 +371,7 @@ def build_prompt(
source_content=source_content.strip(), source_content=source_content.strip(),
output_content=output_content.strip() or "(empty)", # Indicate if output content is empty. output_content=output_content.strip() or "(empty)", # Indicate if output content is empty.
instruction=instruction.strip(), instruction=instruction.strip(),
no_change_token=NO_CHANGES_SENTINEL, no_change_token=no_change_token,
) )
@ -400,7 +406,7 @@ def call_model(model: ModelConfig, prompt: str, cwd: Path) -> tuple[str, bool]:
stderr = result.stderr.strip() stderr = result.stderr.strip()
if stdout: if stdout:
if stdout == NO_CHANGES_SENTINEL: if stdout == model.sentinel:
return raw_stdout, True return raw_stdout, True
if "API Error:" in raw_stdout and "Overloaded" in raw_stdout: if "API Error:" in raw_stdout and "Overloaded" in raw_stdout:
raise PatchGenerationError("Claude API is overloaded (500 error) - please retry later") raise PatchGenerationError("Claude API is overloaded (500 error) - please retry later")

View File

@ -133,7 +133,7 @@ def main(argv: list[str] | None = None) -> int:
return 0 return 0
# Instantiate the model config and delegate to the processing pipeline. # Instantiate the model config and delegate to the processing pipeline.
model = ModelConfig(args.model) model = ModelConfig.from_sources(repo_root, args.model)
return process(repo_root, rules, model) return process(repo_root, rules, model)

View File

@ -4,6 +4,7 @@ from pathlib import Path
import pytest import pytest
from automation.config import RulesConfig from automation.config import RulesConfig
from automation.ai_config import DEFAULT_SENTINEL
from automation.patcher import ModelConfig, PatchGenerationError, generate_output from automation.patcher import ModelConfig, PatchGenerationError, generate_output
@ -39,7 +40,7 @@ diff --git a/Docs/features/FR_1/discussions/example.discussion.sum.md b/Docs/fea
patch_file = tmp_path / "patch.txt" patch_file = tmp_path / "patch.txt"
patch_file.write_text(patch_text, encoding="utf-8") patch_file.write_text(patch_text, encoding="utf-8")
model = ModelConfig(f"bash -lc 'cat {patch_file.as_posix()}'") model = ModelConfig(commands=[f"bash -lc 'cat {patch_file.as_posix()}'"], sentinel=DEFAULT_SENTINEL)
rules = RulesConfig(root=temp_repo, global_rules={"file_associations": {}, "rules": {}}) rules = RulesConfig(root=temp_repo, global_rules={"file_associations": {}, "rules": {}})
generate_output( generate_output(

View File

@ -5,6 +5,7 @@ import textwrap
import pytest import pytest
from automation.config import RulesConfig from automation.config import RulesConfig
from automation.ai_config import DEFAULT_SENTINEL
from automation.patcher import ModelConfig from automation.patcher import ModelConfig
from automation.runner import process from automation.runner import process
@ -57,7 +58,7 @@ diff --git a/Docs/features/FR_1/discussions/example.discussion.sum.md b/Docs/fea
patch_file.write_text(patch_text, encoding="utf-8") patch_file.write_text(patch_text, encoding="utf-8")
rules = RulesConfig.load(repo) rules = RulesConfig.load(repo)
model = ModelConfig(f"bash -lc 'cat {patch_file.as_posix()}'") model = ModelConfig(commands=[f"bash -lc 'cat {patch_file.as_posix()}'"], sentinel=DEFAULT_SENTINEL)
rc = process(repo, rules, model) rc = process(repo, rules, model)
assert rc == 0 assert rc == 0