""" AI-powered patch generation and application utilities. This module ports the proven bash hook logic into Python so the orchestration pipeline can be tested and extended more easily. """ from __future__ import annotations import os import re import shutil import subprocess import tempfile from dataclasses import dataclass from pathlib import Path from automation.config import RulesConfig class PatchGenerationError(RuntimeError): pass @dataclass class ModelConfig: """Configuration for invoking the AI model command.""" # The command to execute the AI model. Defaults to 'claude -p' or CDEV_AI_COMMAND env var. command: str = os.environ.get("CDEV_AI_COMMAND", "claude -p") def generate_output( repo_root: Path, rules: RulesConfig, model: ModelConfig, source_rel: Path, output_rel: Path, instruction: str, ) -> None: """ Generates or refreshes an output artifact using AI, based on staged changes and a given instruction. The process involves: 1. Ensuring the output file is known to Git (for new files). 2. Gathering relevant context (source diff, source content, existing output content). 3. Building a prompt for the AI model. 4. Calling the AI model to generate a unified diff. 5. Extracting and sanitizing the generated patch. 6. Applying the patch to the repository. Args: repo_root: The root directory of the Git repository. rules: The RulesConfig object for resolving template variables and paths. model: The ModelConfig object specifying the AI command. source_rel: The path to the source file relative to repo_root. output_rel: The path to the output file relative to repo_root. instruction: The instruction for the AI model. Raises: PatchGenerationError: If AI output is empty or patch application fails. """ repo_root = repo_root.resolve() # Ensure paths are relative to the repository root for consistent Git operations. source_rel = source_rel output_rel = output_rel # Ensure the parent directory for the output file exists. (repo_root / output_rel).parent.mkdir(parents=True, exist_ok=True) # Ensure Git knows about the output file, especially if it's new. # This allows 'git apply --index' to work correctly for new files. ensure_intent_to_add(repo_root, output_rel) # Gather context for the AI prompt. source_diff = git_diff_cached(repo_root, source_rel) source_content = git_show_cached(repo_root, source_rel) # Get the current content of the output file (pre-image) and its Git hash. output_preimage, output_hash = read_output_preimage(repo_root, output_rel) # Build the comprehensive prompt for the AI model. prompt = build_prompt( source_rel=source_rel, output_rel=output_rel, source_diff=source_diff, source_content=source_content, output_content=output_preimage, instruction=instruction, ) # Call the AI model and get its raw output. raw_patch = call_model(model, prompt, cwd=repo_root) # Use a temporary directory for storing intermediate patch files for debugging. with tempfile.TemporaryDirectory(prefix="cdev-patch-") as tmpdir_str: tmpdir = Path(tmpdir_str) raw_path = tmpdir / "raw.out" clean_path = tmpdir / "clean.diff" sanitized_path = tmpdir / "sanitized.diff" raw_path.write_text(raw_patch, encoding="utf-8") try: # Extract the actual diff content from the AI's raw output, using markers. extracted = extract_patch_with_markers(raw_path.read_text(encoding="utf-8")) clean_path.write_text(extracted, encoding="utf-8") except PatchGenerationError as e: # If extraction fails, save debug artifacts and re-raise the error. save_debug_artifacts(repo_root, output_rel, raw_path, None, None, None) raise # Sanitize the extracted patch to remove unwanted git metadata lines. sanitized = sanitize_unified_patch(clean_path.read_text(encoding="utf-8")) # Special handling for new files: ensure 'new file mode' is present if '--- /dev/null' is. if "--- /dev/null" in sanitized and "new file mode" not in sanitized: sanitized = sanitized.replace("--- /dev/null", "new file mode 100644\n--- /dev/null", 1) sanitized_path.write_text(sanitized, encoding="utf-8") # Determine the patch level. -p1 is common for diffs generated from repo root. patch_level = "-p1" final_patch_path = sanitized_path # Save all intermediate patch files for debugging purposes. save_debug_artifacts(repo_root, output_rel, raw_path, clean_path, sanitized_path, final_patch_path) # Check if the final patch is empty after sanitization. # This is normal when AI determines no changes are needed (e.g., gated outputs) if not final_patch_path.read_text(encoding="utf-8").strip(): # Empty patch is not an error - AI decided not to make changes # Common cases: # - implementation_gate_writer when status != READY_FOR_IMPLEMENTATION # - AI determines output is already correct return # Skip silently # Apply the generated patch to the Git repository. apply_patch(repo_root, final_patch_path, patch_level, output_rel) def ensure_intent_to_add(repo_root: Path, rel_path: Path) -> None: """ Ensures Git has an 'intent to add' for a new file, so 'git apply --index' can create it correctly. If the file is already tracked, do nothing. Args: repo_root: The root directory of the Git repository. rel_path: The path to the file relative to repo_root. """ if git_ls_files(repo_root, rel_path): # File is already tracked by Git, no need for 'intent to add'. return # Use 'git add -N' to record the intent to add a new file without staging its content. run(["git", "add", "-N", "--", rel_path.as_posix()], cwd=repo_root, check=False) def git_ls_files(repo_root: Path, rel_path: Path) -> bool: """ Checks if a file is tracked by Git. Args: repo_root: The root directory of the Git repository. rel_path: The path to the file relative to repo_root. Returns: True if the file is tracked, False otherwise. """ result = run( ["git", "ls-files", "--error-unmatch", "--", rel_path.as_posix()], cwd=repo_root, check=False, ) return result.returncode == 0 def git_diff_cached(repo_root: Path, rel_path: Path) -> str: """ Retrieves the cached diff (staged changes) for a specific file. Args: repo_root: The root directory of the Git repository. rel_path: The path to the file relative to repo_root. Returns: The unified diff string for the staged changes of the file. """ result = run( ["git", "diff", "--cached", "--unified=2", "--", rel_path.as_posix()], cwd=repo_root, check=False, ) return result.stdout def git_show_cached(repo_root: Path, rel_path: Path) -> str: """ Retrieves the content of a file as it exists in the Git index (staged version). If the file is not in the index, it attempts to read it from the working tree. Args: repo_root: The root directory of the Git repository. rel_path: The path to the file relative to repo_root. Returns: The content of the file as a string. """ # Try to get content from the Git index. result = run( ["git", "show", f":{rel_path.as_posix()}"], cwd=repo_root, check=False, ) if result.returncode == 0: return result.stdout # If not in index, try to read from the working tree (e.g., for new files not yet staged). file_path = repo_root / rel_path if file_path.exists(): return file_path.read_text(encoding="utf-8") return "" def read_output_preimage(repo_root: Path, rel_path: Path) -> tuple[str, str]: """ Reads the current content and Git blob hash of the output file. This is used as the 'pre-image' for generating the AI patch. It first checks the Git index, then the working tree. Args: repo_root: The root directory of the Git repository. rel_path: The path to the output file relative to repo_root. Returns: A tuple containing (content_string, blob_hash_string). Returns empty string and default hash if the file does not exist. """ # Check if the file is staged and get its blob hash. staged_hash_result = run( ["git", "ls-files", "--stage", "--", rel_path.as_posix()], cwd=repo_root, check=False, ) blob_hash = "0" * 40 # Default hash if file is new or untracked. if staged_hash_result.returncode == 0 and staged_hash_result.stdout.strip(): # If staged, get its content from the index. show_result = run(["git", "show", f":{rel_path.as_posix()}"], cwd=repo_root, check=False) content = show_result.stdout if show_result.returncode == 0 else "" # Extract the blob hash from 'git ls-files --stage' output. # The output format is: first_field = staged_hash_result.stdout.strip().split()[1] blob_hash = first_field return content, blob_hash # If not staged, try to read from the working tree. file_path = repo_root / rel_path if file_path.exists(): content = file_path.read_text(encoding="utf-8") # Calculate blob hash for working tree file. blob_hash = run( ["git", "hash-object", file_path.as_posix()], cwd=repo_root, check=False, ).stdout.strip() or blob_hash return content, blob_hash # File does not exist. return "", blob_hash # This template defines the structure and content of the prompt sent to the AI model. # It includes context about the source file, current output, and specific instructions. PROMPT_TEMPLATE = """You are assisting with automated artifact generation during a git commit. SOURCE FILE: {source_path} OUTPUT FILE: {output_path} === SOURCE FILE CHANGES (staged) === {source_diff} === SOURCE FILE CONTENT (staged) === {source_content} === CURRENT OUTPUT CONTENT (use this as the preimage) === {output_content} === GENERATION INSTRUCTIONS === {instruction} === OUTPUT FORMAT REQUIREMENTS === Wrap your unified diff with these exact markers: <<>> [your diff here] <<>> For NEW FILES, use these headers exactly: --- /dev/null +++ b/{output_path} === TASK === Create or update {output_path} according to the instructions above. Output ONLY a unified diff patch in proper git format: - Use format: diff --git a/{output_path} b/{output_path} - (Optional) You may include an "index ..." line, but it will be ignored - Include complete hunks with context lines - No markdown fences, no explanations, just the patch Start with: <<>> End with: <<>> Only include the diff between these markers. If the output file doesn't exist, create it from scratch in the patch. """ def build_prompt( source_rel: Path, output_rel: Path, source_diff: str, source_content: str, output_content: str, instruction: str, ) -> str: """ Constructs the full prompt string for the AI model by formatting the PROMPT_TEMPLATE with the provided context and instructions. Args: source_rel: Relative path to the source file. output_rel: Relative path to the output file. source_diff: Git diff of the staged source file. source_content: Content of the staged source file. output_content: Current content of the output file (pre-image). instruction: Specific instructions for the AI. Returns: The formatted prompt string. """ return PROMPT_TEMPLATE.format( source_path=source_rel.as_posix(), output_path=output_rel.as_posix(), source_diff=source_diff.strip(), source_content=source_content.strip(), output_content=output_content.strip() or "(empty)", # Indicate if output content is empty. instruction=instruction.strip(), ) def call_model(model: ModelConfig, prompt: str, cwd: Path) -> str: """ Invokes the AI model command with the given prompt and captures its output. Args: model: The ModelConfig object containing the AI command. prompt: The input prompt string for the AI model. cwd: The current working directory for executing the command. Returns: The raw string output from the AI model. Raises: PatchGenerationError: If the AI command fails or returns an API error. """ command = model.command result = subprocess.run( command, input=prompt, text=True, capture_output=True, cwd=str(cwd), shell=True, # Use shell=True to allow command to be a string with arguments. ) # The Claude CLI (and potentially others) might return a non-zero exit code # even on successful responses if there are warnings or specific API behaviors. # We prioritize stdout content over return code for initial check. if result.stdout.strip(): # Check for specific API error messages in the output. if "API Error:" in result.stdout and "Overloaded" in result.stdout: raise PatchGenerationError("Claude API is overloaded (500 error) - please retry later") return result.stdout # If stdout is empty, then a non-zero return code indicates a true failure. if result.returncode != 0: raise PatchGenerationError(f"AI command failed ({result.returncode}): {result.stderr.strip()}") return result.stdout def extract_patch_with_markers(raw_output: str) -> str: """ Extracts the unified diff string from the AI model's raw output. It looks for specific start and end markers (<<>> / <<>>). If markers are not found, it attempts to find the first 'diff --git' line. Args: raw_output: The complete raw string output from the AI model. Returns: The extracted unified diff string. Raises: PatchGenerationError: If no valid diff markers or diff header are found. """ start_marker = "<<>>" end_marker = "<<>>" if start_marker in raw_output: start_idx = raw_output.index(start_marker) + len(start_marker) end_idx = raw_output.find(end_marker, start_idx) if end_idx == -1: raise PatchGenerationError("AI output missing end marker") return raw_output[start_idx:end_idx].strip() # Fallback: if markers are not present, try to find a 'diff --git' header. match = re.search(r"^diff --git .*", raw_output, re.MULTILINE | re.DOTALL) if match: return raw_output[match.start() :].strip() raise PatchGenerationError("AI output did not contain a diff") def sanitize_unified_patch(patch: str) -> str: """ Cleans up a unified diff patch by removing specific Git metadata lines that can interfere with 'git apply'. Args: patch: The raw unified diff string. Returns: The sanitized diff string, or empty string if no diff header found. An empty result indicates the AI chose not to make changes (e.g., gated output). """ lines = patch.replace("\r", "").splitlines() cleaned = [] for line in lines: # Skip lines that are typically added by 'git diff' but not needed for 'git apply'. if line.startswith("index ") or line.startswith("similarity index ") or line.startswith("rename from ") or line.startswith("rename to "): continue cleaned.append(line) text = "\n".join(cleaned) # Look for the 'diff --git' header diff_start = text.find("diff --git") if diff_start == -1: # No diff header means AI chose not to generate a patch # This is normal for gated outputs or when AI determines no changes needed return "" return text[diff_start:] + "\n" def rewrite_patch_for_p0(patch: str) -> str: """ Rewrites a unified diff patch to be compatible with 'git apply -p0'. This adjusts file paths in the diff headers to remove the 'a/' and 'b/' prefixes. Args: patch: The unified diff string. Returns: The rewritten diff string. """ rewritten_lines = [] diff_header_re = re.compile(r"^diff --git a/(.+?) b/(.+)$") for line in patch.splitlines(): if line.startswith("diff --git"): m = diff_header_re.match(line) if m: # Remove 'a/' and 'b/' prefixes for -p0. rewritten_lines.append(f"diff --git {m.group(1)} {m.group(2)}") else: rewritten_lines.append(line) elif line.startswith("+++ "): # Remove '+++ b/' prefix. rewritten_lines.append(line.replace("+++ b/", "+++ ", 1)) elif line.startswith("--- "): # Remove '--- a/' prefix, unless it's '--- /dev/null' for new files. if line != "--- /dev/null": rewritten_lines.append(line.replace("--- a/", "--- ", 1)) else: rewritten_lines.append(line) else: rewritten_lines.append(line) return "\n".join(rewritten_lines) + "\n" def save_debug_artifacts( repo_root: Path, output_rel: Path, raw_path: Path | None, clean_path: Path | None, sanitized_path: Path | None, final_path: Path | None, ) -> None: """ Saves various stages of the generated patch to a debug directory within .git. This is crucial for debugging AI model outputs and patch application issues. Args: repo_root: The root directory of the Git repository. output_rel: The relative path of the output file being processed. raw_path: Path to the raw AI output file. clean_path: Path to the extracted diff file. sanitized_path: Path to the sanitized diff file. final_path: Path to the final patch file that was attempted to be applied. """ debug_dir = repo_root / ".git" / "ai-rules-debug" debug_dir.mkdir(parents=True, exist_ok=True) # Create a unique identifier for the debug artifacts. identifier = f"{output_rel.as_posix().replace('/', '_')}-{os.getpid()}" if raw_path and raw_path.exists(): shutil.copy(raw_path, debug_dir / f"{identifier}.raw.out") if clean_path and clean_path.exists(): shutil.copy(clean_path, debug_dir / f"{identifier}.clean.diff") if sanitized_path and sanitized_path.exists(): shutil.copy(sanitized_path, debug_dir / f"{identifier}.sanitized.diff") if final_path and final_path.exists(): shutil.copy(final_path, debug_dir / f"{identifier}.final.diff") def apply_patch(repo_root: Path, patch_file: Path, patch_level: str, output_rel: Path) -> None: """ Applies a generated patch to the Git repository, attempting a 3-way merge first and falling back to a strict apply if necessary. Args: repo_root: The root directory of the Git repository. patch_file: The path to the patch file to apply. patch_level: The patch level (e.g., '-p1'). output_rel: The relative path of the file the patch is intended for. Raises: PatchGenerationError: If both 3-way and strict patch application fail. """ absolute_patch = patch_file.resolve() # First, try a dry-run check with --check to see if the patch applies cleanly. check_args = ["git", "apply", patch_level, "--index", "--check", absolute_patch.as_posix()] if run(check_args, cwd=repo_root, check=False).returncode == 0: # If it applies cleanly, perform the actual apply. run(["git", "apply", patch_level, "--index", absolute_patch.as_posix()], cwd=repo_root) return # If direct apply --check fails, try 3-way merge with a check first. three_way_check_args = ["git", "apply", patch_level, "--index", "--3way", "--recount", "--whitespace=nowarn", absolute_patch.as_posix()] if run(three_way_check_args + ["--check"], cwd=repo_root, check=False).returncode == 0: # If 3-way check passes, perform the actual 3-way apply. run(three_way_check_args, cwd=repo_root) return # Special handling for new files: if the patch creates a new file ('--- /dev/null'), # a simple 'git apply' might work, followed by 'git add'. text = patch_file.read_text(encoding="utf-8") if "--- /dev/null" in text: # Try applying without --index and then explicitly add the file. if run(["git", "apply", patch_level, absolute_patch.as_posix()], cwd=repo_root, check=False).returncode == 0: run(["git", "add", "--", output_rel.as_posix()], cwd=repo_root) return # If all attempts fail, raise an error. raise PatchGenerationError("Failed to apply patch (strict and 3-way both failed)") def run(args: list[str], cwd: Path, check: bool = True) -> subprocess.CompletedProcess[str]: result = subprocess.run( args, cwd=str(cwd), text=True, capture_output=True, ) if check and result.returncode != 0: raise PatchGenerationError(f"Command {' '.join(args)} failed: {result.stderr.strip()}") return result