From 57255a7e09a49d1fa0356b19a799e953744af47f Mon Sep 17 00:00:00 2001 From: rob Date: Sat, 1 Nov 2025 14:37:57 -0300 Subject: [PATCH] feat: load ai command chain from shared config --- automation/ai_config.py | 112 ++++++++++++++++++++++++++++++++++++++++ automation/patcher.py | 62 ++++++++++++---------- automation/runner.py | 2 +- tests/test_patcher.py | 3 +- tests/test_runner.py | 3 +- 5 files changed, 151 insertions(+), 31 deletions(-) create mode 100644 automation/ai_config.py diff --git a/automation/ai_config.py b/automation/ai_config.py new file mode 100644 index 0000000..d153f06 --- /dev/null +++ b/automation/ai_config.py @@ -0,0 +1,112 @@ +""" +Shared AI configuration loader for CascadingDev automation. + +The configuration lives at `config/ai.yml` within the repository (for both the +maintainer tool and generated projects). It coordinates preferences for: + • The automation runner's AI command chain and sentinel token. + • Ramble GUI defaults (provider selection and CLI wiring). + +This module keeps the parsing logic close to the automation package so the +installer can ship it alongside the pre-commit runner. +""" +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List + +import yaml + +DEFAULT_SENTINEL = "CASCADINGDEV_NO_CHANGES" +DEFAULT_COMMAND_CHAIN = ["claude -p"] +CONFIG_RELATIVE_PATH = Path("config") / "ai.yml" + + +def parse_command_chain(raw: str | None) -> List[str]: + """ + Split a fallback command string into individual commands. + + We reuse the shell-style '||' delimiter so maintainers can export + `CDEV_AI_COMMAND="cmdA || cmdB"`. Empty entries are discarded. + """ + if not raw: + return [] + parts = [segment.strip() for segment in raw.split("||")] + return [segment for segment in parts if segment] + + +@dataclass +class RunnerSettings: + command_chain: List[str] + sentinel: str + + +@dataclass +class RambleSettings: + default_provider: str + providers: Dict[str, Dict[str, Any]] + + +@dataclass +class AISettings: + runner: RunnerSettings + ramble: RambleSettings + + +def _load_yaml(path: Path) -> Dict[str, Any]: + if not path.exists(): + return {} + try: + with path.open("r", encoding="utf-8") as fh: + data = yaml.safe_load(fh) or {} + if not isinstance(data, dict): + return {} + return data + except OSError: + return {} + + +def load_ai_settings(repo_root: Path) -> AISettings: + """ + Load the AI configuration for the given repository root. + + Missing files fall back to sensible defaults so the automation remains + usable out of the box. + """ + config_path = (repo_root / CONFIG_RELATIVE_PATH).resolve() + data = _load_yaml(config_path) + + runner_data = data.get("runner") if isinstance(data, dict) else {} + if not isinstance(runner_data, dict): + runner_data = {} + + chain = runner_data.get("command_chain", []) + if not isinstance(chain, list): + chain = [] + command_chain = [str(entry).strip() for entry in chain if str(entry).strip()] + if not command_chain: + command_chain = DEFAULT_COMMAND_CHAIN.copy() + + sentinel = runner_data.get("sentinel") + if not isinstance(sentinel, str) or not sentinel.strip(): + sentinel = DEFAULT_SENTINEL + + ramble_data = data.get("ramble") if isinstance(data, dict) else {} + if not isinstance(ramble_data, dict): + ramble_data = {} + + default_provider = ramble_data.get("default_provider") + if not isinstance(default_provider, str) or not default_provider.strip(): + default_provider = "mock" + + providers = ramble_data.get("providers", {}) + if not isinstance(providers, dict): + providers = {} + + return AISettings( + runner=RunnerSettings(command_chain=command_chain, sentinel=sentinel), + ramble=RambleSettings( + default_provider=default_provider, + providers={str(k): v for k, v in providers.items() if isinstance(v, dict)}, + ), + ) diff --git a/automation/patcher.py b/automation/patcher.py index 5419d36..30881a1 100644 --- a/automation/patcher.py +++ b/automation/patcher.py @@ -11,9 +11,15 @@ import re import shutil import subprocess import tempfile -from dataclasses import dataclass, field +from dataclasses import dataclass from pathlib import Path +from automation.ai_config import ( + DEFAULT_COMMAND_CHAIN, + DEFAULT_SENTINEL, + load_ai_settings, + parse_command_chain, +) from automation.config import RulesConfig @@ -21,36 +27,34 @@ class PatchGenerationError(RuntimeError): pass -NO_CHANGES_SENTINEL = "CASCADINGDEV_NO_CHANGES" -COMMAND_DELIMITER = "||" - - -def _parse_command_chain(raw: str) -> list[str]: - """ - Split a fallback command string into individual commands. - - Commands are separated by '||' to align with common shell semantics. - Whitespace-only segments are ignored. - """ - parts = [segment.strip() for segment in raw.split(COMMAND_DELIMITER)] - commands = [segment for segment in parts if segment] - return commands or ["claude -p"] - - @dataclass class ModelConfig: """Configuration for invoking the AI model command.""" - raw: str | None = None - commands: list[str] = field(init=False, repr=False) + commands: list[str] + sentinel: str - def __post_init__(self) -> None: - source = self.raw or os.environ.get("CDEV_AI_COMMAND", "claude -p") - self.commands = _parse_command_chain(source) + @classmethod + def from_sources(cls, repo_root: Path, override: str | None = None) -> "ModelConfig": + """ + Build a ModelConfig using CLI override, environment variable, or + repository configuration (in that precedence order). + """ + settings = load_ai_settings(repo_root) + sentinel = settings.runner.sentinel or DEFAULT_SENTINEL - @property - def command(self) -> str: - """Return the primary command (first in the fallback chain).""" - return self.commands[0] + if override and override.strip(): + commands = parse_command_chain(override) + else: + env_override = os.environ.get("CDEV_AI_COMMAND") + if env_override and env_override.strip(): + commands = parse_command_chain(env_override) + else: + commands = settings.runner.command_chain + + if not commands: + commands = settings.runner.command_chain or DEFAULT_COMMAND_CHAIN.copy() + + return cls(commands=commands, sentinel=sentinel) def generate_output( @@ -108,6 +112,7 @@ def generate_output( source_content=source_content, output_content=output_preimage, instruction=instruction, + no_change_token=model.sentinel, ) # Call the AI model and get its raw output. @@ -342,6 +347,7 @@ def build_prompt( source_content: str, output_content: str, instruction: str, + no_change_token: str, ) -> str: """ Constructs the full prompt string for the AI model by formatting the @@ -365,7 +371,7 @@ def build_prompt( source_content=source_content.strip(), output_content=output_content.strip() or "(empty)", # Indicate if output content is empty. instruction=instruction.strip(), - no_change_token=NO_CHANGES_SENTINEL, + no_change_token=no_change_token, ) @@ -400,7 +406,7 @@ def call_model(model: ModelConfig, prompt: str, cwd: Path) -> tuple[str, bool]: stderr = result.stderr.strip() if stdout: - if stdout == NO_CHANGES_SENTINEL: + if stdout == model.sentinel: return raw_stdout, True if "API Error:" in raw_stdout and "Overloaded" in raw_stdout: raise PatchGenerationError("Claude API is overloaded (500 error) - please retry later") diff --git a/automation/runner.py b/automation/runner.py index 306d358..89a831c 100644 --- a/automation/runner.py +++ b/automation/runner.py @@ -133,7 +133,7 @@ def main(argv: list[str] | None = None) -> int: return 0 # Instantiate the model config and delegate to the processing pipeline. - model = ModelConfig(args.model) + model = ModelConfig.from_sources(repo_root, args.model) return process(repo_root, rules, model) diff --git a/tests/test_patcher.py b/tests/test_patcher.py index 8825d24..c67eece 100644 --- a/tests/test_patcher.py +++ b/tests/test_patcher.py @@ -4,6 +4,7 @@ from pathlib import Path import pytest from automation.config import RulesConfig +from automation.ai_config import DEFAULT_SENTINEL from automation.patcher import ModelConfig, PatchGenerationError, generate_output @@ -39,7 +40,7 @@ diff --git a/Docs/features/FR_1/discussions/example.discussion.sum.md b/Docs/fea patch_file = tmp_path / "patch.txt" patch_file.write_text(patch_text, encoding="utf-8") - model = ModelConfig(f"bash -lc 'cat {patch_file.as_posix()}'") + model = ModelConfig(commands=[f"bash -lc 'cat {patch_file.as_posix()}'"], sentinel=DEFAULT_SENTINEL) rules = RulesConfig(root=temp_repo, global_rules={"file_associations": {}, "rules": {}}) generate_output( diff --git a/tests/test_runner.py b/tests/test_runner.py index 376def3..363e380 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -5,6 +5,7 @@ import textwrap import pytest from automation.config import RulesConfig +from automation.ai_config import DEFAULT_SENTINEL from automation.patcher import ModelConfig from automation.runner import process @@ -57,7 +58,7 @@ diff --git a/Docs/features/FR_1/discussions/example.discussion.sum.md b/Docs/fea patch_file.write_text(patch_text, encoding="utf-8") rules = RulesConfig.load(repo) - model = ModelConfig(f"bash -lc 'cat {patch_file.as_posix()}'") + model = ModelConfig(commands=[f"bash -lc 'cat {patch_file.as_posix()}'"], sentinel=DEFAULT_SENTINEL) rc = process(repo, rules, model) assert rc == 0