CascadingDev/automation/patcher.py

767 lines
28 KiB
Python

"""
AI-powered patch generation and application utilities.
This module ports the proven bash hook logic into Python so the orchestration
pipeline can be tested and extended more easily.
"""
from __future__ import annotations
import os
import re
import shutil
import subprocess
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import json
import shlex
import sys
from automation.ai_config import (
DEFAULT_COMMAND_CHAIN,
DEFAULT_SENTINEL,
load_ai_settings,
parse_command_chain,
)
# Optional: support tqdm progress bar if available but optional
try: # pragma: no cover - optional dependency
from tqdm import tqdm # type: ignore
except Exception: # pragma: no cover
tqdm = None
from automation.config import RulesConfig
class PatchGenerationError(RuntimeError):
pass
@dataclass
class ModelConfig:
"""Configuration for invoking the AI model command."""
commands: list[str]
sentinel: str
runner_settings: Any | None = None # RunnerSettings from ai_config
@classmethod
def from_sources(cls, repo_root: Path, override: str | None = None) -> "ModelConfig":
"""
Build a ModelConfig using CLI override, environment variable, or
repository configuration (in that precedence order).
"""
settings = load_ai_settings(repo_root)
sentinel = settings.runner.sentinel or DEFAULT_SENTINEL
if override and override.strip():
commands = parse_command_chain(override)
else:
env_override = os.environ.get("CDEV_AI_COMMAND")
if env_override and env_override.strip():
commands = parse_command_chain(env_override)
else:
commands = settings.runner.command_chain
if not commands:
commands = settings.runner.command_chain or DEFAULT_COMMAND_CHAIN.copy()
return cls(commands=commands, sentinel=sentinel, runner_settings=settings.runner)
def get_commands_for_hint(self, hint: str) -> list[str]:
"""
Get the appropriate command chain based on the model hint.
Args:
hint: 'fast', 'quality', or empty string
Returns:
List of commands to try
"""
if self.runner_settings:
hint_commands = self.runner_settings.get_chain_for_hint(hint)
if hint_commands:
return hint_commands
return self.commands
def generate_output(
repo_root: Path,
rules: RulesConfig,
model: ModelConfig,
source_rel: Path,
output_rel: Path,
instruction: str,
model_hint: str = "",
) -> None:
"""
Generates or refreshes an output artifact using AI, based on staged changes
and a given instruction. The process involves:
1. Ensuring the output file is known to Git (for new files).
2. Gathering relevant context (source diff, source content, existing output content).
3. Building a prompt for the AI model.
4. Calling the AI model to generate a unified diff.
5. Extracting and sanitizing the generated patch.
6. Applying the patch to the repository.
Args:
repo_root: The root directory of the Git repository.
rules: The RulesConfig object for resolving template variables and paths.
model: The ModelConfig object specifying the AI command.
source_rel: The path to the source file relative to repo_root.
output_rel: The path to the output file relative to repo_root.
instruction: The instruction for the AI model.
model_hint: Optional hint ('fast' or 'quality') to guide model selection.
Raises:
PatchGenerationError: If AI output is empty or patch application fails.
"""
repo_root = repo_root.resolve()
# Ensure paths are relative to the repository root for consistent Git operations.
source_rel = source_rel
output_rel = output_rel
# Ensure the parent directory for the output file exists.
(repo_root / output_rel).parent.mkdir(parents=True, exist_ok=True)
# Ensure Git knows about the output file, especially if it's new.
# This allows 'git apply --index' to work correctly for new files.
ensure_intent_to_add(repo_root, output_rel)
# Gather context for the AI prompt.
source_diff = git_diff_cached(repo_root, source_rel)
source_content = git_show_cached(repo_root, source_rel)
# Get the current content of the output file (pre-image) and its Git hash.
output_preimage, output_hash = read_output_preimage(repo_root, output_rel)
# Build the comprehensive prompt for the AI model.
prompt = build_prompt(
source_rel=source_rel,
output_rel=output_rel,
source_diff=source_diff,
source_content=source_content,
output_content=output_preimage,
instruction=instruction,
no_change_token=model.sentinel,
model_hint=model_hint,
)
# Call the AI model and get its raw output.
raw_patch, no_changes = call_model(model, prompt, model_hint, cwd=repo_root)
if no_changes:
return
# Use a temporary directory for storing intermediate patch files for debugging.
with tempfile.TemporaryDirectory(prefix="cdev-patch-") as tmpdir_str:
tmpdir = Path(tmpdir_str)
raw_path = tmpdir / "raw.out"
clean_path = tmpdir / "clean.diff"
sanitized_path = tmpdir / "sanitized.diff"
raw_path.write_text(raw_patch, encoding="utf-8")
try:
# Extract the actual diff content from the AI's raw output, using markers.
extracted = extract_patch_with_markers(raw_path.read_text(encoding="utf-8"))
clean_path.write_text(extracted, encoding="utf-8")
except PatchGenerationError as e:
# If extraction fails, save debug artifacts and re-raise the error.
save_debug_artifacts(repo_root, output_rel, raw_path, None, None, None)
raise
# Sanitize the extracted patch to remove unwanted git metadata lines.
sanitized = sanitize_unified_patch(clean_path.read_text(encoding="utf-8"))
# Special handling for new files: ensure 'new file mode' is present if '--- /dev/null' is.
if "--- /dev/null" in sanitized and "new file mode" not in sanitized:
sanitized = sanitized.replace("--- /dev/null", "new file mode 100644\n--- /dev/null", 1)
sanitized_path.write_text(sanitized, encoding="utf-8")
# Determine the patch level. -p1 is common for diffs generated from repo root.
patch_level = "-p1"
final_patch_path = sanitized_path
# Save all intermediate patch files for debugging purposes.
save_debug_artifacts(repo_root, output_rel, raw_path, clean_path, sanitized_path, final_patch_path)
# Check if the final patch is empty after sanitization.
# This is normal when AI determines no changes are needed (e.g., gated outputs)
if not final_patch_path.read_text(encoding="utf-8").strip():
# Empty patch is not an error - AI decided not to make changes
# Common cases:
# - implementation_gate_writer when status != READY_FOR_IMPLEMENTATION
# - AI determines output is already correct
return # Skip silently
# Apply the generated patch to the Git repository.
apply_patch(repo_root, final_patch_path, patch_level, output_rel)
def ensure_intent_to_add(repo_root: Path, rel_path: Path) -> None:
"""
Ensures Git has an 'intent to add' for a new file, so 'git apply --index'
can create it correctly. If the file is already tracked, do nothing.
Args:
repo_root: The root directory of the Git repository.
rel_path: The path to the file relative to repo_root.
"""
if git_ls_files(repo_root, rel_path):
# File is already tracked by Git, no need for 'intent to add'.
return
# Use 'git add -N' to record the intent to add a new file without staging its content.
run(["git", "add", "-N", "--", rel_path.as_posix()], cwd=repo_root, check=False)
def git_ls_files(repo_root: Path, rel_path: Path) -> bool:
"""
Checks if a file is tracked by Git.
Args:
repo_root: The root directory of the Git repository.
rel_path: The path to the file relative to repo_root.
Returns:
True if the file is tracked, False otherwise.
"""
result = run(
["git", "ls-files", "--error-unmatch", "--", rel_path.as_posix()],
cwd=repo_root,
check=False,
)
return result.returncode == 0
def git_diff_cached(repo_root: Path, rel_path: Path) -> str:
"""
Retrieves the cached diff (staged changes) for a specific file.
Args:
repo_root: The root directory of the Git repository.
rel_path: The path to the file relative to repo_root.
Returns:
The unified diff string for the staged changes of the file.
"""
result = run(
["git", "diff", "--cached", "--unified=2", "--", rel_path.as_posix()],
cwd=repo_root,
check=False,
)
return result.stdout
def git_show_cached(repo_root: Path, rel_path: Path) -> str:
"""
Retrieves the content of a file as it exists in the Git index (staged version).
If the file is not in the index, it attempts to read it from the working tree.
Args:
repo_root: The root directory of the Git repository.
rel_path: The path to the file relative to repo_root.
Returns:
The content of the file as a string.
"""
# Try to get content from the Git index.
result = run(
["git", "show", f":{rel_path.as_posix()}"],
cwd=repo_root,
check=False,
)
if result.returncode == 0:
return result.stdout
# If not in index, try to read from the working tree (e.g., for new files not yet staged).
file_path = repo_root / rel_path
if file_path.exists():
return file_path.read_text(encoding="utf-8")
return ""
def read_output_preimage(repo_root: Path, rel_path: Path) -> tuple[str, str]:
"""
Reads the current content and Git blob hash of the output file.
This is used as the 'pre-image' for generating the AI patch.
It first checks the Git index, then the working tree.
Args:
repo_root: The root directory of the Git repository.
rel_path: The path to the output file relative to repo_root.
Returns:
A tuple containing (content_string, blob_hash_string).
Returns empty string and default hash if the file does not exist.
"""
# Check if the file is staged and get its blob hash.
staged_hash_result = run(
["git", "ls-files", "--stage", "--", rel_path.as_posix()],
cwd=repo_root,
check=False,
)
blob_hash = "0" * 40 # Default hash if file is new or untracked.
if staged_hash_result.returncode == 0 and staged_hash_result.stdout.strip():
# If staged, get its content from the index.
show_result = run(["git", "show", f":{rel_path.as_posix()}"], cwd=repo_root, check=False)
content = show_result.stdout if show_result.returncode == 0 else ""
# Extract the blob hash from 'git ls-files --stage' output.
# The output format is: <mode> <hash> <stage> <file>
first_field = staged_hash_result.stdout.strip().split()[1]
blob_hash = first_field
return content, blob_hash
# If not staged, try to read from the working tree.
file_path = repo_root / rel_path
if file_path.exists():
content = file_path.read_text(encoding="utf-8")
# Calculate blob hash for working tree file.
blob_hash = run(
["git", "hash-object", file_path.as_posix()],
cwd=repo_root,
check=False,
).stdout.strip() or blob_hash
return content, blob_hash
# File does not exist.
return "", blob_hash
# This template defines the structure and content of the prompt sent to the AI model.
# It includes context about the source file, current output, and specific instructions.
PROMPT_TEMPLATE = """You are assisting with automated artifact generation during a git commit.
SOURCE FILE: {source_path}
OUTPUT FILE: {output_path}
{model_hint_line}
=== SOURCE FILE CHANGES (staged) ===
{source_diff}
=== SOURCE FILE CONTENT (staged) ===
{source_content}
=== CURRENT OUTPUT CONTENT (use this as the preimage) ===
{output_content}
=== GENERATION INSTRUCTIONS ===
{instruction}
=== OUTPUT FORMAT REQUIREMENTS ===
Wrap your unified diff with these exact markers:
<<<AI_DIFF_START>>>
[your diff here]
<<<AI_DIFF_END>>>
For NEW FILES, use these headers exactly:
--- /dev/null
+++ b/{output_path}
=== TASK ===
Create or update {output_path} according to the instructions above.
Output ONLY a unified diff patch in proper git format:
- Use format: diff --git a/{output_path} b/{output_path}
- (Optional) You may include an "index ..." line, but it will be ignored
- Include complete hunks with context lines
- No markdown fences, no explanations, just the patch
Start with: <<<AI_DIFF_START>>>
End with: <<<AI_DIFF_END>>>
Only include the diff between these markers.
If the output file doesn't exist, create it from scratch in the patch.
If no changes are needed, output only this exact token (no diff): {no_change_token}
"""
def build_prompt(
source_rel: Path,
output_rel: Path,
source_diff: str,
source_content: str,
output_content: str,
instruction: str,
no_change_token: str,
model_hint: str = "",
) -> str:
"""
Constructs the full prompt string for the AI model by formatting the
PROMPT_TEMPLATE with the provided context and instructions.
Args:
source_rel: Relative path to the source file.
output_rel: Relative path to the output file.
source_diff: Git diff of the staged source file.
source_content: Content of the staged source file.
output_content: Current content of the output file (pre-image).
instruction: Specific instructions for the AI.
model_hint: Optional hint ('fast' or 'quality') for model selection.
Returns:
The formatted prompt string.
"""
# Format the model hint line if provided
model_hint_line = ""
if model_hint:
model_hint_line = f"TASK COMPLEXITY: {model_hint.upper()}\n"
return PROMPT_TEMPLATE.format(
source_path=source_rel.as_posix(),
output_path=output_rel.as_posix(),
model_hint_line=model_hint_line,
source_diff=source_diff.strip(),
source_content=source_content.strip(),
output_content=output_content.strip() or "(empty)", # Indicate if output content is empty.
instruction=instruction.strip(),
no_change_token=no_change_token,
)
def call_model(
model: ModelConfig,
prompt: str,
model_hint: str,
cwd: Path,
context: str = "[runner]",
) -> tuple[str, bool]:
"""
Invokes the AI model command with the given prompt and captures its output.
Args:
model: The ModelConfig object containing the AI command.
prompt: The input prompt string for the AI model.
model_hint: Optional hint ('fast' or 'quality') for model selection.
cwd: The current working directory for executing the command.
Returns:
A tuple of (stdout, is_no_change_sentinel).
Raises:
PatchGenerationError: If the AI command fails or returns an API error.
"""
errors: list[str] = []
# Get commands based on hint
commands = model.get_commands_for_hint(model_hint)
total = len(commands)
for idx, command in enumerate(commands, start=1):
provider_name = command.split()[0]
sys.stderr.write(f"{context} provider {idx}/{total}{provider_name}\n")
sys.stderr.flush()
executor, raw_stdout, stderr, returncode = _run_ai_command(command, prompt, cwd)
if raw_stdout:
stripped = raw_stdout.strip()
if stripped == model.sentinel:
sys.stderr.write(
f"{context} provider {idx}/{total}{provider_name} returned sentinel (no change)\n"
)
sys.stderr.flush()
return raw_stdout, True
if "API Error:" in raw_stdout and "Overloaded" in raw_stdout:
raise PatchGenerationError("Claude API is overloaded (500 error) - please retry later")
if "<<<AI_DIFF_START>>>" in raw_stdout:
sys.stderr.write(
f"{context} provider {idx}/{total}{provider_name} produced diff\n"
)
sys.stderr.flush()
return raw_stdout, False
# Non-empty output without diff markers counts as failure so we can try fallbacks.
errors.append(f"{executor!r} produced non-diff output: {stripped[:80]}")
sys.stderr.write(
f"{context} provider {idx}/{total}{provider_name} non-diff output; trying next\n"
)
sys.stderr.flush()
continue
if returncode == 0:
errors.append(f"{executor!r} produced no output")
sys.stderr.write(
f"{context} provider {idx}/{total}{provider_name} returned no output\n"
)
sys.stderr.flush()
else:
errors.append(f"{executor!r} exited with {returncode}: {stderr or 'no stderr'}")
sys.stderr.write(
f"{context} provider {idx}/{total}{provider_name} exited with {returncode}\n"
)
sys.stderr.flush()
raise PatchGenerationError("AI command(s) failed: " + "; ".join(errors))
def extract_patch_with_markers(raw_output: str) -> str:
"""
Extracts the unified diff string from the AI model's raw output.
It looks for specific start and end markers (<<<AI_DIFF_START>>> / <<<AI_DIFF_END>>>).
If markers are not found, it attempts to find the first 'diff --git' line.
Args:
raw_output: The complete raw string output from the AI model.
Returns:
The extracted unified diff string.
Raises:
PatchGenerationError: If no valid diff markers or diff header are found.
"""
start_marker = "<<<AI_DIFF_START>>>"
end_marker = "<<<AI_DIFF_END>>>"
if start_marker in raw_output:
start_idx = raw_output.index(start_marker) + len(start_marker)
end_idx = raw_output.find(end_marker, start_idx)
if end_idx == -1:
raise PatchGenerationError("AI output missing end marker")
return raw_output[start_idx:end_idx].strip()
# Fallback: if markers are not present, try to find a 'diff --git' header.
match = re.search(r"^diff --git .*", raw_output, re.MULTILINE | re.DOTALL)
if match:
return raw_output[match.start() :].strip()
raise PatchGenerationError("AI output did not contain a diff")
def sanitize_unified_patch(patch: str) -> str:
"""
Cleans up a unified diff patch by removing specific Git metadata lines
that can interfere with 'git apply'.
Args:
patch: The raw unified diff string.
Returns:
The sanitized diff string, or empty string if no diff header found.
An empty result indicates the AI chose not to make changes (e.g., gated output).
"""
lines = patch.replace("\r", "").splitlines()
cleaned = []
for line in lines:
# Skip lines that are typically added by 'git diff' but not needed for 'git apply'.
if line.startswith("index ") or line.startswith("similarity index ") or line.startswith("rename from ") or line.startswith("rename to "):
continue
cleaned.append(line)
text = "\n".join(cleaned)
# Look for the 'diff --git' header
diff_start = text.find("diff --git")
if diff_start == -1:
# Some providers omit diff header; synthesize one if we have +++ line.
plus_line = next((line for line in cleaned if line.startswith("+++ ")), None)
minus_line = next((line for line in cleaned if line.startswith("--- ")), None)
if plus_line:
plus_path = plus_line[4:].strip()
minus_path = minus_line[4:].strip() if minus_line else plus_path
if minus_path == "/dev/null":
minus_path = plus_path
text = f"diff --git {minus_path} {plus_path}\n{text}"
else:
# No diff header means AI chose not to generate a patch
# This is normal for gated outputs or when AI determines no changes needed
return ""
else:
text = text[diff_start:]
return text + "\n"
def rewrite_patch_for_p0(patch: str) -> str:
"""
Rewrites a unified diff patch to be compatible with 'git apply -p0'.
This adjusts file paths in the diff headers to remove the 'a/' and 'b/' prefixes.
Args:
patch: The unified diff string.
Returns:
The rewritten diff string.
"""
rewritten_lines = []
diff_header_re = re.compile(r"^diff --git a/(.+?) b/(.+)$")
for line in patch.splitlines():
if line.startswith("diff --git"):
m = diff_header_re.match(line)
if m:
# Remove 'a/' and 'b/' prefixes for -p0.
rewritten_lines.append(f"diff --git {m.group(1)} {m.group(2)}")
else:
rewritten_lines.append(line)
elif line.startswith("+++ "):
# Remove '+++ b/' prefix.
rewritten_lines.append(line.replace("+++ b/", "+++ ", 1))
elif line.startswith("--- "):
# Remove '--- a/' prefix, unless it's '--- /dev/null' for new files.
if line != "--- /dev/null":
rewritten_lines.append(line.replace("--- a/", "--- ", 1))
else:
rewritten_lines.append(line)
else:
rewritten_lines.append(line)
return "\n".join(rewritten_lines) + "\n"
def save_debug_artifacts(
repo_root: Path,
output_rel: Path,
raw_path: Path | None,
clean_path: Path | None,
sanitized_path: Path | None,
final_path: Path | None,
) -> None:
"""
Saves various stages of the generated patch to a debug directory within .git.
This is crucial for debugging AI model outputs and patch application issues.
Args:
repo_root: The root directory of the Git repository.
output_rel: The relative path of the output file being processed.
raw_path: Path to the raw AI output file.
clean_path: Path to the extracted diff file.
sanitized_path: Path to the sanitized diff file.
final_path: Path to the final patch file that was attempted to be applied.
"""
debug_dir = repo_root / ".git" / "ai-rules-debug"
debug_dir.mkdir(parents=True, exist_ok=True)
# Create a unique identifier for the debug artifacts.
identifier = f"{output_rel.as_posix().replace('/', '_')}-{os.getpid()}"
if raw_path and raw_path.exists():
shutil.copy(raw_path, debug_dir / f"{identifier}.raw.out")
if clean_path and clean_path.exists():
shutil.copy(clean_path, debug_dir / f"{identifier}.clean.diff")
if sanitized_path and sanitized_path.exists():
shutil.copy(sanitized_path, debug_dir / f"{identifier}.sanitized.diff")
if final_path and final_path.exists():
shutil.copy(final_path, debug_dir / f"{identifier}.final.diff")
def apply_patch(repo_root: Path, patch_file: Path, patch_level: str, output_rel: Path) -> None:
"""
Applies a generated patch to the Git repository, attempting a 3-way merge
first and falling back to a strict apply if necessary.
Args:
repo_root: The root directory of the Git repository.
patch_file: The path to the patch file to apply.
patch_level: The patch level (e.g., '-p1').
output_rel: The relative path of the file the patch is intended for.
Raises:
PatchGenerationError: If both 3-way and strict patch application fail.
"""
absolute_patch = patch_file.resolve()
# Note: Patches are generated from staged content (git diff --cached).
# The --index flag applies patches to both working tree and index, which is
# what we want. Do NOT unstage files before applying - that changes the base
# state and causes patch application to fail.
# First, try a dry-run check with --check to see if the patch applies cleanly.
check_args = ["git", "apply", patch_level, "--index", "--check", absolute_patch.as_posix()]
if run(check_args, cwd=repo_root, check=False).returncode == 0:
# If it applies cleanly, perform the actual apply.
run(["git", "apply", patch_level, "--index", absolute_patch.as_posix()], cwd=repo_root)
return
# If direct apply --check fails, try 3-way merge with a check first.
three_way_check_args = ["git", "apply", patch_level, "--index", "--3way", "--recount", "--whitespace=nowarn", absolute_patch.as_posix()]
if run(three_way_check_args + ["--check"], cwd=repo_root, check=False).returncode == 0:
# If 3-way check passes, perform the actual 3-way apply.
run(three_way_check_args, cwd=repo_root)
return
# Special handling for new files: if the patch creates a new file ('--- /dev/null'),
# a simple 'git apply' might work, followed by 'git add'.
text = patch_file.read_text(encoding="utf-8")
if "--- /dev/null" in text:
# Try applying without --index and then explicitly add the file.
if run(["git", "apply", patch_level, absolute_patch.as_posix()], cwd=repo_root, check=False).returncode == 0:
run(["git", "add", "--", output_rel.as_posix()], cwd=repo_root)
return
# If all attempts fail, raise an error.
raise PatchGenerationError("Failed to apply patch (strict and 3-way both failed)")
def run(args: list[str], cwd: Path, check: bool = True) -> subprocess.CompletedProcess[str]:
result = subprocess.run(
args,
cwd=str(cwd),
text=True,
capture_output=True,
)
if check and result.returncode != 0:
raise PatchGenerationError(f"Command {' '.join(args)} failed: {result.stderr.strip()}")
return result
def _run_ai_command(command: str, prompt: str, cwd: Path) -> tuple[str, str, str, int]:
"""Run an AI command and return (executed_command, stdout, stderr, returncode)."""
if command.strip().startswith("codex"):
return _run_codex_command(command, prompt, cwd)
result = subprocess.run(
command,
input=prompt,
text=True,
capture_output=True,
cwd=str(cwd),
shell=True,
)
raw_stdout = result.stdout or ""
stderr = result.stderr.strip()
return command, raw_stdout, stderr, result.returncode
def _run_codex_command(command: str, prompt: str, cwd: Path) -> tuple[str, str, str, int]:
"""Execute codex CLI with JSON output and extract last message text."""
json_command = _ensure_codex_json(command)
result = subprocess.run(
json_command,
input=prompt,
text=True,
capture_output=True,
cwd=str(cwd),
shell=True,
)
stdout_text = result.stdout or ""
last_message = _extract_codex_last_message(stdout_text)
return json_command, last_message, result.stderr.strip(), result.returncode
def _ensure_codex_json(command: str) -> str:
"""Ensure codex command runs via `codex exec --json --color=never -` for machine parsing."""
tokens = shlex.split(command)
if not tokens:
return command
if tokens[0] != "codex":
return command
if len(tokens) == 1 or tokens[1] != "exec":
tokens.insert(1, "exec")
if "--json" not in tokens:
tokens.append("--json")
if not any(t.startswith("--color") for t in tokens):
tokens.append("--color=never")
if "-" not in tokens:
tokens.append("-")
return shlex.join(tokens)
def _extract_codex_last_message(stdout_text: str) -> str:
"""Parse codex JSONL output and return the final agent message text."""
last_text = ""
for line in stdout_text.splitlines():
line = line.strip()
if not line:
continue
try:
payload = json.loads(line)
except json.JSONDecodeError:
continue
item = payload.get("item")
if isinstance(item, dict) and item.get("type") == "agent_message":
text = item.get("text")
if isinstance(text, str):
last_text = text
return last_text