feat: add ai fallback chain with no-change sentinel

This commit is contained in:
rob 2025-11-01 14:25:22 -03:00
parent eed12ce749
commit feb8580b3a
4 changed files with 106 additions and 53 deletions

View File

@ -11,7 +11,7 @@ import re
import shutil import shutil
import subprocess import subprocess
import tempfile import tempfile
from dataclasses import dataclass from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from automation.config import RulesConfig from automation.config import RulesConfig
@ -21,11 +21,36 @@ class PatchGenerationError(RuntimeError):
pass pass
NO_CHANGES_SENTINEL = "CASCADINGDEV_NO_CHANGES"
COMMAND_DELIMITER = "||"
def _parse_command_chain(raw: str) -> list[str]:
"""
Split a fallback command string into individual commands.
Commands are separated by '||' to align with common shell semantics.
Whitespace-only segments are ignored.
"""
parts = [segment.strip() for segment in raw.split(COMMAND_DELIMITER)]
commands = [segment for segment in parts if segment]
return commands or ["claude -p"]
@dataclass @dataclass
class ModelConfig: class ModelConfig:
"""Configuration for invoking the AI model command.""" """Configuration for invoking the AI model command."""
# The command to execute the AI model. Defaults to 'claude -p' or CDEV_AI_COMMAND env var. raw: str | None = None
command: str = os.environ.get("CDEV_AI_COMMAND", "claude -p") commands: list[str] = field(init=False, repr=False)
def __post_init__(self) -> None:
source = self.raw or os.environ.get("CDEV_AI_COMMAND", "claude -p")
self.commands = _parse_command_chain(source)
@property
def command(self) -> str:
"""Return the primary command (first in the fallback chain)."""
return self.commands[0]
def generate_output( def generate_output(
@ -86,7 +111,9 @@ def generate_output(
) )
# Call the AI model and get its raw output. # Call the AI model and get its raw output.
raw_patch = call_model(model, prompt, cwd=repo_root) raw_patch, no_changes = call_model(model, prompt, cwd=repo_root)
if no_changes:
return
# Use a temporary directory for storing intermediate patch files for debugging. # Use a temporary directory for storing intermediate patch files for debugging.
with tempfile.TemporaryDirectory(prefix="cdev-patch-") as tmpdir_str: with tempfile.TemporaryDirectory(prefix="cdev-patch-") as tmpdir_str:
@ -304,6 +331,7 @@ End with: <<<AI_DIFF_END>>>
Only include the diff between these markers. Only include the diff between these markers.
If the output file doesn't exist, create it from scratch in the patch. If the output file doesn't exist, create it from scratch in the patch.
If no changes are needed, output only this exact token (no diff): {no_change_token}
""" """
@ -337,10 +365,11 @@ def build_prompt(
source_content=source_content.strip(), source_content=source_content.strip(),
output_content=output_content.strip() or "(empty)", # Indicate if output content is empty. output_content=output_content.strip() or "(empty)", # Indicate if output content is empty.
instruction=instruction.strip(), instruction=instruction.strip(),
no_change_token=NO_CHANGES_SENTINEL,
) )
def call_model(model: ModelConfig, prompt: str, cwd: Path) -> str: def call_model(model: ModelConfig, prompt: str, cwd: Path) -> tuple[str, bool]:
""" """
Invokes the AI model command with the given prompt and captures its output. Invokes the AI model command with the given prompt and captures its output.
@ -350,33 +379,39 @@ def call_model(model: ModelConfig, prompt: str, cwd: Path) -> str:
cwd: The current working directory for executing the command. cwd: The current working directory for executing the command.
Returns: Returns:
The raw string output from the AI model. A tuple of (stdout, is_no_change_sentinel).
Raises: Raises:
PatchGenerationError: If the AI command fails or returns an API error. PatchGenerationError: If the AI command fails or returns an API error.
""" """
command = model.command errors: list[str] = []
result = subprocess.run(
command,
input=prompt,
text=True,
capture_output=True,
cwd=str(cwd),
shell=True, # Use shell=True to allow command to be a string with arguments.
)
# The Claude CLI (and potentially others) might return a non-zero exit code
# even on successful responses if there are warnings or specific API behaviors.
# We prioritize stdout content over return code for initial check.
if result.stdout.strip():
# Check for specific API error messages in the output.
if "API Error:" in result.stdout and "Overloaded" in result.stdout:
raise PatchGenerationError("Claude API is overloaded (500 error) - please retry later")
return result.stdout
# If stdout is empty, then a non-zero return code indicates a true failure. for command in model.commands:
if result.returncode != 0: result = subprocess.run(
raise PatchGenerationError(f"AI command failed ({result.returncode}): {result.stderr.strip()}") command,
return result.stdout input=prompt,
text=True,
capture_output=True,
cwd=str(cwd),
shell=True,
)
raw_stdout = result.stdout or ""
stdout = raw_stdout.strip()
stderr = result.stderr.strip()
if stdout:
if stdout == NO_CHANGES_SENTINEL:
return raw_stdout, True
if "API Error:" in raw_stdout and "Overloaded" in raw_stdout:
raise PatchGenerationError("Claude API is overloaded (500 error) - please retry later")
return raw_stdout, False
if result.returncode == 0:
errors.append(f"{command!r} produced no output")
else:
errors.append(f"{command!r} exited with {result.returncode}: {stderr or 'no stderr'}")
raise PatchGenerationError("AI command(s) failed: " + "; ".join(errors))
def extract_patch_with_markers(raw_output: str) -> str: def extract_patch_with_markers(raw_output: str) -> str:
@ -524,32 +559,50 @@ def apply_patch(repo_root: Path, patch_file: Path, patch_level: str, output_rel:
PatchGenerationError: If both 3-way and strict patch application fail. PatchGenerationError: If both 3-way and strict patch application fail.
""" """
absolute_patch = patch_file.resolve() absolute_patch = patch_file.resolve()
# First, try a dry-run check with --check to see if the patch applies cleanly.
check_args = ["git", "apply", patch_level, "--index", "--check", absolute_patch.as_posix()]
if run(check_args, cwd=repo_root, check=False).returncode == 0:
# If it applies cleanly, perform the actual apply.
run(["git", "apply", patch_level, "--index", absolute_patch.as_posix()], cwd=repo_root)
return
# If direct apply --check fails, try 3-way merge with a check first. # Check if the output file is currently staged
three_way_check_args = ["git", "apply", patch_level, "--index", "--3way", "--recount", "--whitespace=nowarn", absolute_patch.as_posix()] # If it is, unstage it so the patch applies to the working directory version
if run(three_way_check_args + ["--check"], cwd=repo_root, check=False).returncode == 0: was_staged = False
# If 3-way check passes, perform the actual 3-way apply. check_staged = run(
run(three_way_check_args, cwd=repo_root) ["git", "diff", "--cached", "--name-only", "--", output_rel.as_posix()],
return cwd=repo_root,
check=False
# Special handling for new files: if the patch creates a new file ('--- /dev/null'), )
# a simple 'git apply' might work, followed by 'git add'. if check_staged.returncode == 0 and check_staged.stdout.strip():
text = patch_file.read_text(encoding="utf-8") # File is staged, unstage it temporarily
if "--- /dev/null" in text: was_staged = True
# Try applying without --index and then explicitly add the file. run(["git", "reset", "HEAD", "--", output_rel.as_posix()], cwd=repo_root, check=False)
if run(["git", "apply", patch_level, absolute_patch.as_posix()], cwd=repo_root, check=False).returncode == 0:
run(["git", "add", "--", output_rel.as_posix()], cwd=repo_root) try:
# First, try a dry-run check with --check to see if the patch applies cleanly.
check_args = ["git", "apply", patch_level, "--index", "--check", absolute_patch.as_posix()]
if run(check_args, cwd=repo_root, check=False).returncode == 0:
# If it applies cleanly, perform the actual apply.
run(["git", "apply", patch_level, "--index", absolute_patch.as_posix()], cwd=repo_root)
return return
# If all attempts fail, raise an error. # If direct apply --check fails, try 3-way merge with a check first.
raise PatchGenerationError("Failed to apply patch (strict and 3-way both failed)") three_way_check_args = ["git", "apply", patch_level, "--index", "--3way", "--recount", "--whitespace=nowarn", absolute_patch.as_posix()]
if run(three_way_check_args + ["--check"], cwd=repo_root, check=False).returncode == 0:
# If 3-way check passes, perform the actual 3-way apply.
run(three_way_check_args, cwd=repo_root)
return
# Special handling for new files: if the patch creates a new file ('--- /dev/null'),
# a simple 'git apply' might work, followed by 'git add'.
text = patch_file.read_text(encoding="utf-8")
if "--- /dev/null" in text:
# Try applying without --index and then explicitly add the file.
if run(["git", "apply", patch_level, absolute_patch.as_posix()], cwd=repo_root, check=False).returncode == 0:
run(["git", "add", "--", output_rel.as_posix()], cwd=repo_root)
return
# If all attempts fail, raise an error.
raise PatchGenerationError("Failed to apply patch (strict and 3-way both failed)")
finally:
# If the file was staged before, re-stage it now (whether patch succeeded or failed)
if was_staged:
run(["git", "add", "--", output_rel.as_posix()], cwd=repo_root, check=False)
def run(args: list[str], cwd: Path, check: bool = True) -> subprocess.CompletedProcess[str]: def run(args: list[str], cwd: Path, check: bool = True) -> subprocess.CompletedProcess[str]:

View File

@ -133,7 +133,7 @@ def main(argv: list[str] | None = None) -> int:
return 0 return 0
# Instantiate the model config and delegate to the processing pipeline. # Instantiate the model config and delegate to the processing pipeline.
model = ModelConfig(command=args.model or ModelConfig().command) model = ModelConfig(args.model)
return process(repo_root, rules, model) return process(repo_root, rules, model)

View File

@ -39,7 +39,7 @@ diff --git a/Docs/features/FR_1/discussions/example.discussion.sum.md b/Docs/fea
patch_file = tmp_path / "patch.txt" patch_file = tmp_path / "patch.txt"
patch_file.write_text(patch_text, encoding="utf-8") patch_file.write_text(patch_text, encoding="utf-8")
model = ModelConfig(command=f"bash -lc 'cat {patch_file.as_posix()}'") model = ModelConfig(f"bash -lc 'cat {patch_file.as_posix()}'")
rules = RulesConfig(root=temp_repo, global_rules={"file_associations": {}, "rules": {}}) rules = RulesConfig(root=temp_repo, global_rules={"file_associations": {}, "rules": {}})
generate_output( generate_output(

View File

@ -57,7 +57,7 @@ diff --git a/Docs/features/FR_1/discussions/example.discussion.sum.md b/Docs/fea
patch_file.write_text(patch_text, encoding="utf-8") patch_file.write_text(patch_text, encoding="utf-8")
rules = RulesConfig.load(repo) rules = RulesConfig.load(repo)
model = ModelConfig(command=f"bash -lc 'cat {patch_file.as_posix()}'") model = ModelConfig(f"bash -lc 'cat {patch_file.as_posix()}'")
rc = process(repo, rules, model) rc = process(repo, rules, model)
assert rc == 0 assert rc == 0