"""
Can-Do-Steps generate media component for the Lumabit lesson-generation pipeline.
Generates images from existing illustration descriptions in roadmap content.
"""
import os
import json
import traceback
import concurrent.futures
from typing import Dict, Any, List, Optional, Tuple

from utils.io import save_output, load_latest_output, ensure_dir
from chains.base import build_chain
from utils.storage import path_exists, read_json, copy_path


def resolve_generate_media_prompt(run_id: Optional[str] = None, prompt_id: Optional[str] = None) -> Tuple[str, str]:
    """
    Resolve the appropriate prompt file and chain name for generate_media.

    Args:
        run_id: Optional run identifier to look up a custom prompt.
        prompt_id: Optional override identifier for the prompt file.

    Returns:
        Tuple containing (chain_name, prompt_path)
    """
    prompt_dir = os.path.join("prompts", "can-do-steps")
    default_chain_name = "generate_media"
    default_prompt_path = os.path.join(prompt_dir, f"{default_chain_name}.txt")

    if prompt_id:
        candidate_filename = f"{default_chain_name}_{prompt_id}.txt"
        candidate_prompt_path = os.path.join(prompt_dir, candidate_filename)
        if os.path.exists(candidate_prompt_path):
            chain_name = f"{default_chain_name}_{prompt_id}"
            print(f"Using prompt override: {candidate_prompt_path}")
            return chain_name, candidate_prompt_path
        else:
            print(f"Prompt override not found ({candidate_prompt_path}); falling back to run-specific/default prompt.")

    if run_id:
        candidate_filename = f"{default_chain_name}_{run_id}.txt"
        candidate_prompt_path = os.path.join(prompt_dir, candidate_filename)
        if os.path.exists(candidate_prompt_path):
            chain_name = f"{default_chain_name}_{run_id}"
            print(f"Using roadmap-specific prompt: {candidate_prompt_path}")
            return chain_name, candidate_prompt_path

    print(f"Using default prompt: {default_prompt_path}")
    return default_chain_name, default_prompt_path


def load_roadmap_file(run_id: str) -> Optional[Dict[str, Any]]:
    """
    Load the roadmap JSON file for the given run ID.

    Args:
        run_id: Run identifier

    Returns:
        Dict: Loaded roadmap data, or None if not found
    """
    # Look for roadmap-{run_id}.json in the output directory
    roadmap_path = f"output/can-do-steps/{run_id}/roadmap-{run_id}.json"

    if path_exists(roadmap_path):
        try:
            roadmap_data = read_json(roadmap_path)
            print(f"✓ Loaded roadmap file: {roadmap_path}")
            return roadmap_data
        except Exception as e:
            print(f"⚠ Error loading roadmap file {roadmap_path}: {e}")
            return None
    else:
        print(f"⚠ Roadmap file not found: {roadmap_path}")
        return None


def extract_existing_illustrations(roadmap_data: Dict[str, Any]) -> List[Dict[str, Any]]:
    """
    Extract existing illustration descriptions from roadmap data.

    Args:
        roadmap_data: Loaded roadmap JSON

    Returns:
        List: Existing illustrations found in the roadmap
    """
    illustrations = []

    # Extract roadmap-level illustration
    if "illustration" in roadmap_data and roadmap_data["illustration"]:
        illustrations.append({
            "type": "roadmap",
            "id": roadmap_data.get("id", "roadmap"),
            "title": roadmap_data.get("title", ""),
            "description": roadmap_data.get("description", ""),
            "illustration": roadmap_data["illustration"]
        })

    # Extract track illustrations
    if "tracks" in roadmap_data:
        for track in roadmap_data["tracks"]:
            if "illustration" in track and track["illustration"]:
                illustrations.append({
                    "type": "track",
                    "id": track.get("id", ""),
                    "title": track.get("title", ""),
                    "description": track.get("description", ""),
                    "illustration": track["illustration"]
                })

            # Extract path illustrations
            if "paths" in track:
                for path in track["paths"]:
                    if "illustration" in path and path["illustration"]:
                        illustrations.append({
                            "type": "path",
                            "id": path.get("id", ""),
                            "title": path.get("title", ""),
                            "description": path.get("description", ""),
                            "illustration": path["illustration"]
                        })

                    # Extract step illustrations
                    if "steps" in path:
                        for step in path["steps"]:
                            if "illustration" in step and step["illustration"]:
                                illustrations.append({
                                    "type": "step",
                                    "id": step.get("id", ""),
                                    "title": step.get("title", ""),
                                    "description": step.get("description", ""),
                                    "illustration": step["illustration"]
                                })

    print(f"✅ Found {len(illustrations)} existing illustrations in roadmap")
    return illustrations


def filter_illustrations_by_target_level(illustrations: List[Dict[str, Any]], target_level: str = "all") -> List[Dict[str, Any]]:
    """
    Filter illustrations by target level.

    Args:
        illustrations: List of all illustrations
        target_level: "roadmap", "track", "path", "step", or "all"

    Returns:
        List: Filtered illustrations
    """
    if target_level == "all":
        return illustrations

    return [ill for ill in illustrations if ill["type"] == target_level]


def load_prompt_template(prompt_path: str) -> str:
    """
    Read the full prompt template from disk.

    Args:
        prompt_path: Path to the prompt file

    Returns:
        str: Contents of the prompt file, or empty string on failure
    """
    try:
        if not os.path.exists(prompt_path):
            print(f"  ⚠️ Prompt file not found: {prompt_path}")
            return ""

        with open(prompt_path, 'r', encoding='utf-8') as f:
            return f.read()
    except Exception as e:
        print(f"  ⚠️ Could not read prompt template: {e}")
        return ""


def render_image_prompt(
    prompt_template: str,
    image_description: str,
    run_id: str = "",
    topic: str = "",
    target_level: str = ""
) -> str:
    """
    Substitute known placeholders into the prompt template.

    Args:
        prompt_template: Raw prompt template text
        image_description: Illustration description from the roadmap
        run_id: Run identifier
        topic: Topic for context
        target_level: Target level for image generation

    Returns:
        str: Prompt with placeholders substituted
    """
    prompt = prompt_template
    prompt = prompt.replace("{image_description}", image_description)
    prompt = prompt.replace("{run_id}", run_id)
    prompt = prompt.replace("{topic}", topic)
    prompt = prompt.replace("{target_level}", target_level)
    return prompt


def generate_image_from_description(
    description: str,
    item_id: str,
    run_id: str,
    output_dir: str,
    topic: str = "",
    overwrite: bool = False,
    quality: str = "low",
    chain_name: Optional[str] = None,
    provider: str = "open-ai",
    prompt_template: Optional[str] = None,
    target_level: str = ""
) -> Tuple[str, bool]:
    """
    Generate an actual image from an illustration description using the full prompt template.

    Args:
        description: Illustration description
        item_id: Unique identifier for the item
        run_id: Run identifier
        output_dir: Output directory for images
        topic: Topic for styling context
        overwrite: Whether to overwrite existing files
        quality: Image quality - "low", "medium", or "high" (default: "low")
        chain_name: Optional chain name for cost tracking and provenance
        provider: Image provider ("open-ai" or "google")
        prompt_template: Full prompt template to use for generation
        target_level: Target level for context substitution

    Returns:
        Tuple[str, bool]: (image_path, success)
    """
    try:
        # Use the abstracted image generation utility
        from utils.image_generator import generate_and_save_image

        # Render the prompt by substituting the illustration description
        if prompt_template:
            final_prompt = render_image_prompt(
                prompt_template=prompt_template,
                image_description=description,
                run_id=run_id,
                topic=topic,
                target_level=target_level
            )
        else:
            final_prompt = description

        print(f"  Generating image for {item_id} (quality: {quality})...")
        result = generate_and_save_image(
            description=final_prompt,
            item_id=item_id,
            output_dir=output_dir,
            context=f"Educational roadmap content for {run_id}",
            overwrite=overwrite,
            quality=quality,
            chain_name=chain_name or "generate_media",
            pipeline="can-do-steps",
            run_id=run_id,
            provider=provider,
            create_large_variant=True
        )

        if result["success"]:
            if result.get("skipped"):
                print(f"  ⏭️ Image skipped (already exists): {result['image_path']}")
            else:
                print(f"  ✅ Image generated: {result['image_path']}")

                # Copy image to global can-do-steps images directory
                try:
                    global_images_dir = "output/can-do-steps/images"
                    ensure_dir(global_images_dir)

                    # Create filename with roadmap ID and "-front" suffix
                    global_filename = f"{run_id}-front.png"
                    global_image_path = os.path.join(global_images_dir, global_filename)

                    copy_path(result["image_path"], global_image_path)
                    print(f"  📋 Image copied to global directory: {global_image_path}")

                except Exception as copy_error:
                    print(f"  ⚠️ Failed to copy image to global directory: {copy_error}")
                    # Don't fail the entire operation if copy fails

            return result["image_path"], True
        else:
            print(f"  ❌ Failed to generate image for {item_id}: {result.get('error', 'Unknown error')}")
            return "", False

    except Exception as e:
        print(f"  ❌ Error generating image for {item_id}: {e}")
        return "", False




def generate_media(
    run_id: str,
    overwrite: bool = False,
    target_level: str = "all",
    generate_images: bool = True,
    quality: str = "low",
    limit: Optional[int] = None,
    max_workers: Optional[int] = None,
    prompt_id: Optional[str] = None,
    provider: str = "open-ai"
) -> Dict[str, Any]:
    """
    Generate images from existing illustration descriptions in roadmap content.

    Args:
        run_id: Run identifier
        overwrite: If True, overwrite existing images
        target_level: "roadmap", "track", "path", "step", or "all" to limit generation
        generate_images: If True, generate actual image files
        quality: Image quality - "low", "medium", or "high" (default: "low")
        limit: Maximum number of illustrations to process (useful for testing)
        max_workers: Maximum number of concurrent workers for image generation
        prompt_id: Optional override identifier for the prompt file
        provider: Image provider ("open-ai" or "google")

    Returns:
        Dict: Media generation results
    """
    print(f"Starting image generation for run ID: {run_id}")
    print(f"Target level: {target_level}, Generate images: {generate_images}, Quality: {quality}, Provider: {provider}")

    try:
        # Load the roadmap file
        roadmap_data = load_roadmap_file(run_id)
        if not roadmap_data:
            raise ValueError(f"Could not load roadmap file for run ID: {run_id}")

        chain_name_in_use, prompt_path_in_use = resolve_generate_media_prompt(run_id, prompt_id)
        prompt_template_in_use = load_prompt_template(prompt_path_in_use)
        if not prompt_template_in_use:
            print("⚠️ Prompt template empty or unreadable; falling back to illustration description only.")
        elif "{image_description}" not in prompt_template_in_use:
            raise ValueError(
                f"Prompt template is missing required placeholder '{{image_description}}': {prompt_path_in_use}"
            )

        # Extract topic information
        topic = roadmap_data.get("title", roadmap_data.get("id", run_id))

        # Extract existing illustrations from roadmap
        all_illustrations = extract_existing_illustrations(roadmap_data)
        if not all_illustrations:
            print("⚠️ No existing illustrations found in roadmap. Please ensure illustrations have been generated first.")
            return {
                "run_id": run_id,
                "target_level": target_level,
                "illustrations_found": 0,
                "images_generated": 0,
                "generated_images": [],
                "message": "No illustrations found in roadmap"
            }

        # Filter illustrations by target level
        illustrations = filter_illustrations_by_target_level(all_illustrations, target_level)

        if not illustrations:
            print(f"⚠️ No illustrations found for target level: {target_level}")
            return {
                "run_id": run_id,
                "target_level": target_level,
                "illustrations_found": 0,
                "images_generated": 0,
                "generated_images": [],
                "message": f"No illustrations found for target level: {target_level}"
            }

        total_illustrations_available = len(illustrations)

        if limit is not None and limit > 0:
            limited_count = min(limit, total_illustrations_available)
            if limited_count < total_illustrations_available:
                print(
                    f"Limiting processing to first {limited_count} illustrations "
                    f"out of {total_illustrations_available} (limit={limit})."
                )
            illustrations = illustrations[:limited_count]
        else:
            limited_count = total_illustrations_available

        print(f"Found {total_illustrations_available} illustrations for target level: {target_level}")

        # Generate actual images if requested
        generated_images = []
        if generate_images and illustrations:
            images_dir = f"output/can-do-steps/{run_id}/images"
            ensure_dir(images_dir)

            default_workers = min(8, (os.cpu_count() or 4) * 2)
            workers = max_workers if max_workers is not None else default_workers
            if workers is None or workers < 1:
                workers = 1

            def _process_single(index: int, illustration: Dict[str, Any]) -> Tuple[int, Dict[str, Any]]:
                image_path, success = generate_image_from_description(
                    description=illustration["illustration"],
                    item_id=illustration["id"],
                    run_id=run_id,
                    output_dir=images_dir,
                    topic=topic,
                    overwrite=overwrite,
                    quality=quality,
                    chain_name=chain_name_in_use,
                    provider=provider,
                    prompt_template=prompt_template_in_use,
                    target_level=target_level
                )

                return index, {
                    "id": illustration["id"],
                    "type": illustration["type"],
                    "title": illustration["title"],
                    "illustration": illustration["illustration"],
                    "image_path": image_path if success else "",
                    "generated": success
                }

            generated_images = [None] * len(illustrations)

            if workers > 1 and len(illustrations) > 1:
                print(f"Generating {len(illustrations)} images with up to {workers} workers...")
                with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
                    futures = [
                        executor.submit(_process_single, index, illustration)
                        for index, illustration in enumerate(illustrations)
                    ]

                    for future in concurrent.futures.as_completed(futures):
                        index, result = future.result()
                        generated_images[index] = result
            else:
                print(f"Generating {len(illustrations)} images sequentially...")
                for index, illustration in enumerate(illustrations):
                    returned_index, result = _process_single(index, illustration)
                    generated_images[returned_index] = result

            generated_images = [img for img in generated_images if img is not None]

        # Compile final results
        final_result = {
            "run_id": run_id,
            "target_level": target_level,
            "provider": provider,
            "illustrations_found": len(illustrations),
            "total_illustrations_available": total_illustrations_available,
            "images_generated": len([img for img in generated_images if img["generated"]]) if generate_images else 0,
            "illustrations": illustrations,
            "generated_images": generated_images if generate_images else []
        }

        # Save the complete results
        save_output(
            data=final_result,
            pipeline="can-do-steps",
            step="generate_media",
            run_id=run_id,
            subfolder="logs"
        )

        print(f"✅ Image generation completed successfully")
        print(f"   Illustrations found: {len(illustrations)}")
        if generate_images:
            success_count = len([img for img in generated_images if img["generated"]])
            print(f"   Images: {success_count}/{len(illustrations)} generated")

        return final_result

    except Exception as e:
        tb = traceback.format_exc()
        print(f"Error in generate_media: {e}")
        print(tb)
        raise


def load_generated_media(run_id: str) -> Optional[Dict[str, Any]]:
    """
    Load previously generated media results.

    Args:
        run_id: Run identifier

    Returns:
        Dict: Previously generated media, or None if not found
    """
    return load_latest_output(
        pipeline="can-do-steps",
        step="generate_media",
        run_id=run_id,
        subfolder="logs"
    )


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Generate images for can-do-steps roadmap")
    parser.add_argument("--run-id", required=True, help="Run identifier")
    parser.add_argument("--overwrite", action="store_true", help="Overwrite existing images")
    parser.add_argument("--target-level", choices=["roadmap", "track", "path", "step", "all"],
                        default="all", help="Target level for generation")
    parser.add_argument("--no-images", action="store_true", help="Skip image generation")
    parser.add_argument("--quality", choices=["low", "medium", "high"],
                        default="low", help="Image quality (default: low)")
    parser.add_argument("--limit", type=int, help="Limit the number of illustrations to process")
    parser.add_argument("--max-workers", type=int, dest="max_workers",
                        help="Maximum number of concurrent image generation workers (default: auto)")
    parser.add_argument("--provider", choices=["open-ai", "google"], default="open-ai",
                        help="Image provider to use (default: open-ai)")
    args = parser.parse_args()

    result = generate_media(
        run_id=args.run_id,
        overwrite=args.overwrite,
        target_level=args.target_level,
        generate_images=not args.no_images,
        quality=args.quality,
        limit=args.limit,
        max_workers=args.max_workers,
        provider=args.provider
    )

    print(json.dumps({
        "run_id": result["run_id"],
        "illustrations_found": result["illustrations_found"],
        "images_generated": result.get("images_generated", 0)
    }, indent=2))
