CascadingDev/automation/patcher.py

331 lines
11 KiB
Python

"""
AI-powered patch generation and application utilities.
This module ports the proven bash hook logic into Python so the orchestration
pipeline can be tested and extended more easily.
"""
from __future__ import annotations
import os
import re
import shutil
import subprocess
import tempfile
from dataclasses import dataclass
from pathlib import Path
from automation.config import RulesConfig
class PatchGenerationError(RuntimeError):
pass
@dataclass
class ModelConfig:
command: str = os.environ.get("CDEV_AI_COMMAND", "claude -p")
def generate_output(
repo_root: Path,
rules: RulesConfig,
model: ModelConfig,
source_rel: Path,
output_rel: Path,
instruction: str,
) -> None:
"""
Generate/refresh an output artifact using staged context + AI diff.
"""
repo_root = repo_root.resolve()
source_rel = source_rel
output_rel = output_rel
(repo_root / output_rel).parent.mkdir(parents=True, exist_ok=True)
ensure_intent_to_add(repo_root, output_rel)
source_diff = git_diff_cached(repo_root, source_rel)
source_content = git_show_cached(repo_root, source_rel)
output_preimage, output_hash = read_output_preimage(repo_root, output_rel)
prompt = build_prompt(
source_rel=source_rel,
output_rel=output_rel,
source_diff=source_diff,
source_content=source_content,
output_content=output_preimage,
instruction=instruction,
)
raw_patch = call_model(model, prompt, cwd=repo_root)
with tempfile.TemporaryDirectory(prefix="cdev-patch-") as tmpdir_str:
tmpdir = Path(tmpdir_str)
raw_path = tmpdir / "raw.out"
clean_path = tmpdir / "clean.diff"
sanitized_path = tmpdir / "sanitized.diff"
raw_path.write_text(raw_patch, encoding="utf-8")
extracted = extract_patch_with_markers(raw_path.read_text(encoding="utf-8"))
clean_path.write_text(extracted, encoding="utf-8")
sanitized = sanitize_unified_patch(clean_path.read_text(encoding="utf-8"))
if "--- /dev/null" in sanitized and "new file mode" not in sanitized:
sanitized = sanitized.replace("--- /dev/null", "new file mode 100644\n--- /dev/null", 1)
sanitized_path.write_text(sanitized, encoding="utf-8")
patch_level = "-p1"
final_patch_path = sanitized_path
save_debug_artifacts(repo_root, output_rel, raw_path, clean_path, sanitized_path, final_patch_path)
if not final_patch_path.read_text(encoding="utf-8").strip():
raise PatchGenerationError("AI returned empty patch")
apply_patch(repo_root, final_patch_path, patch_level, output_rel)
def ensure_intent_to_add(repo_root: Path, rel_path: Path) -> None:
if git_ls_files(repo_root, rel_path):
return
run(["git", "add", "-N", "--", rel_path.as_posix()], cwd=repo_root, check=False)
def git_ls_files(repo_root: Path, rel_path: Path) -> bool:
result = run(
["git", "ls-files", "--error-unmatch", "--", rel_path.as_posix()],
cwd=repo_root,
check=False,
)
return result.returncode == 0
def git_diff_cached(repo_root: Path, rel_path: Path) -> str:
result = run(
["git", "diff", "--cached", "--unified=2", "--", rel_path.as_posix()],
cwd=repo_root,
check=False,
)
return result.stdout
def git_show_cached(repo_root: Path, rel_path: Path) -> str:
result = run(
["git", "show", f":{rel_path.as_posix()}"],
cwd=repo_root,
check=False,
)
if result.returncode == 0:
return result.stdout
file_path = repo_root / rel_path
if file_path.exists():
return file_path.read_text(encoding="utf-8")
return ""
def read_output_preimage(repo_root: Path, rel_path: Path) -> tuple[str, str]:
staged_hash = run(
["git", "ls-files", "--stage", "--", rel_path.as_posix()],
cwd=repo_root,
check=False,
)
blob_hash = "0" * 40
if staged_hash.returncode == 0 and staged_hash.stdout.strip():
show = run(["git", "show", f":{rel_path.as_posix()}"], cwd=repo_root, check=False)
content = show.stdout if show.returncode == 0 else ""
first_field = staged_hash.stdout.strip().split()[1]
blob_hash = first_field
return content, blob_hash
file_path = repo_root / rel_path
if file_path.exists():
content = file_path.read_text(encoding="utf-8")
blob_hash = run(
["git", "hash-object", file_path.as_posix()],
cwd=repo_root,
check=False,
).stdout.strip() or blob_hash
return content, blob_hash
return "", blob_hash
PROMPT_TEMPLATE = """You are assisting with automated artifact generation during a git commit.
SOURCE FILE: {source_path}
OUTPUT FILE: {output_path}
=== SOURCE FILE CHANGES (staged) ===
{source_diff}
=== SOURCE FILE CONTENT (staged) ===
{source_content}
=== CURRENT OUTPUT CONTENT (use this as the preimage) ===
{output_content}
=== GENERATION INSTRUCTIONS ===
{instruction}
=== OUTPUT FORMAT REQUIREMENTS ===
Wrap your unified diff with these exact markers:
<<<AI_DIFF_START>>>
[your diff here]
<<<AI_DIFF_END>>>
For NEW FILES, use these headers exactly:
--- /dev/null
+++ b/{output_path}
=== TASK ===
Create or update {output_path} according to the instructions above.
Output ONLY a unified diff patch in proper git format:
- Use format: diff --git a/{output_path} b/{output_path}
- (Optional) You may include an "index ..." line, but it will be ignored
- Include complete hunks with context lines
- No markdown fences, no explanations, just the patch
Start with: <<<AI_DIFF_START>>>
End with: <<<AI_DIFF_END>>>
Only include the diff between these markers.
If the output file doesn't exist, create it from scratch in the patch.
"""
def build_prompt(
source_rel: Path,
output_rel: Path,
source_diff: str,
source_content: str,
output_content: str,
instruction: str,
) -> str:
return PROMPT_TEMPLATE.format(
source_path=source_rel.as_posix(),
output_path=output_rel.as_posix(),
source_diff=source_diff.strip(),
source_content=source_content.strip(),
output_content=output_content.strip() or "(empty)",
instruction=instruction.strip(),
)
def call_model(model: ModelConfig, prompt: str, cwd: Path) -> str:
command = model.command
result = subprocess.run(
command,
input=prompt,
text=True,
capture_output=True,
cwd=str(cwd),
shell=True,
)
if result.returncode != 0:
raise PatchGenerationError(f"AI command failed ({result.returncode}): {result.stderr.strip()}")
return result.stdout
def extract_patch_with_markers(raw_output: str) -> str:
start_marker = "<<<AI_DIFF_START>>>"
end_marker = "<<<AI_DIFF_END>>>"
if start_marker in raw_output:
start_idx = raw_output.index(start_marker) + len(start_marker)
end_idx = raw_output.find(end_marker, start_idx)
if end_idx == -1:
raise PatchGenerationError("AI output missing end marker")
return raw_output[start_idx:end_idx].strip()
match = re.search(r"^diff --git .*", raw_output, re.MULTILINE | re.DOTALL)
if match:
return raw_output[match.start() :].strip()
raise PatchGenerationError("AI output did not contain a diff")
def sanitize_unified_patch(patch: str) -> str:
lines = patch.replace("\r", "").splitlines()
cleaned = []
for line in lines:
if line.startswith("index ") or line.startswith("similarity index ") or line.startswith("rename from ") or line.startswith("rename to "):
continue
cleaned.append(line)
text = "\n".join(cleaned)
diff_start = text.find("diff --git")
if diff_start == -1:
raise PatchGenerationError("Sanitized patch missing diff header")
return text[diff_start:] + "\n"
def rewrite_patch_for_p0(patch: str) -> str:
rewritten_lines = []
diff_header_re = re.compile(r"^diff --git a/(.+?) b/(.+)$")
for line in patch.splitlines():
if line.startswith("diff --git"):
m = diff_header_re.match(line)
if m:
rewritten_lines.append(f"diff --git {m.group(1)} {m.group(2)}")
else:
rewritten_lines.append(line)
elif line.startswith("+++ "):
rewritten_lines.append(line.replace("+++ b/", "+++ ", 1))
elif line.startswith("--- "):
if line != "--- /dev/null":
rewritten_lines.append(line.replace("--- a/", "--- ", 1))
else:
rewritten_lines.append(line)
else:
rewritten_lines.append(line)
return "\n".join(rewritten_lines) + "\n"
def save_debug_artifacts(
repo_root: Path,
output_rel: Path,
raw_path: Path,
clean_path: Path,
sanitized_path: Path,
final_path: Path,
) -> None:
debug_dir = repo_root / ".git" / "ai-rules-debug"
debug_dir.mkdir(parents=True, exist_ok=True)
identifier = f"{output_rel.as_posix().replace('/', '_')}-{os.getpid()}"
shutil.copy(raw_path, debug_dir / f"{identifier}.raw.out")
shutil.copy(clean_path, debug_dir / f"{identifier}.clean.diff")
shutil.copy(sanitized_path, debug_dir / f"{identifier}.sanitized.diff")
if final_path.exists():
shutil.copy(final_path, debug_dir / f"{identifier}.final.diff")
def apply_patch(repo_root: Path, patch_file: Path, patch_level: str, output_rel: Path) -> None:
absolute_patch = patch_file.resolve()
args = ["git", "apply", patch_level, "--index", "--check", absolute_patch.as_posix()]
if run(args, cwd=repo_root, check=False).returncode == 0:
run(["git", "apply", patch_level, "--index", absolute_patch.as_posix()], cwd=repo_root)
return
three_way = ["git", "apply", patch_level, "--index", "--3way", "--recount", "--whitespace=nowarn", absolute_patch.as_posix()]
if run(three_way + ["--check"], cwd=repo_root, check=False).returncode == 0:
run(three_way, cwd=repo_root)
return
text = patch_file.read_text(encoding="utf-8")
if "--- /dev/null" in text:
if run(["git", "apply", patch_level, absolute_patch.as_posix()], cwd=repo_root, check=False).returncode == 0:
run(["git", "add", "--", output_rel.as_posix()], cwd=repo_root)
return
raise PatchGenerationError("Failed to apply patch (strict and 3-way both failed)")
def run(args: list[str], cwd: Path, check: bool = True) -> subprocess.CompletedProcess[str]:
result = subprocess.run(
args,
cwd=str(cwd),
text=True,
capture_output=True,
)
if check and result.returncode != 0:
raise PatchGenerationError(f"Command {' '.join(args)} failed: {result.stderr.strip()}")
return result