"""
Can-Do-Steps describe images component for the Lumabit lesson-generation pipeline.
Generates illustration descriptions for roadmap content.
"""
import os
import json
import traceback
from datetime import datetime
from typing import Dict, Any, List, Optional, Tuple

from utils.io import save_output, load_latest_output, ensure_dir
from chains.base import build_chain, default_json_parser, parse_output
from utils.storage import path_exists, read_json, write_json, copy_path


def load_roadmap_file(run_id: str) -> Optional[Dict[str, Any]]:
    """
    Load the roadmap JSON file for the given run ID.

    Args:
        run_id: Run identifier

    Returns:
        Dict: Loaded roadmap data, or None if not found
    """
    # Look for roadmap-{run_id}.json in the output directory
    roadmap_path = f"output/can-do-steps/{run_id}/roadmap-{run_id}.json"

    if path_exists(roadmap_path):
        try:
            roadmap_data = read_json(roadmap_path)
            print(f"✓ Loaded roadmap file: {roadmap_path}")
            return roadmap_data
        except Exception as e:
            print(f"⚠ Error loading roadmap file {roadmap_path}: {e}")
            return None
    else:
        print(f"⚠ Roadmap file not found: {roadmap_path}")
        return None


def resolve_describe_images_prompt(run_id: Optional[str] = None, prompt_id: Optional[str] = None) -> Tuple[str, str]:
    """
    Resolve the prompt file to use for describe_images.

    Priority order:
      1. prompt_id override (describe_images_<prompt_id>.txt)
      2. run-specific prompt (describe_images_<run_id>.txt)
      3. Default prompt (describe_images.txt)
    """
    prompt_dir = os.path.join("prompts", "can-do-steps")
    default_chain_name = "describe_images"
    default_prompt_path = os.path.join(prompt_dir, f"{default_chain_name}.txt")

    if prompt_id:
        candidate_filename = f"{default_chain_name}_{prompt_id}.txt"
        candidate_prompt_path = os.path.join(prompt_dir, candidate_filename)
        if os.path.exists(candidate_prompt_path):
            chain_name = f"{default_chain_name}_{prompt_id}"
            print(f"Using prompt override: {candidate_prompt_path}")
            return chain_name, candidate_prompt_path
        else:
            print(f"Prompt override not found ({candidate_prompt_path}); falling back to run-specific/default prompt.")

    if run_id:
        candidate_filename = f"{default_chain_name}_{run_id}.txt"
        candidate_prompt_path = os.path.join(prompt_dir, candidate_filename)
        if os.path.exists(candidate_prompt_path):
            chain_name = f"{default_chain_name}_{run_id}"
            print(f"Using roadmap-specific prompt: {candidate_prompt_path}")
            return chain_name, candidate_prompt_path

    print(f"Using default prompt: {default_prompt_path}")
    return default_chain_name, default_prompt_path


def parse_media_output(output: str) -> Dict[str, Any]:
    """
    Parse the illustration description output from the LLM.

    Args:
        output: Raw output from the LLM

    Returns:
        Dict: Parsed illustration description data
    """
    try:
        # Parse the JSON from the output
        parsed = parse_output(output, default_json_parser)

        # Validate the expected structure
        required_fields = ["illustrations", "generation_summary"]
        for field in required_fields:
            if field not in parsed:
                raise ValueError(f"Expected '{field}' key in parsed output")

        # Validate illustrations structure
        illustrations = parsed["illustrations"]
        if not isinstance(illustrations, list):
            raise ValueError("Expected 'illustrations' to be a list")

        # Validate each illustration
        for i, illustration in enumerate(illustrations):
            if not isinstance(illustration, dict):
                raise ValueError(f"Illustration at index {i} is not a dictionary")

            required_illustration_fields = ["type", "id", "title", "description", "illustration"]
            for field in required_illustration_fields:
                if field not in illustration:
                    raise ValueError(f"Illustration at index {i} is missing required field '{field}'")

            # Validate type
            if illustration["type"] not in ["roadmap", "track", "path", "step"]:
                raise ValueError(f"Invalid illustration type: {illustration['type']}")

            # Validate illustration description length
            desc_words = illustration["illustration"].split()
            if len(desc_words) < 10 or len(desc_words) > 60:
                print(f"WARNING: Illustration description should be 10-60 words: '{illustration['illustration']}'")

        print(f"✅ Successfully parsed {len(illustrations)} illustration descriptions")
        return parsed

    except Exception as e:
        print(f"Error parsing illustration output: {e}")
        print(f"First 500 characters of output: {output[:500]}...")
        raise


def extract_content_for_prompts(roadmap_data: Dict[str, Any], target_level: str = "all") -> str:
    """
    Extract relevant content from roadmap data for prompt generation.

    Args:
        roadmap_data: Loaded roadmap JSON
        target_level: "roadmap", "track", "path", "step", or "all"

    Returns:
        str: Formatted content for the prompt
    """
    content_lines = []

    if target_level in ["roadmap", "all"]:
        content_lines.append("=== ROADMAP ===")
        content_lines.append(f"Roadmap ID: {roadmap_data.get('id', 'N/A')}")
        content_lines.append(f"Title: {roadmap_data.get('title', 'N/A')}")
        content_lines.append(f"Description: {roadmap_data.get('description', 'N/A')}")
        if "illustration" in roadmap_data:
            content_lines.append(f"Existing Illustration: {roadmap_data['illustration']}")
        content_lines.append("---")

    if "tracks" in roadmap_data and (target_level in ["track", "all"]):
        content_lines.append("=== TRACKS ===")
        for track in roadmap_data["tracks"]:
            content_lines.append(f"Track ID: {track.get('id', 'N/A')}")
            content_lines.append(f"Title: {track.get('title', 'N/A')}")
            content_lines.append(f"Description: {track.get('description', 'N/A')}")
            if "illustration" in track:
                content_lines.append(f"Existing Illustration: {track['illustration']}")
            content_lines.append("---")

    if "tracks" in roadmap_data and (target_level in ["path", "all"]):
        content_lines.append("\n=== PATHS ===")
        for track in roadmap_data["tracks"]:
            if "paths" in track:
                for path in track["paths"]:
                    content_lines.append(f"Path ID: {path.get('id', 'N/A')}")
                    content_lines.append(f"Title: {path.get('title', 'N/A')}")
                    content_lines.append(f"Description: {path.get('description', 'N/A')}")
                    content_lines.append(f"Track: {track.get('title', 'N/A')}")
                    if "illustration" in path:
                        content_lines.append(f"Existing Illustration: {path['illustration']}")
                    content_lines.append("---")

    if "tracks" in roadmap_data and (target_level in ["step", "all"]):
        content_lines.append("\n=== STEPS ===")
        for track in roadmap_data["tracks"]:
            if "paths" in track:
                for path in track["paths"]:
                    if "steps" in path:
                        for step in path["steps"]:
                            content_lines.append(f"Step ID: {step.get('id', 'N/A')}")
                            content_lines.append(f"Title: {step.get('title', 'N/A')}")
                            content_lines.append(f"Description: {step.get('description', 'N/A')}")
                            content_lines.append(f"Level: {step.get('level', 'N/A')}")
                            content_lines.append(f"Path: {path.get('title', 'N/A')}")
                            if "illustration" in step:
                                content_lines.append(f"Existing Illustration: {step['illustration']}")
                            content_lines.append("---")

    return "\n".join(content_lines)


def apply_illustrations_to_roadmap(
    roadmap_data: Dict[str, Any],
    illustrations: List[Dict[str, Any]]
) -> Dict[str, Any]:
    """
    Apply generated illustration descriptions back to the roadmap data.

    Args:
        roadmap_data: Original roadmap data
        illustrations: Generated illustrations

    Returns:
        Dict: Updated roadmap data with illustration fields
    """
    # Create a lookup dictionary for illustrations by ID
    illustration_lookup = {ill["id"]: ill for ill in illustrations}

    # Collect the IDs that exist in the roadmap so we can warn about mismatches
    structure_lookup: Dict[str, str] = {}

    roadmap_id = roadmap_data.get("id")
    if roadmap_id:
        structure_lookup[roadmap_id] = "roadmap"

    for track in roadmap_data.get("tracks", []):
        track_id = track.get("id")
        if track_id:
            structure_lookup[track_id] = "track"

        for path in track.get("paths", []):
            path_id = path.get("id")
            if path_id:
                structure_lookup[path_id] = "path"

            for step in path.get("steps", []):
                step_id = step.get("id")
                if step_id:
                    structure_lookup[step_id] = "step"

    applied_ids: List[str] = []

    # Apply to roadmap level (top-level illustration)
    if roadmap_id and roadmap_id in illustration_lookup:
        roadmap_data["illustration"] = illustration_lookup[roadmap_id]["illustration"]
        applied_ids.append(roadmap_id)

    # Apply to tracks
    for track in roadmap_data.get("tracks", []):
        track_id = track.get("id")
        if track_id and track_id in illustration_lookup:
            track["illustration"] = illustration_lookup[track_id]["illustration"]
            applied_ids.append(track_id)

        # Apply to paths
        for path in track.get("paths", []):
            path_id = path.get("id")
            if path_id and path_id in illustration_lookup:
                path["illustration"] = illustration_lookup[path_id]["illustration"]
                applied_ids.append(path_id)

            # Apply to steps
            for step in path.get("steps", []):
                step_id = step.get("id")
                if step_id and step_id in illustration_lookup:
                    step["illustration"] = illustration_lookup[step_id]["illustration"]
                    applied_ids.append(step_id)

    # Issue warnings for any illustration IDs that did not match the roadmap structure
    unmatched = [ill for ill in illustrations if ill["id"] not in applied_ids]
    if unmatched:
        print("⚠ Illustration IDs with no matching roadmap item found:")
        for ill in unmatched:
            expected = ill.get("type")
            actual = structure_lookup.get(ill["id"])
            detail = f"expected type '{expected}'"
            if actual:
                detail += f", but roadmap contains type '{actual}'"
            else:
                detail += ", but ID is missing from roadmap"
            print(f"   - {ill['id']}: {detail}")

    print(f"🖼 Applied {len(applied_ids)} illustration(s) to roadmap (requested {len(illustrations)}).")

    return roadmap_data


def describe_images(
    run_id: str,
    force_text: bool = False,
    target_level: str = "all",
    prompt_id: Optional[str] = None
) -> Dict[str, Any]:
    """
    Generate illustration descriptions for roadmap content.

    Args:
        run_id: Run identifier
        force_text: If True, use existing raw text output instead of calling API
        target_level: "roadmap", "track", "path", "step", or "all" to limit generation
        prompt_id: Optional override identifier for selecting a specific prompt file

    Returns:
        Dict: Illustration description generation results
    """
    print(f"Starting illustration description generation for run ID: {run_id}")
    print(f"Target level: {target_level}")

    try:
        # Load the roadmap file
        roadmap_data = load_roadmap_file(run_id)
        if not roadmap_data:
            raise ValueError(f"Could not load roadmap file for run ID: {run_id}")

        chain_name, _ = resolve_describe_images_prompt(run_id, prompt_id)

        # Extract topic information if available
        topic = roadmap_data.get("title", roadmap_data.get("id", run_id))

        # Check if we should use existing output
        if force_text:
            existing_output = load_latest_output(
                pipeline="can-do-steps",
                step=chain_name,
                run_id=run_id,
                as_text=True,
                raw=True,
                subfolder="logs"
            )

            if existing_output:
                print(f"Using existing raw output for {chain_name} in can-do-steps/{run_id}")
                parsed_media = parse_media_output(existing_output)

                # Save the parsed output
                save_output(
                    data=parsed_media,
                    pipeline="can-do-steps",
                    step="describe_images",
                    run_id=run_id,
                    subfolder="logs"
                )

                return parsed_media

        # Extract content for prompt
        roadmap_content = extract_content_for_prompts(roadmap_data, target_level)

        # Generate illustration descriptions using LLM
        print("Generating illustration descriptions...")
        result = build_chain(
            chain_name=chain_name,
            pipeline="can-do-steps",
            run_id=run_id,
            input_variables={
                "run_id": run_id,
                "topic": topic,
                "target_level": target_level,
                "roadmap_content": roadmap_content
            }
        )

        # Parse the result
        parsed_media = parse_media_output(result["output"])
        illustrations = parsed_media["illustrations"]

        # Apply illustrations back to roadmap data
        updated_roadmap = apply_illustrations_to_roadmap(roadmap_data, illustrations)

        # Create backup of original roadmap file with timestamp
        original_roadmap_path = f"output/can-do-steps/{run_id}/roadmap-{run_id}.json"
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        backup_roadmap_path = f"output/can-do-steps/{run_id}/archived/roadmap-{run_id}-{timestamp}.json"

        if path_exists(original_roadmap_path):
            copy_path(original_roadmap_path, backup_roadmap_path)
            print(f"✅ Created backup: {backup_roadmap_path}")

        # Replace original roadmap file with enriched version
        write_json(original_roadmap_path, updated_roadmap)
        print(f"✅ Updated original roadmap with illustrations: {original_roadmap_path}")

        # Compile final results
        final_result = {
            "run_id": run_id,
            "target_level": target_level,
            "illustrations_generated": len(illustrations),
            "original_roadmap_path": original_roadmap_path,
            "backup_roadmap_path": backup_roadmap_path if path_exists(backup_roadmap_path) else None,
            "illustrations": illustrations,
            "generation_summary": parsed_media.get("generation_summary", {})
        }

        # Save the complete results
        save_output(
            data=final_result,
            pipeline="can-do-steps",
            step="describe_images",
            run_id=run_id,
            subfolder="logs"
        )

        print(f"✅ Illustration description generation completed successfully")
        print(f"   Illustrations: {len(illustrations)}")

        return final_result

    except Exception as e:
        tb = traceback.format_exc()
        print(f"Error in describe_images: {e}")
        print(tb)
        raise


def load_generated_descriptions(run_id: str) -> Optional[Dict[str, Any]]:
    """
    Load previously generated illustration descriptions.

    Args:
        run_id: Run identifier

    Returns:
        Dict: Previously generated descriptions, or None if not found
    """
    return load_latest_output(
        pipeline="can-do-steps",
        step="describe_images",
        run_id=run_id,
        subfolder="logs"
    )


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Generate illustration descriptions for can-do-steps roadmap")
    parser.add_argument("--run-id", required=True, help="Run identifier")
    parser.add_argument("--force-text", action="store_true", help="Use existing raw text output")
    parser.add_argument("--target-level", choices=["roadmap", "track", "path", "step", "all"],
                        default="all", help="Target level for generation")
    args = parser.parse_args()

    result = describe_images(
        run_id=args.run_id,
        force_text=args.force_text,
        target_level=args.target_level
    )

    print(json.dumps({
        "run_id": result["run_id"],
        "illustrations_generated": result["illustrations_generated"],
    }, indent=2))
