feat: add ai fallback chain with no-change sentinel

This commit is contained in:
rob 2025-11-01 14:25:22 -03:00
parent eed12ce749
commit feb8580b3a
4 changed files with 106 additions and 53 deletions

View File

@ -11,7 +11,7 @@ import re
import shutil
import subprocess
import tempfile
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from automation.config import RulesConfig
@ -21,11 +21,36 @@ class PatchGenerationError(RuntimeError):
pass
NO_CHANGES_SENTINEL = "CASCADINGDEV_NO_CHANGES"
COMMAND_DELIMITER = "||"
def _parse_command_chain(raw: str) -> list[str]:
"""
Split a fallback command string into individual commands.
Commands are separated by '||' to align with common shell semantics.
Whitespace-only segments are ignored.
"""
parts = [segment.strip() for segment in raw.split(COMMAND_DELIMITER)]
commands = [segment for segment in parts if segment]
return commands or ["claude -p"]
@dataclass
class ModelConfig:
"""Configuration for invoking the AI model command."""
# The command to execute the AI model. Defaults to 'claude -p' or CDEV_AI_COMMAND env var.
command: str = os.environ.get("CDEV_AI_COMMAND", "claude -p")
raw: str | None = None
commands: list[str] = field(init=False, repr=False)
def __post_init__(self) -> None:
source = self.raw or os.environ.get("CDEV_AI_COMMAND", "claude -p")
self.commands = _parse_command_chain(source)
@property
def command(self) -> str:
"""Return the primary command (first in the fallback chain)."""
return self.commands[0]
def generate_output(
@ -86,7 +111,9 @@ def generate_output(
)
# Call the AI model and get its raw output.
raw_patch = call_model(model, prompt, cwd=repo_root)
raw_patch, no_changes = call_model(model, prompt, cwd=repo_root)
if no_changes:
return
# Use a temporary directory for storing intermediate patch files for debugging.
with tempfile.TemporaryDirectory(prefix="cdev-patch-") as tmpdir_str:
@ -304,6 +331,7 @@ End with: <<<AI_DIFF_END>>>
Only include the diff between these markers.
If the output file doesn't exist, create it from scratch in the patch.
If no changes are needed, output only this exact token (no diff): {no_change_token}
"""
@ -337,10 +365,11 @@ def build_prompt(
source_content=source_content.strip(),
output_content=output_content.strip() or "(empty)", # Indicate if output content is empty.
instruction=instruction.strip(),
no_change_token=NO_CHANGES_SENTINEL,
)
def call_model(model: ModelConfig, prompt: str, cwd: Path) -> str:
def call_model(model: ModelConfig, prompt: str, cwd: Path) -> tuple[str, bool]:
"""
Invokes the AI model command with the given prompt and captures its output.
@ -350,33 +379,39 @@ def call_model(model: ModelConfig, prompt: str, cwd: Path) -> str:
cwd: The current working directory for executing the command.
Returns:
The raw string output from the AI model.
A tuple of (stdout, is_no_change_sentinel).
Raises:
PatchGenerationError: If the AI command fails or returns an API error.
"""
command = model.command
result = subprocess.run(
command,
input=prompt,
text=True,
capture_output=True,
cwd=str(cwd),
shell=True, # Use shell=True to allow command to be a string with arguments.
)
# The Claude CLI (and potentially others) might return a non-zero exit code
# even on successful responses if there are warnings or specific API behaviors.
# We prioritize stdout content over return code for initial check.
if result.stdout.strip():
# Check for specific API error messages in the output.
if "API Error:" in result.stdout and "Overloaded" in result.stdout:
raise PatchGenerationError("Claude API is overloaded (500 error) - please retry later")
return result.stdout
errors: list[str] = []
# If stdout is empty, then a non-zero return code indicates a true failure.
if result.returncode != 0:
raise PatchGenerationError(f"AI command failed ({result.returncode}): {result.stderr.strip()}")
return result.stdout
for command in model.commands:
result = subprocess.run(
command,
input=prompt,
text=True,
capture_output=True,
cwd=str(cwd),
shell=True,
)
raw_stdout = result.stdout or ""
stdout = raw_stdout.strip()
stderr = result.stderr.strip()
if stdout:
if stdout == NO_CHANGES_SENTINEL:
return raw_stdout, True
if "API Error:" in raw_stdout and "Overloaded" in raw_stdout:
raise PatchGenerationError("Claude API is overloaded (500 error) - please retry later")
return raw_stdout, False
if result.returncode == 0:
errors.append(f"{command!r} produced no output")
else:
errors.append(f"{command!r} exited with {result.returncode}: {stderr or 'no stderr'}")
raise PatchGenerationError("AI command(s) failed: " + "; ".join(errors))
def extract_patch_with_markers(raw_output: str) -> str:
@ -525,31 +560,49 @@ def apply_patch(repo_root: Path, patch_file: Path, patch_level: str, output_rel:
"""
absolute_patch = patch_file.resolve()
# First, try a dry-run check with --check to see if the patch applies cleanly.
check_args = ["git", "apply", patch_level, "--index", "--check", absolute_patch.as_posix()]
if run(check_args, cwd=repo_root, check=False).returncode == 0:
# If it applies cleanly, perform the actual apply.
run(["git", "apply", patch_level, "--index", absolute_patch.as_posix()], cwd=repo_root)
return
# Check if the output file is currently staged
# If it is, unstage it so the patch applies to the working directory version
was_staged = False
check_staged = run(
["git", "diff", "--cached", "--name-only", "--", output_rel.as_posix()],
cwd=repo_root,
check=False
)
if check_staged.returncode == 0 and check_staged.stdout.strip():
# File is staged, unstage it temporarily
was_staged = True
run(["git", "reset", "HEAD", "--", output_rel.as_posix()], cwd=repo_root, check=False)
# If direct apply --check fails, try 3-way merge with a check first.
three_way_check_args = ["git", "apply", patch_level, "--index", "--3way", "--recount", "--whitespace=nowarn", absolute_patch.as_posix()]
if run(three_way_check_args + ["--check"], cwd=repo_root, check=False).returncode == 0:
# If 3-way check passes, perform the actual 3-way apply.
run(three_way_check_args, cwd=repo_root)
return
# Special handling for new files: if the patch creates a new file ('--- /dev/null'),
# a simple 'git apply' might work, followed by 'git add'.
text = patch_file.read_text(encoding="utf-8")
if "--- /dev/null" in text:
# Try applying without --index and then explicitly add the file.
if run(["git", "apply", patch_level, absolute_patch.as_posix()], cwd=repo_root, check=False).returncode == 0:
run(["git", "add", "--", output_rel.as_posix()], cwd=repo_root)
try:
# First, try a dry-run check with --check to see if the patch applies cleanly.
check_args = ["git", "apply", patch_level, "--index", "--check", absolute_patch.as_posix()]
if run(check_args, cwd=repo_root, check=False).returncode == 0:
# If it applies cleanly, perform the actual apply.
run(["git", "apply", patch_level, "--index", absolute_patch.as_posix()], cwd=repo_root)
return
# If all attempts fail, raise an error.
raise PatchGenerationError("Failed to apply patch (strict and 3-way both failed)")
# If direct apply --check fails, try 3-way merge with a check first.
three_way_check_args = ["git", "apply", patch_level, "--index", "--3way", "--recount", "--whitespace=nowarn", absolute_patch.as_posix()]
if run(three_way_check_args + ["--check"], cwd=repo_root, check=False).returncode == 0:
# If 3-way check passes, perform the actual 3-way apply.
run(three_way_check_args, cwd=repo_root)
return
# Special handling for new files: if the patch creates a new file ('--- /dev/null'),
# a simple 'git apply' might work, followed by 'git add'.
text = patch_file.read_text(encoding="utf-8")
if "--- /dev/null" in text:
# Try applying without --index and then explicitly add the file.
if run(["git", "apply", patch_level, absolute_patch.as_posix()], cwd=repo_root, check=False).returncode == 0:
run(["git", "add", "--", output_rel.as_posix()], cwd=repo_root)
return
# If all attempts fail, raise an error.
raise PatchGenerationError("Failed to apply patch (strict and 3-way both failed)")
finally:
# If the file was staged before, re-stage it now (whether patch succeeded or failed)
if was_staged:
run(["git", "add", "--", output_rel.as_posix()], cwd=repo_root, check=False)
def run(args: list[str], cwd: Path, check: bool = True) -> subprocess.CompletedProcess[str]:

View File

@ -133,7 +133,7 @@ def main(argv: list[str] | None = None) -> int:
return 0
# Instantiate the model config and delegate to the processing pipeline.
model = ModelConfig(command=args.model or ModelConfig().command)
model = ModelConfig(args.model)
return process(repo_root, rules, model)

View File

@ -39,7 +39,7 @@ diff --git a/Docs/features/FR_1/discussions/example.discussion.sum.md b/Docs/fea
patch_file = tmp_path / "patch.txt"
patch_file.write_text(patch_text, encoding="utf-8")
model = ModelConfig(command=f"bash -lc 'cat {patch_file.as_posix()}'")
model = ModelConfig(f"bash -lc 'cat {patch_file.as_posix()}'")
rules = RulesConfig(root=temp_repo, global_rules={"file_associations": {}, "rules": {}})
generate_output(

View File

@ -57,7 +57,7 @@ diff --git a/Docs/features/FR_1/discussions/example.discussion.sum.md b/Docs/fea
patch_file.write_text(patch_text, encoding="utf-8")
rules = RulesConfig.load(repo)
model = ModelConfig(command=f"bash -lc 'cat {patch_file.as_posix()}'")
model = ModelConfig(f"bash -lc 'cat {patch_file.as_posix()}'")
rc = process(repo, rules, model)
assert rc == 0