"""
Can-Do-Steps change media component.
Transforms existing images into a target style via image-to-image editing.
"""
import os
import concurrent.futures
from datetime import datetime
from typing import Dict, Any, List, Optional, Tuple

from utils.io import save_output, load_latest_output, ensure_dir
from chains.base import build_chain  # noqa: F401 (reserved for future prompt-based changes)
from utils.image_generator import edit_and_save_image
from utils.media_selector import select_media_files
from utils.storage import path_exists, read_json, write_json, copy_path


def resolve_change_media_prompt(run_id: Optional[str] = None, prompt_id: Optional[str] = None) -> Tuple[str, str]:
    """
    Resolve the prompt file to use for change_media.
    """
    prompt_dir = os.path.join("prompts", "can-do-steps")
    default_chain_name = "change_media"
    default_prompt_path = os.path.join(prompt_dir, f"{default_chain_name}.txt")

    if prompt_id:
        candidate_filename = f"{default_chain_name}_{prompt_id}.txt"
        candidate_prompt_path = os.path.join(prompt_dir, candidate_filename)
        if os.path.exists(candidate_prompt_path):
            chain_name = f"{default_chain_name}_{prompt_id}"
            print(f"Using prompt override: {candidate_prompt_path}")
            return chain_name, candidate_prompt_path
        else:
            print(f"Prompt override not found ({candidate_prompt_path}); falling back to run-specific/default prompt.")

    if run_id:
        candidate_filename = f"{default_chain_name}_{run_id}.txt"
        candidate_prompt_path = os.path.join(prompt_dir, candidate_filename)
        if os.path.exists(candidate_prompt_path):
            chain_name = f"{default_chain_name}_{run_id}"
            print(f"Using roadmap-specific prompt: {candidate_prompt_path}")
            return chain_name, candidate_prompt_path

    print(f"Using default prompt: {default_prompt_path}")
    return default_chain_name, default_prompt_path


def load_prompt_template(prompt_path: str) -> str:
    """Read the full prompt template from disk."""
    try:
        if not os.path.exists(prompt_path):
            print(f"  ⚠️ Prompt file not found: {prompt_path}")
            return ""

        with open(prompt_path, "r", encoding="utf-8") as f:
            return f.read()
    except Exception as exc:
        print(f"  ⚠️ Could not read prompt template: {exc}")
        return ""


def load_roadmap_file(run_id: str) -> Optional[Dict[str, Any]]:
    """Load the roadmap JSON file for the given run ID."""
    roadmap_path = f"output/can-do-steps/{run_id}/roadmap-{run_id}.json"

    if not path_exists(roadmap_path):
        print(f"⚠️ Roadmap file not found: {roadmap_path}")
        return None

    try:
        data = read_json(roadmap_path)
        print(f"✓ Loaded roadmap file: {roadmap_path}")
        return data
    except Exception as exc:
        print(f"⚠️ Error loading roadmap file {roadmap_path}: {exc}")
    return None


def apply_custom_images_to_roadmap(
    roadmap_data: Dict[str, Any],
    updates: List[Dict[str, str]]
) -> Tuple[Dict[str, Any], bool, List[str]]:
    """
    Insert custom image entries for matching steps at the start of images list.
    """
    if not roadmap_data or not updates:
        return roadmap_data, False, []

    update_lookup = {entry["step_key"]: entry for entry in updates if entry.get("step_key")}
    matched_keys = set()
    updated_steps: List[str] = []
    modified = False

    def _extract_step_image_files(step_data: Dict[str, Any]) -> List[str]:
        files: List[str] = []
        images = step_data.get("images")
        if isinstance(images, list):
            for item in images:
                if isinstance(item, dict) and item.get("file"):
                    files.append(item["file"])
                elif isinstance(item, str):
                    files.append(item)
        elif isinstance(images, dict):
            if images.get("file"):
                files.append(images["file"])
        elif isinstance(images, str):
            files.append(images)
        return files

    for track in roadmap_data.get("tracks", []):
        for path in track.get("paths", []):
            for step in path.get("steps", []):
                step_id = step.get("id")
                step_slug = step.get("slug")
                match_key = step_id or step_slug
                update_entry = update_lookup.get(match_key)
                if update_entry is None:
                    step_files = _extract_step_image_files(step)
                    for entry in updates:
                        source_name = entry.get("source_filename")
                        if source_name and source_name in step_files:
                            update_entry = entry
                            match_key = entry.get("step_key") or match_key
                            break

                if not update_entry:
                    continue

                matched_keys.add(match_key)
                image_entry = {
                    "file": update_entry["file"],
                    "thumbnail": update_entry["thumbnail"]
                }

                existing = step.get("images")
                if isinstance(existing, list):
                    already = any(
                        isinstance(item, dict) and item.get("file") == image_entry["file"]
                        for item in existing
                    )
                    if already:
                        print(f"  ⏭️ Roadmap already has image for step {match_key}: {image_entry['file']}")
                        continue
                if isinstance(existing, list):
                    step["images"] = [image_entry] + existing
                elif isinstance(existing, dict):
                    if existing.get("file") == image_entry["file"]:
                        print(f"  ⏭️ Roadmap already has image for step {match_key}: {image_entry['file']}")
                        continue
                    step["images"] = [image_entry, existing]
                elif isinstance(existing, str) and existing:
                    if existing == image_entry["file"]:
                        print(f"  ⏭️ Roadmap already has image for step {match_key}: {image_entry['file']}")
                        continue
                    step["images"] = [image_entry, {"file": existing}]
                else:
                    step["images"] = [image_entry]
                modified = True
                updated_steps.append(match_key)

    for step_key in update_lookup:
        if step_key not in matched_keys:
            print(f"  ⚠️ No matching step found for key: {step_key}")

    return roadmap_data, modified, updated_steps


def render_change_prompt(
    prompt_template: str,
    change_request: str,
    run_id: str,
    topic: str,
    target_level: str,
    item_title: str,
    item_description: str,
    source_filename: str
) -> str:
    """Substitute known placeholders into the change media prompt template."""
    prompt = prompt_template
    replacements = {
        "{change_request}": change_request,
        "{run_id}": run_id,
        "{topic}": topic,
        "{target_level}": target_level,
        "{item_title}": item_title,
        "{item_description}": item_description,
        "{source_filename}": source_filename
    }

    for placeholder, value in replacements.items():
        prompt = prompt.replace(placeholder, value or "")

    return prompt


def find_item_context(roadmap_data: Dict[str, Any], base_name: str) -> Tuple[str, str]:
    """
    Attempt to find a title/description in the roadmap that matches the filename stem.
    """
    if not roadmap_data:
        return base_name, ""

    # Check roadmap level
    if roadmap_data.get("id") == base_name or roadmap_data.get("slug") == base_name:
        return roadmap_data.get("title", base_name), roadmap_data.get("description", "")

    for track in roadmap_data.get("tracks", []):
        if track.get("id") == base_name or track.get("slug") == base_name:
            return track.get("title", base_name), track.get("description", "")
        for path in track.get("paths", []):
            if path.get("id") == base_name or path.get("slug") == base_name:
                return path.get("title", base_name), path.get("description", "")
            for step in path.get("steps", []):
                if step.get("id") == base_name or step.get("slug") == base_name:
                    return step.get("title", base_name), step.get("description", "")

    return base_name, ""


def change_media(
    run_id: str,
    source_dir: str,
    match_mode: Optional[str] = None,
    match: Optional[str] = None,
    overwrite: bool = False,
    change_request: str = "Restyle the source image into the target botanical illustration style while preserving subject fidelity.",
    prompt_id: Optional[str] = None,
    model: Optional[str] = None,
    quality: str = "standard",
    target_level: str = "all",
    limit: Optional[int] = None,
    max_workers: Optional[int] = None,
    size: str = "1024x1024",
    add_to_roadmap: bool = False
) -> Dict[str, Any]:
    """
    Apply a style change to images in a source directory using image-to-image editing.
    """
    print(f"Starting change_media for run ID: {run_id}")
    print(f"Source dir: {source_dir}, match_mode: {match_mode or 'all'}, match: {match or 'none'}, quality: {quality}")

    if not source_dir or not os.path.isdir(source_dir):
        raise ValueError(f"Source directory not found or invalid: {source_dir}")

    try:
        roadmap_data = load_roadmap_file(run_id)
        topic = roadmap_data.get("title", roadmap_data.get("id", run_id)) if roadmap_data else run_id
    except Exception:
        roadmap_data = None
        topic = run_id

    chain_name_in_use, prompt_path_in_use = resolve_change_media_prompt(run_id, prompt_id)
    prompt_template_in_use = load_prompt_template(prompt_path_in_use)
    if not prompt_template_in_use:
        raise ValueError(f"Prompt template empty or unreadable: {prompt_path_in_use}")

    sources = select_media_files(
        run_id=run_id,
        source_dir=source_dir,
        match_mode=match_mode,
        match=match
    )
    if not sources:
        print("⚠️ No source images found for change_media.")
        return {
            "run_id": run_id,
            "source_dir": source_dir,
            "match_mode": match_mode,
            "match": match,
            "images_found": 0,
            "images_changed": 0,
            "changed_images": [],
            "message": "No source images found"
        }

    total_sources = len(sources)
    if limit is not None and limit > 0:
        limited_count = min(limit, total_sources)
        if limited_count < total_sources:
            print(f"Limiting processing to first {limited_count} images (limit={limit}).")
        sources = sources[:limited_count]
    else:
        limited_count = total_sources

    output_dir = f"output/can-do-steps/{run_id}/change_media"
    ensure_dir(output_dir)

    default_workers = min(8, (os.cpu_count() or 4) * 2)
    workers = max_workers if max_workers is not None else default_workers
    if workers is None or workers < 1:
        workers = 1

    print(f"Processing {len(sources)} images with up to {workers} workers...")

    def _process_single(index: int, source_path: str) -> Tuple[int, Dict[str, Any]]:
        print(f"  Processing source image: {os.path.basename(source_path)}")
        base_name = os.path.splitext(os.path.basename(source_path))[0]
        item_title, item_description = find_item_context(roadmap_data, base_name) if roadmap_data else (base_name, "")

        final_prompt = render_change_prompt(
            prompt_template=prompt_template_in_use,
            change_request=change_request,
            run_id=run_id,
            topic=topic,
            target_level=target_level,
            item_title=item_title,
            item_description=item_description or "",
            source_filename=os.path.basename(source_path)
        )

        item_id = f"{base_name}-custom"
        result = edit_and_save_image(
            description=final_prompt,
            source_image_path=source_path,
            item_id=item_id,
            output_dir=output_dir,
            context=f"Change media for {run_id} ({topic})",
            overwrite=overwrite,
            model=model,
            size=size,
            quality=quality,
            create_large_variant=True,
            create_webp_variant=True,
            chain_name=chain_name_in_use,
            pipeline="can-do-steps",
            run_id=run_id
        )

        if result.get("success"):
            if result.get("skipped"):
                print(f"  ⏭️ Image skipped (already exists): {os.path.basename(result.get('image_path', ''))}")
            else:
                print(f"  ✅ Image generated: {os.path.basename(result.get('image_path', ''))}")
        else:
            print(f"  ❌ Failed to generate image for {item_id}: {result.get('error', 'Unknown error')}")

        return index, {
            "source_path": source_path,
            "item_id": item_id,
            "step_key": base_name,
            "source_filename": os.path.basename(source_path),
            "item_title": item_title,
            "item_description": item_description,
            "output_path": result.get("image_path", ""),
            "metadata_path": result.get("metadata_path", ""),
            "generated": result.get("success", False) and not result.get("skipped", False),
            "skipped": result.get("skipped", False),
            "webp_filename": result.get("webp_filename"),
            "thumbnail_filename": result.get("thumbnail_filename"),
            "error": result.get("error")
        }

    changed_images: List[Optional[Dict[str, Any]]] = [None] * len(sources)

    if workers > 1 and len(sources) > 1:
        with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
            futures = [
                executor.submit(_process_single, idx, source_path)
                for idx, source_path in enumerate(sources)
            ]

            for future in concurrent.futures.as_completed(futures):
                idx, result = future.result()
                changed_images[idx] = result
    else:
        for idx, source_path in enumerate(sources):
            returned_idx, result = _process_single(idx, source_path)
            changed_images[returned_idx] = result

    changed_images = [img for img in changed_images if img is not None]

    if add_to_roadmap and roadmap_data:
        update_items: List[Dict[str, str]] = []
        for entry in changed_images:
            if not entry.get("generated"):
                continue
            webp_name = entry.get("webp_filename")
            thumb_name = entry.get("thumbnail_filename")
            if not webp_name or not thumb_name:
                continue
            update_items.append({
                "step_key": entry.get("step_key", ""),
                "source_filename": entry.get("source_filename", ""),
                "file": webp_name,
                "thumbnail": thumb_name
            })

        if update_items:
            roadmap_data, modified, updated_steps = apply_custom_images_to_roadmap(roadmap_data, update_items)
            roadmap_path = f"output/can-do-steps/{run_id}/roadmap-{run_id}.json"
            try:
                if modified:
                    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                    backup_path = f"output/can-do-steps/{run_id}/archived/roadmap-{run_id}-{timestamp}.json"
                    if path_exists(roadmap_path):
                        copy_path(roadmap_path, backup_path)
                        print(f"✅ Created roadmap backup: {backup_path}")

                    write_json(roadmap_path, roadmap_data)
                    print(f"✅ Updated roadmap with custom images: {roadmap_path}")
                    if updated_steps:
                        print(f"   Steps updated: {len(updated_steps)}")
                else:
                    print("ℹ️ No roadmap updates applied; skipping file write.")
            except Exception as exc:
                print(f"⚠️ Failed to update roadmap file {roadmap_path}: {exc}")

    final_result = {
        "run_id": run_id,
        "source_dir": source_dir,
        "match_mode": match_mode,
        "match": match,
        "model": model,
        "quality": quality,
        "target_level": target_level,
        "add_to_roadmap": add_to_roadmap,
        "images_found": len(sources),
        "total_images_available": total_sources,
        "images_changed": len([img for img in changed_images if img["generated"]]),
        "changed_images": changed_images
    }

    save_output(
        data=final_result,
        pipeline="can-do-steps",
        step="change_media",
        run_id=run_id,
        subfolder="logs"
    )

    print(f"✅ change_media completed: {final_result['images_changed']}/{len(sources)} images processed.")
    return final_result


def load_changed_media(run_id: str) -> Optional[Dict[str, Any]]:
    """Load previously generated change_media results."""
    return load_latest_output(
        pipeline="can-do-steps",
        step="change_media",
        run_id=run_id,
        subfolder="logs"
    )


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Change media style for can-do-steps pipeline")
    parser.add_argument("--run-id", required=True, help="Run identifier")
    parser.add_argument("--source-dir", required=True, help="Directory containing source images")
    parser.add_argument("--match-mode", help="Selection strategy for source images (e.g., first-image)")
    parser.add_argument("--match", help="Optional filename mask (e.g., '*.jpg')")
    parser.add_argument("--add-to-roadmap", action="store_true", help="Insert generated images into roadmap steps")
    parser.add_argument("--overwrite", action="store_true", help="Overwrite existing outputs")
    parser.add_argument("--quality", choices=["low", "medium", "high", "standard"], default="standard", help="Quality label for cost tracking")
    parser.add_argument("--model", help="Model override (OpenAI image model)")
    parser.add_argument("--size", default="1024x1024", help="Requested output size (provider-specific)")
    args = parser.parse_args()

    change_media(
        run_id=args.run_id,
        source_dir=args.source_dir,
        match_mode=args.match_mode,
        match=args.match,
        overwrite=args.overwrite,
        quality=args.quality,
        model=args.model,
        add_to_roadmap=args.add_to_roadmap,
        size=args.size
    )
