feat: load ai command chain from shared config
This commit is contained in:
parent
88c783a278
commit
57255a7e09
|
|
@ -0,0 +1,112 @@
|
|||
"""
|
||||
Shared AI configuration loader for CascadingDev automation.
|
||||
|
||||
The configuration lives at `config/ai.yml` within the repository (for both the
|
||||
maintainer tool and generated projects). It coordinates preferences for:
|
||||
• The automation runner's AI command chain and sentinel token.
|
||||
• Ramble GUI defaults (provider selection and CLI wiring).
|
||||
|
||||
This module keeps the parsing logic close to the automation package so the
|
||||
installer can ship it alongside the pre-commit runner.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import yaml
|
||||
|
||||
DEFAULT_SENTINEL = "CASCADINGDEV_NO_CHANGES"
|
||||
DEFAULT_COMMAND_CHAIN = ["claude -p"]
|
||||
CONFIG_RELATIVE_PATH = Path("config") / "ai.yml"
|
||||
|
||||
|
||||
def parse_command_chain(raw: str | None) -> List[str]:
|
||||
"""
|
||||
Split a fallback command string into individual commands.
|
||||
|
||||
We reuse the shell-style '||' delimiter so maintainers can export
|
||||
`CDEV_AI_COMMAND="cmdA || cmdB"`. Empty entries are discarded.
|
||||
"""
|
||||
if not raw:
|
||||
return []
|
||||
parts = [segment.strip() for segment in raw.split("||")]
|
||||
return [segment for segment in parts if segment]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunnerSettings:
|
||||
command_chain: List[str]
|
||||
sentinel: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class RambleSettings:
|
||||
default_provider: str
|
||||
providers: Dict[str, Dict[str, Any]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AISettings:
|
||||
runner: RunnerSettings
|
||||
ramble: RambleSettings
|
||||
|
||||
|
||||
def _load_yaml(path: Path) -> Dict[str, Any]:
|
||||
if not path.exists():
|
||||
return {}
|
||||
try:
|
||||
with path.open("r", encoding="utf-8") as fh:
|
||||
data = yaml.safe_load(fh) or {}
|
||||
if not isinstance(data, dict):
|
||||
return {}
|
||||
return data
|
||||
except OSError:
|
||||
return {}
|
||||
|
||||
|
||||
def load_ai_settings(repo_root: Path) -> AISettings:
|
||||
"""
|
||||
Load the AI configuration for the given repository root.
|
||||
|
||||
Missing files fall back to sensible defaults so the automation remains
|
||||
usable out of the box.
|
||||
"""
|
||||
config_path = (repo_root / CONFIG_RELATIVE_PATH).resolve()
|
||||
data = _load_yaml(config_path)
|
||||
|
||||
runner_data = data.get("runner") if isinstance(data, dict) else {}
|
||||
if not isinstance(runner_data, dict):
|
||||
runner_data = {}
|
||||
|
||||
chain = runner_data.get("command_chain", [])
|
||||
if not isinstance(chain, list):
|
||||
chain = []
|
||||
command_chain = [str(entry).strip() for entry in chain if str(entry).strip()]
|
||||
if not command_chain:
|
||||
command_chain = DEFAULT_COMMAND_CHAIN.copy()
|
||||
|
||||
sentinel = runner_data.get("sentinel")
|
||||
if not isinstance(sentinel, str) or not sentinel.strip():
|
||||
sentinel = DEFAULT_SENTINEL
|
||||
|
||||
ramble_data = data.get("ramble") if isinstance(data, dict) else {}
|
||||
if not isinstance(ramble_data, dict):
|
||||
ramble_data = {}
|
||||
|
||||
default_provider = ramble_data.get("default_provider")
|
||||
if not isinstance(default_provider, str) or not default_provider.strip():
|
||||
default_provider = "mock"
|
||||
|
||||
providers = ramble_data.get("providers", {})
|
||||
if not isinstance(providers, dict):
|
||||
providers = {}
|
||||
|
||||
return AISettings(
|
||||
runner=RunnerSettings(command_chain=command_chain, sentinel=sentinel),
|
||||
ramble=RambleSettings(
|
||||
default_provider=default_provider,
|
||||
providers={str(k): v for k, v in providers.items() if isinstance(v, dict)},
|
||||
),
|
||||
)
|
||||
|
|
@ -11,9 +11,15 @@ import re
|
|||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from automation.ai_config import (
|
||||
DEFAULT_COMMAND_CHAIN,
|
||||
DEFAULT_SENTINEL,
|
||||
load_ai_settings,
|
||||
parse_command_chain,
|
||||
)
|
||||
from automation.config import RulesConfig
|
||||
|
||||
|
||||
|
|
@ -21,36 +27,34 @@ class PatchGenerationError(RuntimeError):
|
|||
pass
|
||||
|
||||
|
||||
NO_CHANGES_SENTINEL = "CASCADINGDEV_NO_CHANGES"
|
||||
COMMAND_DELIMITER = "||"
|
||||
|
||||
|
||||
def _parse_command_chain(raw: str) -> list[str]:
|
||||
"""
|
||||
Split a fallback command string into individual commands.
|
||||
|
||||
Commands are separated by '||' to align with common shell semantics.
|
||||
Whitespace-only segments are ignored.
|
||||
"""
|
||||
parts = [segment.strip() for segment in raw.split(COMMAND_DELIMITER)]
|
||||
commands = [segment for segment in parts if segment]
|
||||
return commands or ["claude -p"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Configuration for invoking the AI model command."""
|
||||
raw: str | None = None
|
||||
commands: list[str] = field(init=False, repr=False)
|
||||
commands: list[str]
|
||||
sentinel: str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
source = self.raw or os.environ.get("CDEV_AI_COMMAND", "claude -p")
|
||||
self.commands = _parse_command_chain(source)
|
||||
@classmethod
|
||||
def from_sources(cls, repo_root: Path, override: str | None = None) -> "ModelConfig":
|
||||
"""
|
||||
Build a ModelConfig using CLI override, environment variable, or
|
||||
repository configuration (in that precedence order).
|
||||
"""
|
||||
settings = load_ai_settings(repo_root)
|
||||
sentinel = settings.runner.sentinel or DEFAULT_SENTINEL
|
||||
|
||||
@property
|
||||
def command(self) -> str:
|
||||
"""Return the primary command (first in the fallback chain)."""
|
||||
return self.commands[0]
|
||||
if override and override.strip():
|
||||
commands = parse_command_chain(override)
|
||||
else:
|
||||
env_override = os.environ.get("CDEV_AI_COMMAND")
|
||||
if env_override and env_override.strip():
|
||||
commands = parse_command_chain(env_override)
|
||||
else:
|
||||
commands = settings.runner.command_chain
|
||||
|
||||
if not commands:
|
||||
commands = settings.runner.command_chain or DEFAULT_COMMAND_CHAIN.copy()
|
||||
|
||||
return cls(commands=commands, sentinel=sentinel)
|
||||
|
||||
|
||||
def generate_output(
|
||||
|
|
@ -108,6 +112,7 @@ def generate_output(
|
|||
source_content=source_content,
|
||||
output_content=output_preimage,
|
||||
instruction=instruction,
|
||||
no_change_token=model.sentinel,
|
||||
)
|
||||
|
||||
# Call the AI model and get its raw output.
|
||||
|
|
@ -342,6 +347,7 @@ def build_prompt(
|
|||
source_content: str,
|
||||
output_content: str,
|
||||
instruction: str,
|
||||
no_change_token: str,
|
||||
) -> str:
|
||||
"""
|
||||
Constructs the full prompt string for the AI model by formatting the
|
||||
|
|
@ -365,7 +371,7 @@ def build_prompt(
|
|||
source_content=source_content.strip(),
|
||||
output_content=output_content.strip() or "(empty)", # Indicate if output content is empty.
|
||||
instruction=instruction.strip(),
|
||||
no_change_token=NO_CHANGES_SENTINEL,
|
||||
no_change_token=no_change_token,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -400,7 +406,7 @@ def call_model(model: ModelConfig, prompt: str, cwd: Path) -> tuple[str, bool]:
|
|||
stderr = result.stderr.strip()
|
||||
|
||||
if stdout:
|
||||
if stdout == NO_CHANGES_SENTINEL:
|
||||
if stdout == model.sentinel:
|
||||
return raw_stdout, True
|
||||
if "API Error:" in raw_stdout and "Overloaded" in raw_stdout:
|
||||
raise PatchGenerationError("Claude API is overloaded (500 error) - please retry later")
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ def main(argv: list[str] | None = None) -> int:
|
|||
return 0
|
||||
|
||||
# Instantiate the model config and delegate to the processing pipeline.
|
||||
model = ModelConfig(args.model)
|
||||
model = ModelConfig.from_sources(repo_root, args.model)
|
||||
return process(repo_root, rules, model)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from pathlib import Path
|
|||
import pytest
|
||||
|
||||
from automation.config import RulesConfig
|
||||
from automation.ai_config import DEFAULT_SENTINEL
|
||||
from automation.patcher import ModelConfig, PatchGenerationError, generate_output
|
||||
|
||||
|
||||
|
|
@ -39,7 +40,7 @@ diff --git a/Docs/features/FR_1/discussions/example.discussion.sum.md b/Docs/fea
|
|||
patch_file = tmp_path / "patch.txt"
|
||||
patch_file.write_text(patch_text, encoding="utf-8")
|
||||
|
||||
model = ModelConfig(f"bash -lc 'cat {patch_file.as_posix()}'")
|
||||
model = ModelConfig(commands=[f"bash -lc 'cat {patch_file.as_posix()}'"], sentinel=DEFAULT_SENTINEL)
|
||||
rules = RulesConfig(root=temp_repo, global_rules={"file_associations": {}, "rules": {}})
|
||||
|
||||
generate_output(
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import textwrap
|
|||
import pytest
|
||||
|
||||
from automation.config import RulesConfig
|
||||
from automation.ai_config import DEFAULT_SENTINEL
|
||||
from automation.patcher import ModelConfig
|
||||
from automation.runner import process
|
||||
|
||||
|
|
@ -57,7 +58,7 @@ diff --git a/Docs/features/FR_1/discussions/example.discussion.sum.md b/Docs/fea
|
|||
patch_file.write_text(patch_text, encoding="utf-8")
|
||||
|
||||
rules = RulesConfig.load(repo)
|
||||
model = ModelConfig(f"bash -lc 'cat {patch_file.as_posix()}'")
|
||||
model = ModelConfig(commands=[f"bash -lc 'cat {patch_file.as_posix()}'"], sentinel=DEFAULT_SENTINEL)
|
||||
|
||||
rc = process(repo, rules, model)
|
||||
assert rc == 0
|
||||
|
|
|
|||
Loading…
Reference in New Issue