feat: add ai fallback chain with no-change sentinel
This commit is contained in:
parent
eed12ce749
commit
feb8580b3a
|
|
@ -11,7 +11,7 @@ import re
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from automation.config import RulesConfig
|
from automation.config import RulesConfig
|
||||||
|
|
@ -21,11 +21,36 @@ class PatchGenerationError(RuntimeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
NO_CHANGES_SENTINEL = "CASCADINGDEV_NO_CHANGES"
|
||||||
|
COMMAND_DELIMITER = "||"
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_command_chain(raw: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Split a fallback command string into individual commands.
|
||||||
|
|
||||||
|
Commands are separated by '||' to align with common shell semantics.
|
||||||
|
Whitespace-only segments are ignored.
|
||||||
|
"""
|
||||||
|
parts = [segment.strip() for segment in raw.split(COMMAND_DELIMITER)]
|
||||||
|
commands = [segment for segment in parts if segment]
|
||||||
|
return commands or ["claude -p"]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
"""Configuration for invoking the AI model command."""
|
"""Configuration for invoking the AI model command."""
|
||||||
# The command to execute the AI model. Defaults to 'claude -p' or CDEV_AI_COMMAND env var.
|
raw: str | None = None
|
||||||
command: str = os.environ.get("CDEV_AI_COMMAND", "claude -p")
|
commands: list[str] = field(init=False, repr=False)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
source = self.raw or os.environ.get("CDEV_AI_COMMAND", "claude -p")
|
||||||
|
self.commands = _parse_command_chain(source)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def command(self) -> str:
|
||||||
|
"""Return the primary command (first in the fallback chain)."""
|
||||||
|
return self.commands[0]
|
||||||
|
|
||||||
|
|
||||||
def generate_output(
|
def generate_output(
|
||||||
|
|
@ -86,7 +111,9 @@ def generate_output(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call the AI model and get its raw output.
|
# Call the AI model and get its raw output.
|
||||||
raw_patch = call_model(model, prompt, cwd=repo_root)
|
raw_patch, no_changes = call_model(model, prompt, cwd=repo_root)
|
||||||
|
if no_changes:
|
||||||
|
return
|
||||||
|
|
||||||
# Use a temporary directory for storing intermediate patch files for debugging.
|
# Use a temporary directory for storing intermediate patch files for debugging.
|
||||||
with tempfile.TemporaryDirectory(prefix="cdev-patch-") as tmpdir_str:
|
with tempfile.TemporaryDirectory(prefix="cdev-patch-") as tmpdir_str:
|
||||||
|
|
@ -304,6 +331,7 @@ End with: <<<AI_DIFF_END>>>
|
||||||
|
|
||||||
Only include the diff between these markers.
|
Only include the diff between these markers.
|
||||||
If the output file doesn't exist, create it from scratch in the patch.
|
If the output file doesn't exist, create it from scratch in the patch.
|
||||||
|
If no changes are needed, output only this exact token (no diff): {no_change_token}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -337,10 +365,11 @@ def build_prompt(
|
||||||
source_content=source_content.strip(),
|
source_content=source_content.strip(),
|
||||||
output_content=output_content.strip() or "(empty)", # Indicate if output content is empty.
|
output_content=output_content.strip() or "(empty)", # Indicate if output content is empty.
|
||||||
instruction=instruction.strip(),
|
instruction=instruction.strip(),
|
||||||
|
no_change_token=NO_CHANGES_SENTINEL,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def call_model(model: ModelConfig, prompt: str, cwd: Path) -> str:
|
def call_model(model: ModelConfig, prompt: str, cwd: Path) -> tuple[str, bool]:
|
||||||
"""
|
"""
|
||||||
Invokes the AI model command with the given prompt and captures its output.
|
Invokes the AI model command with the given prompt and captures its output.
|
||||||
|
|
||||||
|
|
@ -350,33 +379,39 @@ def call_model(model: ModelConfig, prompt: str, cwd: Path) -> str:
|
||||||
cwd: The current working directory for executing the command.
|
cwd: The current working directory for executing the command.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The raw string output from the AI model.
|
A tuple of (stdout, is_no_change_sentinel).
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
PatchGenerationError: If the AI command fails or returns an API error.
|
PatchGenerationError: If the AI command fails or returns an API error.
|
||||||
"""
|
"""
|
||||||
command = model.command
|
errors: list[str] = []
|
||||||
result = subprocess.run(
|
|
||||||
command,
|
|
||||||
input=prompt,
|
|
||||||
text=True,
|
|
||||||
capture_output=True,
|
|
||||||
cwd=str(cwd),
|
|
||||||
shell=True, # Use shell=True to allow command to be a string with arguments.
|
|
||||||
)
|
|
||||||
# The Claude CLI (and potentially others) might return a non-zero exit code
|
|
||||||
# even on successful responses if there are warnings or specific API behaviors.
|
|
||||||
# We prioritize stdout content over return code for initial check.
|
|
||||||
if result.stdout.strip():
|
|
||||||
# Check for specific API error messages in the output.
|
|
||||||
if "API Error:" in result.stdout and "Overloaded" in result.stdout:
|
|
||||||
raise PatchGenerationError("Claude API is overloaded (500 error) - please retry later")
|
|
||||||
return result.stdout
|
|
||||||
|
|
||||||
# If stdout is empty, then a non-zero return code indicates a true failure.
|
for command in model.commands:
|
||||||
if result.returncode != 0:
|
result = subprocess.run(
|
||||||
raise PatchGenerationError(f"AI command failed ({result.returncode}): {result.stderr.strip()}")
|
command,
|
||||||
return result.stdout
|
input=prompt,
|
||||||
|
text=True,
|
||||||
|
capture_output=True,
|
||||||
|
cwd=str(cwd),
|
||||||
|
shell=True,
|
||||||
|
)
|
||||||
|
raw_stdout = result.stdout or ""
|
||||||
|
stdout = raw_stdout.strip()
|
||||||
|
stderr = result.stderr.strip()
|
||||||
|
|
||||||
|
if stdout:
|
||||||
|
if stdout == NO_CHANGES_SENTINEL:
|
||||||
|
return raw_stdout, True
|
||||||
|
if "API Error:" in raw_stdout and "Overloaded" in raw_stdout:
|
||||||
|
raise PatchGenerationError("Claude API is overloaded (500 error) - please retry later")
|
||||||
|
return raw_stdout, False
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
errors.append(f"{command!r} produced no output")
|
||||||
|
else:
|
||||||
|
errors.append(f"{command!r} exited with {result.returncode}: {stderr or 'no stderr'}")
|
||||||
|
|
||||||
|
raise PatchGenerationError("AI command(s) failed: " + "; ".join(errors))
|
||||||
|
|
||||||
|
|
||||||
def extract_patch_with_markers(raw_output: str) -> str:
|
def extract_patch_with_markers(raw_output: str) -> str:
|
||||||
|
|
@ -525,31 +560,49 @@ def apply_patch(repo_root: Path, patch_file: Path, patch_level: str, output_rel:
|
||||||
"""
|
"""
|
||||||
absolute_patch = patch_file.resolve()
|
absolute_patch = patch_file.resolve()
|
||||||
|
|
||||||
# First, try a dry-run check with --check to see if the patch applies cleanly.
|
# Check if the output file is currently staged
|
||||||
check_args = ["git", "apply", patch_level, "--index", "--check", absolute_patch.as_posix()]
|
# If it is, unstage it so the patch applies to the working directory version
|
||||||
if run(check_args, cwd=repo_root, check=False).returncode == 0:
|
was_staged = False
|
||||||
# If it applies cleanly, perform the actual apply.
|
check_staged = run(
|
||||||
run(["git", "apply", patch_level, "--index", absolute_patch.as_posix()], cwd=repo_root)
|
["git", "diff", "--cached", "--name-only", "--", output_rel.as_posix()],
|
||||||
return
|
cwd=repo_root,
|
||||||
|
check=False
|
||||||
|
)
|
||||||
|
if check_staged.returncode == 0 and check_staged.stdout.strip():
|
||||||
|
# File is staged, unstage it temporarily
|
||||||
|
was_staged = True
|
||||||
|
run(["git", "reset", "HEAD", "--", output_rel.as_posix()], cwd=repo_root, check=False)
|
||||||
|
|
||||||
# If direct apply --check fails, try 3-way merge with a check first.
|
try:
|
||||||
three_way_check_args = ["git", "apply", patch_level, "--index", "--3way", "--recount", "--whitespace=nowarn", absolute_patch.as_posix()]
|
# First, try a dry-run check with --check to see if the patch applies cleanly.
|
||||||
if run(three_way_check_args + ["--check"], cwd=repo_root, check=False).returncode == 0:
|
check_args = ["git", "apply", patch_level, "--index", "--check", absolute_patch.as_posix()]
|
||||||
# If 3-way check passes, perform the actual 3-way apply.
|
if run(check_args, cwd=repo_root, check=False).returncode == 0:
|
||||||
run(three_way_check_args, cwd=repo_root)
|
# If it applies cleanly, perform the actual apply.
|
||||||
return
|
run(["git", "apply", patch_level, "--index", absolute_patch.as_posix()], cwd=repo_root)
|
||||||
|
|
||||||
# Special handling for new files: if the patch creates a new file ('--- /dev/null'),
|
|
||||||
# a simple 'git apply' might work, followed by 'git add'.
|
|
||||||
text = patch_file.read_text(encoding="utf-8")
|
|
||||||
if "--- /dev/null" in text:
|
|
||||||
# Try applying without --index and then explicitly add the file.
|
|
||||||
if run(["git", "apply", patch_level, absolute_patch.as_posix()], cwd=repo_root, check=False).returncode == 0:
|
|
||||||
run(["git", "add", "--", output_rel.as_posix()], cwd=repo_root)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# If all attempts fail, raise an error.
|
# If direct apply --check fails, try 3-way merge with a check first.
|
||||||
raise PatchGenerationError("Failed to apply patch (strict and 3-way both failed)")
|
three_way_check_args = ["git", "apply", patch_level, "--index", "--3way", "--recount", "--whitespace=nowarn", absolute_patch.as_posix()]
|
||||||
|
if run(three_way_check_args + ["--check"], cwd=repo_root, check=False).returncode == 0:
|
||||||
|
# If 3-way check passes, perform the actual 3-way apply.
|
||||||
|
run(three_way_check_args, cwd=repo_root)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Special handling for new files: if the patch creates a new file ('--- /dev/null'),
|
||||||
|
# a simple 'git apply' might work, followed by 'git add'.
|
||||||
|
text = patch_file.read_text(encoding="utf-8")
|
||||||
|
if "--- /dev/null" in text:
|
||||||
|
# Try applying without --index and then explicitly add the file.
|
||||||
|
if run(["git", "apply", patch_level, absolute_patch.as_posix()], cwd=repo_root, check=False).returncode == 0:
|
||||||
|
run(["git", "add", "--", output_rel.as_posix()], cwd=repo_root)
|
||||||
|
return
|
||||||
|
|
||||||
|
# If all attempts fail, raise an error.
|
||||||
|
raise PatchGenerationError("Failed to apply patch (strict and 3-way both failed)")
|
||||||
|
finally:
|
||||||
|
# If the file was staged before, re-stage it now (whether patch succeeded or failed)
|
||||||
|
if was_staged:
|
||||||
|
run(["git", "add", "--", output_rel.as_posix()], cwd=repo_root, check=False)
|
||||||
|
|
||||||
|
|
||||||
def run(args: list[str], cwd: Path, check: bool = True) -> subprocess.CompletedProcess[str]:
|
def run(args: list[str], cwd: Path, check: bool = True) -> subprocess.CompletedProcess[str]:
|
||||||
|
|
|
||||||
|
|
@ -133,7 +133,7 @@ def main(argv: list[str] | None = None) -> int:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# Instantiate the model config and delegate to the processing pipeline.
|
# Instantiate the model config and delegate to the processing pipeline.
|
||||||
model = ModelConfig(command=args.model or ModelConfig().command)
|
model = ModelConfig(args.model)
|
||||||
return process(repo_root, rules, model)
|
return process(repo_root, rules, model)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ diff --git a/Docs/features/FR_1/discussions/example.discussion.sum.md b/Docs/fea
|
||||||
patch_file = tmp_path / "patch.txt"
|
patch_file = tmp_path / "patch.txt"
|
||||||
patch_file.write_text(patch_text, encoding="utf-8")
|
patch_file.write_text(patch_text, encoding="utf-8")
|
||||||
|
|
||||||
model = ModelConfig(command=f"bash -lc 'cat {patch_file.as_posix()}'")
|
model = ModelConfig(f"bash -lc 'cat {patch_file.as_posix()}'")
|
||||||
rules = RulesConfig(root=temp_repo, global_rules={"file_associations": {}, "rules": {}})
|
rules = RulesConfig(root=temp_repo, global_rules={"file_associations": {}, "rules": {}})
|
||||||
|
|
||||||
generate_output(
|
generate_output(
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ diff --git a/Docs/features/FR_1/discussions/example.discussion.sum.md b/Docs/fea
|
||||||
patch_file.write_text(patch_text, encoding="utf-8")
|
patch_file.write_text(patch_text, encoding="utf-8")
|
||||||
|
|
||||||
rules = RulesConfig.load(repo)
|
rules = RulesConfig.load(repo)
|
||||||
model = ModelConfig(command=f"bash -lc 'cat {patch_file.as_posix()}'")
|
model = ModelConfig(f"bash -lc 'cat {patch_file.as_posix()}'")
|
||||||
|
|
||||||
rc = process(repo, rules, model)
|
rc = process(repo, rules, model)
|
||||||
assert rc == 0
|
assert rc == 0
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue