feat: load ai command chain from shared config
This commit is contained in:
parent
88c783a278
commit
57255a7e09
|
|
@ -0,0 +1,112 @@
|
||||||
|
"""
|
||||||
|
Shared AI configuration loader for CascadingDev automation.
|
||||||
|
|
||||||
|
The configuration lives at `config/ai.yml` within the repository (for both the
|
||||||
|
maintainer tool and generated projects). It coordinates preferences for:
|
||||||
|
• The automation runner's AI command chain and sentinel token.
|
||||||
|
• Ramble GUI defaults (provider selection and CLI wiring).
|
||||||
|
|
||||||
|
This module keeps the parsing logic close to the automation package so the
|
||||||
|
installer can ship it alongside the pre-commit runner.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
DEFAULT_SENTINEL = "CASCADINGDEV_NO_CHANGES"
|
||||||
|
DEFAULT_COMMAND_CHAIN = ["claude -p"]
|
||||||
|
CONFIG_RELATIVE_PATH = Path("config") / "ai.yml"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_command_chain(raw: str | None) -> List[str]:
|
||||||
|
"""
|
||||||
|
Split a fallback command string into individual commands.
|
||||||
|
|
||||||
|
We reuse the shell-style '||' delimiter so maintainers can export
|
||||||
|
`CDEV_AI_COMMAND="cmdA || cmdB"`. Empty entries are discarded.
|
||||||
|
"""
|
||||||
|
if not raw:
|
||||||
|
return []
|
||||||
|
parts = [segment.strip() for segment in raw.split("||")]
|
||||||
|
return [segment for segment in parts if segment]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RunnerSettings:
|
||||||
|
command_chain: List[str]
|
||||||
|
sentinel: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RambleSettings:
|
||||||
|
default_provider: str
|
||||||
|
providers: Dict[str, Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AISettings:
|
||||||
|
runner: RunnerSettings
|
||||||
|
ramble: RambleSettings
|
||||||
|
|
||||||
|
|
||||||
|
def _load_yaml(path: Path) -> Dict[str, Any]:
|
||||||
|
if not path.exists():
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
with path.open("r", encoding="utf-8") as fh:
|
||||||
|
data = yaml.safe_load(fh) or {}
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return {}
|
||||||
|
return data
|
||||||
|
except OSError:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def load_ai_settings(repo_root: Path) -> AISettings:
|
||||||
|
"""
|
||||||
|
Load the AI configuration for the given repository root.
|
||||||
|
|
||||||
|
Missing files fall back to sensible defaults so the automation remains
|
||||||
|
usable out of the box.
|
||||||
|
"""
|
||||||
|
config_path = (repo_root / CONFIG_RELATIVE_PATH).resolve()
|
||||||
|
data = _load_yaml(config_path)
|
||||||
|
|
||||||
|
runner_data = data.get("runner") if isinstance(data, dict) else {}
|
||||||
|
if not isinstance(runner_data, dict):
|
||||||
|
runner_data = {}
|
||||||
|
|
||||||
|
chain = runner_data.get("command_chain", [])
|
||||||
|
if not isinstance(chain, list):
|
||||||
|
chain = []
|
||||||
|
command_chain = [str(entry).strip() for entry in chain if str(entry).strip()]
|
||||||
|
if not command_chain:
|
||||||
|
command_chain = DEFAULT_COMMAND_CHAIN.copy()
|
||||||
|
|
||||||
|
sentinel = runner_data.get("sentinel")
|
||||||
|
if not isinstance(sentinel, str) or not sentinel.strip():
|
||||||
|
sentinel = DEFAULT_SENTINEL
|
||||||
|
|
||||||
|
ramble_data = data.get("ramble") if isinstance(data, dict) else {}
|
||||||
|
if not isinstance(ramble_data, dict):
|
||||||
|
ramble_data = {}
|
||||||
|
|
||||||
|
default_provider = ramble_data.get("default_provider")
|
||||||
|
if not isinstance(default_provider, str) or not default_provider.strip():
|
||||||
|
default_provider = "mock"
|
||||||
|
|
||||||
|
providers = ramble_data.get("providers", {})
|
||||||
|
if not isinstance(providers, dict):
|
||||||
|
providers = {}
|
||||||
|
|
||||||
|
return AISettings(
|
||||||
|
runner=RunnerSettings(command_chain=command_chain, sentinel=sentinel),
|
||||||
|
ramble=RambleSettings(
|
||||||
|
default_provider=default_provider,
|
||||||
|
providers={str(k): v for k, v in providers.items() if isinstance(v, dict)},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
@ -11,9 +11,15 @@ import re
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from automation.ai_config import (
|
||||||
|
DEFAULT_COMMAND_CHAIN,
|
||||||
|
DEFAULT_SENTINEL,
|
||||||
|
load_ai_settings,
|
||||||
|
parse_command_chain,
|
||||||
|
)
|
||||||
from automation.config import RulesConfig
|
from automation.config import RulesConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -21,36 +27,34 @@ class PatchGenerationError(RuntimeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
NO_CHANGES_SENTINEL = "CASCADINGDEV_NO_CHANGES"
|
|
||||||
COMMAND_DELIMITER = "||"
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_command_chain(raw: str) -> list[str]:
|
|
||||||
"""
|
|
||||||
Split a fallback command string into individual commands.
|
|
||||||
|
|
||||||
Commands are separated by '||' to align with common shell semantics.
|
|
||||||
Whitespace-only segments are ignored.
|
|
||||||
"""
|
|
||||||
parts = [segment.strip() for segment in raw.split(COMMAND_DELIMITER)]
|
|
||||||
commands = [segment for segment in parts if segment]
|
|
||||||
return commands or ["claude -p"]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
"""Configuration for invoking the AI model command."""
|
"""Configuration for invoking the AI model command."""
|
||||||
raw: str | None = None
|
commands: list[str]
|
||||||
commands: list[str] = field(init=False, repr=False)
|
sentinel: str
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
@classmethod
|
||||||
source = self.raw or os.environ.get("CDEV_AI_COMMAND", "claude -p")
|
def from_sources(cls, repo_root: Path, override: str | None = None) -> "ModelConfig":
|
||||||
self.commands = _parse_command_chain(source)
|
"""
|
||||||
|
Build a ModelConfig using CLI override, environment variable, or
|
||||||
|
repository configuration (in that precedence order).
|
||||||
|
"""
|
||||||
|
settings = load_ai_settings(repo_root)
|
||||||
|
sentinel = settings.runner.sentinel or DEFAULT_SENTINEL
|
||||||
|
|
||||||
@property
|
if override and override.strip():
|
||||||
def command(self) -> str:
|
commands = parse_command_chain(override)
|
||||||
"""Return the primary command (first in the fallback chain)."""
|
else:
|
||||||
return self.commands[0]
|
env_override = os.environ.get("CDEV_AI_COMMAND")
|
||||||
|
if env_override and env_override.strip():
|
||||||
|
commands = parse_command_chain(env_override)
|
||||||
|
else:
|
||||||
|
commands = settings.runner.command_chain
|
||||||
|
|
||||||
|
if not commands:
|
||||||
|
commands = settings.runner.command_chain or DEFAULT_COMMAND_CHAIN.copy()
|
||||||
|
|
||||||
|
return cls(commands=commands, sentinel=sentinel)
|
||||||
|
|
||||||
|
|
||||||
def generate_output(
|
def generate_output(
|
||||||
|
|
@ -108,6 +112,7 @@ def generate_output(
|
||||||
source_content=source_content,
|
source_content=source_content,
|
||||||
output_content=output_preimage,
|
output_content=output_preimage,
|
||||||
instruction=instruction,
|
instruction=instruction,
|
||||||
|
no_change_token=model.sentinel,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call the AI model and get its raw output.
|
# Call the AI model and get its raw output.
|
||||||
|
|
@ -342,6 +347,7 @@ def build_prompt(
|
||||||
source_content: str,
|
source_content: str,
|
||||||
output_content: str,
|
output_content: str,
|
||||||
instruction: str,
|
instruction: str,
|
||||||
|
no_change_token: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Constructs the full prompt string for the AI model by formatting the
|
Constructs the full prompt string for the AI model by formatting the
|
||||||
|
|
@ -365,7 +371,7 @@ def build_prompt(
|
||||||
source_content=source_content.strip(),
|
source_content=source_content.strip(),
|
||||||
output_content=output_content.strip() or "(empty)", # Indicate if output content is empty.
|
output_content=output_content.strip() or "(empty)", # Indicate if output content is empty.
|
||||||
instruction=instruction.strip(),
|
instruction=instruction.strip(),
|
||||||
no_change_token=NO_CHANGES_SENTINEL,
|
no_change_token=no_change_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -400,7 +406,7 @@ def call_model(model: ModelConfig, prompt: str, cwd: Path) -> tuple[str, bool]:
|
||||||
stderr = result.stderr.strip()
|
stderr = result.stderr.strip()
|
||||||
|
|
||||||
if stdout:
|
if stdout:
|
||||||
if stdout == NO_CHANGES_SENTINEL:
|
if stdout == model.sentinel:
|
||||||
return raw_stdout, True
|
return raw_stdout, True
|
||||||
if "API Error:" in raw_stdout and "Overloaded" in raw_stdout:
|
if "API Error:" in raw_stdout and "Overloaded" in raw_stdout:
|
||||||
raise PatchGenerationError("Claude API is overloaded (500 error) - please retry later")
|
raise PatchGenerationError("Claude API is overloaded (500 error) - please retry later")
|
||||||
|
|
|
||||||
|
|
@ -133,7 +133,7 @@ def main(argv: list[str] | None = None) -> int:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# Instantiate the model config and delegate to the processing pipeline.
|
# Instantiate the model config and delegate to the processing pipeline.
|
||||||
model = ModelConfig(args.model)
|
model = ModelConfig.from_sources(repo_root, args.model)
|
||||||
return process(repo_root, rules, model)
|
return process(repo_root, rules, model)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from pathlib import Path
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from automation.config import RulesConfig
|
from automation.config import RulesConfig
|
||||||
|
from automation.ai_config import DEFAULT_SENTINEL
|
||||||
from automation.patcher import ModelConfig, PatchGenerationError, generate_output
|
from automation.patcher import ModelConfig, PatchGenerationError, generate_output
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -39,7 +40,7 @@ diff --git a/Docs/features/FR_1/discussions/example.discussion.sum.md b/Docs/fea
|
||||||
patch_file = tmp_path / "patch.txt"
|
patch_file = tmp_path / "patch.txt"
|
||||||
patch_file.write_text(patch_text, encoding="utf-8")
|
patch_file.write_text(patch_text, encoding="utf-8")
|
||||||
|
|
||||||
model = ModelConfig(f"bash -lc 'cat {patch_file.as_posix()}'")
|
model = ModelConfig(commands=[f"bash -lc 'cat {patch_file.as_posix()}'"], sentinel=DEFAULT_SENTINEL)
|
||||||
rules = RulesConfig(root=temp_repo, global_rules={"file_associations": {}, "rules": {}})
|
rules = RulesConfig(root=temp_repo, global_rules={"file_associations": {}, "rules": {}})
|
||||||
|
|
||||||
generate_output(
|
generate_output(
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import textwrap
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from automation.config import RulesConfig
|
from automation.config import RulesConfig
|
||||||
|
from automation.ai_config import DEFAULT_SENTINEL
|
||||||
from automation.patcher import ModelConfig
|
from automation.patcher import ModelConfig
|
||||||
from automation.runner import process
|
from automation.runner import process
|
||||||
|
|
||||||
|
|
@ -57,7 +58,7 @@ diff --git a/Docs/features/FR_1/discussions/example.discussion.sum.md b/Docs/fea
|
||||||
patch_file.write_text(patch_text, encoding="utf-8")
|
patch_file.write_text(patch_text, encoding="utf-8")
|
||||||
|
|
||||||
rules = RulesConfig.load(repo)
|
rules = RulesConfig.load(repo)
|
||||||
model = ModelConfig(f"bash -lc 'cat {patch_file.as_posix()}'")
|
model = ModelConfig(commands=[f"bash -lc 'cat {patch_file.as_posix()}'"], sentinel=DEFAULT_SENTINEL)
|
||||||
|
|
||||||
rc = process(repo, rules, model)
|
rc = process(repo, rules, model)
|
||||||
assert rc == 0
|
assert rc == 0
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue