"""Can-Do-Steps content fixer.

Reads reviewer instructions, plans the required fixes, applies them to the
roadmap/bits JSON files inside a target directory, and creates a git commit for
each change.
"""
from __future__ import annotations

import json
import subprocess
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence

from chains.base import build_chain, default_json_parser, parse_output
from utils.io import save_output
from utils.storage import path_exists, read_text

ALLOWED_TARGETS = {"bits", "roadmap"}


def fix_content(
    run_id: str,
    fix_input: str,
    content_dir: str,
    branch_name: str,
    force_text: bool = False,
) -> Dict[str, Any]:
    """Main entry point for the fix-content pipeline."""

    target_path = Path(content_dir).expanduser().resolve()
    if not target_path.exists():
        raise FileNotFoundError(f"Content directory not found: {target_path}")

    target_repo_root = _detect_repo_root(target_path)

    fix_input_path: Optional[Path] = None
    if path_exists(fix_input):
        try:
            fix_input_path = Path(fix_input).expanduser().resolve()
        except Exception:
            fix_input_path = None
    else:
        raise FileNotFoundError(f"Fix instructions file not found: {fix_input}")

    roadmap_path = _resolve_file(target_path, run_id, prefix="roadmap")
    bits_path = _resolve_file(target_path, run_id, prefix="bits")

    roadmap_data = _load_json(roadmap_path)
    bits_data = _load_json(bits_path)

    raw_instructions = read_text(str(fix_input_path) if fix_input_path and fix_input_path.exists() else fix_input)
    plan_result = build_chain(
        chain_name="plan_fix",
        pipeline="can-do-steps",
        run_id=run_id,
        input_variables={"raw_instructions": raw_instructions},
        force_text=force_text,
    )
    plan_data = parse_output(plan_result["output"], default_json_parser)
    plan_items = _extract_plan_items(plan_data)
    save_output(
        data={"fix_items": plan_items},
        pipeline="can-do-steps",
        step="plan_fix_structured",
        run_id=run_id,
        subfolder="logs",
    )

    if not plan_items:
        summary = {
            "planned_items": 0,
            "applied": 0,
            "skipped": [],
            "errors": [],
            "commits": [],
        }
        save_output(summary, "can-do-steps", "fix_content", run_id, subfolder="logs")
        print("No actionable fixes found in instructions.")
        return summary

    _ensure_branch(branch_name, target_repo_root)

    summary: Dict[str, Any] = {
        "planned_items": len(plan_items),
        "applied": 0,
        "skipped": [],
        "errors": [],
        "commits": [],
    }

    for item in plan_items:
        step_id = item.get("step_id")
        instruction = (item.get("instruction") or "").strip()
        targets = _normalize_targets(item.get("targets"))

        if not step_id or not instruction:
            summary["skipped"].append(
                {"step_id": step_id or "unknown", "reason": "missing step_id or instruction"}
            )
            continue

        print(f"🔧 Applying fix for step '{step_id}' targeting {targets}")

        roadmap_context = _build_roadmap_context(roadmap_data, step_id)
        bits_context = _build_bits_context(bits_data, step_id)

        include_roadmap = "roadmap" in targets
        include_bits = "bits" in targets

        missing_targets: List[str] = []
        if include_roadmap and roadmap_context is None:
            missing_targets.append("roadmap")
        if include_bits and bits_context is None:
            missing_targets.append("bits")
        if missing_targets:
            summary["errors"].append(
                {
                    "step_id": step_id,
                    "error": f"step not found in {', '.join(missing_targets)} data",
                }
            )
            print(
                f"⚠️ Step '{step_id}' missing in {', '.join(missing_targets)} data; skipping"
            )
            continue

        try:
            apply_payload = _run_apply_fix_chain(
                run_id=run_id,
                step_id=step_id,
                instruction=instruction,
                targets=targets,
                roadmap_context=roadmap_context,
                bits_context=bits_context,
                force_text=force_text,
            )
        except Exception as exc:  # noqa: BLE001
            summary["errors"].append({"step_id": step_id, "error": str(exc)})
            print(f"⚠️ Apply fix failed for {step_id}: {exc}")
            continue

        roadmap_snippet_updated = (
            _extract_updated_snippet(apply_payload, "roadmap_snippet_updated")
            if include_roadmap
            else None
        )
        bits_snippet_updated = (
            _extract_updated_snippet(apply_payload, "bits_snippet_updated")
            if include_bits
            else None
        )
        commit_message = (apply_payload.get("commit_message") or f"Fix content for {step_id}").strip()

        files_to_commit: List[Path] = []
        roadmap_changed = False
        bits_changed = False

        if include_roadmap and roadmap_snippet_updated is not None:
            roadmap_changed = _replace_roadmap_snippet(roadmap_data, step_id, roadmap_snippet_updated)
            if roadmap_changed:
                _write_json(roadmap_path, roadmap_data)
                files_to_commit.append(roadmap_path)

        if include_bits and bits_snippet_updated is not None:
            bits_changed = _replace_bits_snippet(bits_data, step_id, bits_snippet_updated)
            if bits_changed:
                _write_json(bits_path, bits_data)
                files_to_commit.append(bits_path)

        if not files_to_commit:
            summary["skipped"].append(
                {"step_id": step_id, "reason": "model returned no actionable changes"}
            )
            continue

        try:
            _stage_and_commit(files_to_commit, commit_message, target_repo_root)
        except Exception as exc:  # noqa: BLE001
            summary["errors"].append({"step_id": step_id, "error": str(exc)})
            print(f"⚠️ Git commit failed for {step_id}: {exc}")
            continue

        summary["applied"] += 1
        summary["commits"].append(
            {
                "step_id": step_id,
                "commit_message": commit_message,
                "files": [str(_relativize_path(p, target_repo_root)) for p in files_to_commit],
            }
        )

    save_output(summary, "can-do-steps", "fix_content", run_id, subfolder="logs")
    return summary


def _resolve_file(base_dir: Path, run_id: str, prefix: str) -> Path:
    candidates = [
        base_dir / f"{prefix}-{run_id}.json",
        base_dir / f"{prefix}_{run_id}.json",
        base_dir / f"{run_id}-{prefix}.json",
        base_dir / f"{run_id}_{prefix}.json",
        base_dir / f"{prefix}.json",
    ]
    for candidate in candidates:
        if candidate.exists():
            return candidate.resolve()
    raise FileNotFoundError(f"Could not find {prefix} file for run '{run_id}' in {base_dir}")


def _load_json(path: Path) -> Any:
    with path.open("r", encoding="utf-8") as infile:
        return json.load(infile)


def _write_json(path: Path, data: Any) -> None:
    with path.open("w", encoding="utf-8") as outfile:
        json.dump(data, outfile, indent=2, ensure_ascii=False)
        outfile.write("\n")


def _extract_plan_items(plan_data: Any) -> List[Dict[str, Any]]:
    if isinstance(plan_data, dict) and isinstance(plan_data.get("fix_items"), list):
        return plan_data["fix_items"]
    if isinstance(plan_data, list):
        return plan_data
    return []


def _normalize_targets(raw_targets: Any) -> List[str]:
    if isinstance(raw_targets, str):
        parts = [raw_targets]
    elif isinstance(raw_targets, Sequence):
        parts = [str(item) for item in raw_targets]
    else:
        parts = []

    cleaned = []
    for part in parts:
        normalized = part.strip().lower()
        if normalized in ALLOWED_TARGETS and normalized not in cleaned:
            cleaned.append(normalized)

    return cleaned or ["bits"]


def _build_roadmap_context(roadmap_data: Dict[str, Any], step_id: str) -> Optional[Dict[str, Any]]:
    tracks = roadmap_data.get("tracks", []) if isinstance(roadmap_data, dict) else []
    for track in tracks:
        for path in track.get("paths", []):
            for step in path.get("steps", []):
                if step.get("id") == step_id or step.get("slug") == step_id:
                    return {
                        "track": {"id": track.get("id"), "title": track.get("title")},
                        "path": {"id": path.get("id"), "title": path.get("title")},
                        "step": step,
                    }
    return None


def _build_bits_context(bits_data: Any, step_id: str) -> Optional[Any]:
    if isinstance(bits_data, dict):
        if step_id in bits_data:
            return bits_data[step_id]
    elif isinstance(bits_data, list):
        matches = [bit for bit in bits_data if bit.get("stepId") == step_id]
        if matches:
            return {"bits": matches}
    return None


def _run_apply_fix_chain(
    run_id: str,
    step_id: str,
    instruction: str,
    targets: List[str],
    roadmap_context: Optional[Dict[str, Any]],
    bits_context: Optional[Any],
    force_text: bool,
) -> Dict[str, Any]:
    include_roadmap = "roadmap" in targets
    include_bits = "bits" in targets

    roadmap_text = (
        json.dumps(roadmap_context["step"], indent=2, ensure_ascii=False)
        if include_roadmap and roadmap_context
        else "Roadmap edits not requested; ignore this snippet."
    )
    bits_text = (
        json.dumps(bits_context, indent=2, ensure_ascii=False)
        if include_bits and bits_context
        else "Bits edits not requested; ignore this snippet."
    )

    apply_result = build_chain(
        chain_name="apply_fix",
        pipeline="can-do-steps",
        run_id=run_id,
        input_variables={
            "step_id": step_id,
            "instruction": instruction,
            "targets": ", ".join(targets),
            "roadmap_step": roadmap_text,
            "bits_data": bits_text,
        },
        force_text=force_text,
    )
    return parse_output(apply_result["output"], default_json_parser)


def _extract_updated_snippet(payload: Dict[str, Any], field_name: str) -> Optional[Any]:
    snippet = payload.get(field_name)
    if snippet is None:
        return None
    if isinstance(snippet, str):
        stripped = snippet.strip()
        if not stripped or "ignore this snippet" in stripped.lower():
            return None
        try:
            snippet = json.loads(stripped)
        except json.JSONDecodeError as exc:
            raise ValueError(f"Invalid JSON in {field_name}: {exc}") from exc
    return snippet


def _replace_roadmap_snippet(roadmap_data: Dict[str, Any], step_id: str, new_snippet: Any) -> bool:
    if not isinstance(new_snippet, dict):
        print("   • roadmap_snippet_updated must be an object; skipping")
        return False

    tracks = roadmap_data.get("tracks", []) if isinstance(roadmap_data, dict) else []
    for track in tracks:
        for path in track.get("paths", []):
            steps = path.get("steps", [])
            for idx, step in enumerate(steps):
                if step.get("id") == step_id or step.get("slug") == step_id:
                    if steps[idx] != new_snippet:
                        steps[idx] = new_snippet
                        print(f"   • Replaced roadmap snippet for {step_id}")
                        return True
                    return False

    print(f"   • Roadmap step not found for {step_id}; skipping roadmap update")
    return False


def _replace_bits_snippet(bits_data: Any, step_id: str, new_snippet: Any) -> bool:
    if isinstance(bits_data, dict):
        if not isinstance(new_snippet, dict):
            print("   • bits_snippet_updated must be an object for dict-based bits")
            return False
        if step_id not in bits_data:
            print(f"   • Bits entry missing for {step_id}")
            return False
        if bits_data[step_id] != new_snippet:
            bits_data[step_id] = new_snippet
            print(f"   • Replaced bits entry for {step_id}")
            return True
        return False

    if isinstance(bits_data, list):
        if isinstance(new_snippet, dict):
            new_bits = new_snippet.get("bits")
        else:
            new_bits = new_snippet
        if not isinstance(new_bits, list):
            print("   • bits_snippet_updated must include a 'bits' list for list-based bits")
            return False

        changed = False
        for new_bit in new_bits:
            if not isinstance(new_bit, dict):
                continue
            bit_id = new_bit.get("id") or new_bit.get("slug")
            if not bit_id:
                continue
            idx = next(
                (
                    i
                    for i, bit in enumerate(bits_data)
                    if bit.get("id") == bit_id or bit.get("slug") == bit_id
                ),
                None,
            )
            if idx is None:
                print(f"   • Bit {bit_id} not found in root list; skipping")
                continue
            if bits_data[idx] != new_bit:
                bits_data[idx] = new_bit
                changed = True
        if changed:
            print(f"   • Updated bits for {step_id}")
        return changed

    print("   • Unsupported bits data format; skipping bits update")
    return False


def _detect_repo_root(start_path: Path) -> Path:
    result = subprocess.run(
        ["git", "rev-parse", "--show-toplevel"],
        cwd=start_path,
        capture_output=True,
        text=True,
    )
    if result.returncode != 0:
        raise RuntimeError(
            "Unable to determine git repository for target directory: "
            f"{start_path} -> {result.stderr.strip() or result.stdout.strip()}"
        )
    repo_root = Path(result.stdout.strip()).resolve()
    print(f"📁 Using repo root {repo_root} for git operations")
    return repo_root


def _ensure_branch(branch_name: str, repo_root: Path) -> None:
    exists = subprocess.run(
        ["git", "rev-parse", "--verify", branch_name],
        cwd=repo_root,
        capture_output=True,
        text=True,
    ).returncode == 0

    if exists:
        print(f"📂 Checking out existing branch '{branch_name}'")
        _run_git(["checkout", branch_name], repo_root)
    else:
        print(f"🌱 Creating branch '{branch_name}'")
        _run_git(["checkout", "-b", branch_name], repo_root)


def _stage_and_commit(files: List[Path], message: str, repo_root: Path) -> None:
    rel_paths = [str(_relativize_path(path, repo_root)) for path in files]
    for path in rel_paths:
        _run_git(["add", path], repo_root)
    _run_git(["commit", "-m", message], repo_root)
    print(f"✅ Created commit: {message}")


def _run_git(args: List[str], repo_root: Path) -> None:
    result = subprocess.run(["git", *args], cwd=repo_root, capture_output=True, text=True)
    if result.returncode != 0:
        raise RuntimeError(result.stderr.strip() or result.stdout.strip() or "git command failed")


def _relativize_path(path: Path, repo_root: Path) -> Path:
    try:
        return path.resolve().relative_to(repo_root)
    except ValueError:
        return path
