"""
Can-Do-Steps organize bits component for the Lumabit lesson-generation pipeline.
Phase 2D: Generate learning bits from learning steps.
"""
import os
import json
import concurrent.futures
from typing import Dict, Any, List, Optional, Tuple

from utils.io import save_output, load_latest_output, create_clean_copy
from chains.base import build_chain, default_json_parser, parse_output
from chains.can_do_steps.split_hierarchy import load_steps
from validation.json_schemas import (
    BITS_SCHEMA_VERSION,
    validate_bits,
    JsonSchemaValidationError,
)

def parse_organized_bits(output: str, expected_step_ids: List[str]) -> Dict[str, Any]:
    """
    Parse the organized bits from the LLM output.

    Args:
        output: Raw output from the LLM
        expected_step_ids: List of valid step IDs for validation

    Returns:
        Dict: Parsed organized bits data
    """
    try:
        # Parse the JSON from the output
        parsed = parse_output(output, default_json_parser)

        # Validate the expected structure
        required_fields = ["generation_summary"]
        for field in required_fields:
            if field not in parsed:
                raise ValueError(f"Expected '{field}' key in parsed output")

        # Validate bits structure when present
        bits = parsed.get("bits")
        if bits is None:
            print("Warning: Parsed output missing 'bits'; skipping bit validation.")
            parsed["bits"] = []
            return parsed

        if not isinstance(bits, list):
            print("Warning: Parsed output has invalid 'bits' type; skipping bit validation.")
            parsed["bits"] = []
            return parsed

        if len(bits) == 0:
            print("Warning: Parsed output contained an empty 'bits' array; skipping bit validation.")
            return parsed

        # Validate each bit has the required fields
        bit_required_fields = ["id", "slug", "title", "description", "stepId", "order"]
        bit_ids = set()
        step_bit_counts = {}

        for i, bit in enumerate(bits):
            if not isinstance(bit, dict):
                raise ValueError(f"Bit at index {i} is not a dictionary")

            for field in bit_required_fields:
                if field not in bit:
                    raise ValueError(f"Bit at index {i} is missing required field '{field}'")

            # Validate unique IDs and slugs
            if bit["id"] in bit_ids:
                raise ValueError(f"Duplicate bit ID: {bit['id']}")
            bit_ids.add(bit["id"])

            if bit["id"] != bit["slug"]:
                raise ValueError(f"Bit ID and slug must match: {bit['id']} != {bit['slug']}")

            # Validate stepId is valid
            if bit["stepId"] not in expected_step_ids:
                raise ValueError(f"Invalid stepId '{bit['stepId']}'. Must be one of: {expected_step_ids}")

            # Count bits per step
            step_id = bit["stepId"]
            step_bit_counts[step_id] = step_bit_counts.get(step_id, 0) + 1

            # Validate title is not a can-do statement (should be different from step titles)
            title = bit["title"]
            can_do_prefixes = ["I can ", "I know ", "I have ", "I use ", "I understand "]
            if any(title.startswith(prefix) for prefix in can_do_prefixes):
                print(f"Warning: Bit title should not be a can-do statement: '{title}'")

            # Validate description length (4-15 words)
            desc_words = bit["description"].split()
            if len(desc_words) < 4 or len(desc_words) > 15:
                raise ValueError(f"Bit description must be 4-15 words: '{bit['description']}'")

            # Validate order is positive integer
            if not isinstance(bit["order"], int) or bit["order"] < 1:
                raise ValueError(f"Bit order must be positive integer: {bit['order']}")

        # Validate bits per step (1-3 per step)
        for step_id in expected_step_ids:
            count = step_bit_counts.get(step_id, 0)
            if count < 1 or count > 3:
                raise ValueError(f"Step '{step_id}' must have 1-3 bits, got {count}")

        # Validate that each step has at least one bit
        steps_with_no_bits = set(expected_step_ids) - set(step_bit_counts.keys())
        if steps_with_no_bits:
            raise ValueError(f"Steps with no bits assigned: {steps_with_no_bits}")

        return parsed
    except Exception as e:
        print(f"Error parsing organized bits: {e}")
        raise

def validate_single_bit(bit: Dict[str, Any], expected_step_id: str):
    """
    Validate a single bit structure.

    Args:
        bit: Single bit data to validate
        expected_step_id: Expected step ID for validation

    Raises:
        ValueError: If validation fails
    """
    required_fields = ["id", "slug", "title", "description", "stepId", "order"]

    for field in required_fields:
        if field not in bit:
            raise ValueError(f"Missing required field '{field}'")

    if bit["stepId"] != expected_step_id:
        raise ValueError(f"Invalid stepId: {bit['stepId']}, expected: {expected_step_id}")

    if bit["id"] != bit["slug"]:
        raise ValueError(f"ID and slug must match: {bit['id']} != {bit['slug']}")

    # Validate description length
    desc_words = bit["description"].split()
    if len(desc_words) < 4 or len(desc_words) > 15:
        print(f"WARNING: Description must be 4-15 words: '{bit['description']}'")

    # Validate order
    if not isinstance(bit["order"], int) or bit["order"] < 1:
        raise ValueError(f"Order must be positive integer: {bit['order']}")

def resolve_organize_bits_prompt(
    local_run_id: str,
    prompt_id: Optional[str],
    step_type: Optional[str] = None
) -> Tuple[str, str]:
    """
    Determine which prompt to use for organize_bits generation.

    Priority order:
      1. step type specific prompt (organize_bits_type_<type>.txt) when provided
      2. prompt_id override (organize_bits_<prompt_id>.txt)
      3. run-specific prompt (organize_bits_<run_id>.txt)
      4. Default single-step prompt
    """
    prompt_dir = "prompts/can-do-steps"
    default_chain_name = "organize_bits"
    default_prompt_path = os.path.join(prompt_dir, f"{default_chain_name}.txt")

    if step_type:
        candidate_chain_name = f"organize_bits_type_{step_type}"
        candidate_prompt_path = os.path.join(prompt_dir, f"{candidate_chain_name}.txt")
        if os.path.exists(candidate_prompt_path):
            print(f"Using type-specific prompt: {candidate_prompt_path}")
            return candidate_chain_name, candidate_prompt_path

    identifier_candidates = []
    if prompt_id:
        identifier_candidates.append(prompt_id)
    if local_run_id:
        identifier_candidates.append(local_run_id)

    for identifier in identifier_candidates:
        candidate_chain_name = f"organize_bits_{identifier}"
        candidate_prompt_path = os.path.join(prompt_dir, f"{candidate_chain_name}.txt")
        if os.path.exists(candidate_prompt_path):
            print(f"Using pipeline-specific prompt: {candidate_prompt_path}")
            return candidate_chain_name, candidate_prompt_path

    print(f"Using default prompt: {default_prompt_path}")
    return default_chain_name, default_prompt_path


def generate_bits_for_single_step(
    step: Dict[str, Any],
    run_id: Optional[str] = None,
    prompt_id: Optional[str] = None
) -> Dict[str, Any]:
    """
    Generate 1-3 bits for a single step with deterministic IDs.

    Args:
        step: Single step data
        run_id: Optional run identifier to use when invoking the chain (helps make saved outputs unique/traceable)
        prompt_id: Optional override identifier for selecting a specific prompt file

    Returns:
        Dict[str, Any]: Dict containing 'bits', 'lesson_content', and 'generation_summary'
    """
    step_context = f"""Step: {step['title']} (ID: {step['id']})
Description: {step['description']}
Level: {step['level']}
Path: {step['pathId']}
Order: {step['order']}"""

    try:
        # Use a per-step run_id when provided to avoid concurrent write collisions and for traceability
        local_run_id = run_id or f"temp-{step['id']}"

        chain_name, _ = resolve_organize_bits_prompt(
            local_run_id=local_run_id,
            prompt_id=prompt_id,
            step_type=step.get("type")
        )

        result = build_chain(
            chain_name=chain_name,
            pipeline="can-do-steps",
            run_id=local_run_id,
            input_variables={
                "step_data": step_context
            }
        )

        # Parse the result
        parsed = parse_output(result["output"], default_json_parser)

        raw_bits = parsed.get("bits")
        bits_with_ids = []

        if raw_bits is None:
            print(f"Warning: Step {step['id']} output missing 'bits'; skipping bit generation.")
        elif not isinstance(raw_bits, list):
            print(f"Warning: Step {step['id']} output has invalid 'bits' type ({type(raw_bits).__name__}); skipping bit generation.")
        elif len(raw_bits) == 0:
            print(f"Warning: Step {step['id']} output contained an empty 'bits' array; skipping bit generation.")
        else:
            # Add deterministic IDs and complete structure
            for i, bit in enumerate(raw_bits, 1):
                complete_bit = {
                    "id": f"{step['id']}-bit-{i}",
                    "slug": f"{step['id']}-bit-{i}",
                    "title": bit["title"],
                    "description": bit["description"],
                    "stepId": step["id"],  # Guaranteed correct
                    "order": bit.get("order", i)
                }
                bits_with_ids.append(complete_bit)

        # Validate lesson content is present
        if "lesson" not in parsed:
            raise ValueError("Invalid response: missing 'lesson' field")

        # Persist with lesson_content key in artifact output.
        step_result = {
            "bits": bits_with_ids,
            "lesson_content": parsed["lesson"],
            "generation_summary": parsed.get("generation_summary", "No summary provided")
        }

        return step_result

    except Exception as e:
        print(f"Error generating bits for step {step['id']}: {e}")
        raise

def organize_bits_iteratively(
    run_id: str,
    limit: Optional[int] = None,
    max_workers: Optional[int] = None,
    prompt_id: Optional[str] = None
) -> Dict[str, Any]:
    """
    Process each step (in parallel) and aggregate results.

    Args:
        run_id: Run identifier
        limit: Maximum number of steps to process (for testing)
        max_workers: Maximum number of worker threads to use for parallel LLM calls
        prompt_id: Optional override identifier for selecting a specific prompt file

    Returns:
        Dict: Organized bits data grouped by step ID
    """
    # Load steps data from split_hierarchy
    steps_data = load_steps(run_id)
    if not steps_data:
        raise ValueError(f"No steps found for run ID: {run_id}")

    bits_by_step = {}
    failed_steps = []
    total_bits = 0

    steps_to_process = steps_data["steps"][:limit] if limit else steps_data["steps"]
    total_steps = len(steps_to_process)

    # Determine reasonable default for max_workers if not provided
    default_workers = min(8, (os.cpu_count() or 4) * 2)
    max_workers = max_workers or default_workers

    print(f"Processing {total_steps} steps concurrently with up to {max_workers} workers{f' (limited to {limit})' if limit else ''}...")

    # Submit all tasks to the thread pool
    futures_map = {}
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        for step in steps_to_process:
            future = executor.submit(generate_bits_for_single_step, step, run_id, prompt_id)
            futures_map[future] = step

        completed = 0
        # Iterate over completed futures as they finish
        for future in concurrent.futures.as_completed(futures_map):
            step = futures_map[future]
            completed += 1
            try:
                print(f"[{completed}/{total_steps}] Generating bits for: {step['id']}")
                step_result = future.result()

                # Extract bits and lesson content
                step_bits = step_result.get("bits", [])
                lesson_content = step_result["lesson_content"]
                generation_summary = step_result.get("generation_summary", "")

                formatted_bits = []

                if step_bits:
                    # Validate each bit immediately
                    for bit in step_bits:
                        validate_single_bit(bit, step["id"])

                    # Convert description to content field and group by step ID
                    for bit in step_bits:
                        formatted_bit = {
                            "id": bit["id"],
                            "slug": bit["slug"],
                            "title": bit["title"],
                            "content": bit["description"],  # Convert description to content
                            "stepId": bit["stepId"],
                            "order": bit["order"]
                        }
                        formatted_bits.append(formatted_bit)
                else:
                    print(f"  ⚠️  No bits generated for step {step['id']}; skipping bit validation.")

                # Store bits and lesson content for this step
                step_data = {
                    "bits": formatted_bits,
                    "lesson_content": lesson_content,
                    "generation_summary": generation_summary
                }
                bits_by_step[step["id"]] = step_data

                total_bits += len(step_bits)
                print(f"  ✅ Generated {len(step_bits)} bits + lesson content")

            except Exception as e:
                print(f"  ❌ Failed: {e}")
                failed_steps.append({"step_id": step["id"], "error": str(e)})
                # Continue processing remaining futures

    # Generate summary
    successful_steps = total_steps - len(failed_steps)
    summary = {
        "total_steps_processed": total_steps,
        "successful_steps": successful_steps,
        "failed_steps": len(failed_steps),
        "total_bits_generated": total_bits,
        "average_bits_per_step": total_bits / successful_steps if successful_steps > 0 else 0,
        "failed_step_details": failed_steps
    }

    # Create result with steps containing both bits and lesson content
    result = {
        "schema_version": BITS_SCHEMA_VERSION,
        **dict(bits_by_step),
    }  # Step IDs as top-level keys, each containing bits and lesson content
    result["generation_summary"] = summary  # Add summary

    try:
        validate_bits(result)
    except JsonSchemaValidationError as exc:
        raise ValueError(f"Bits schema validation failed for run '{run_id}'. {exc}") from exc

    print(f"✅ Completed: {successful_steps}/{total_steps} steps successful, {total_bits} total bits generated")

    if failed_steps:
        print(f"⚠️  {len(failed_steps)} steps failed - see generation_summary for details")

    return result

def organize_bits(
    run_id: str,
    force_text: bool = False,
    limit: Optional[int] = None,
    max_workers: Optional[int] = None,
    prompt_id: Optional[str] = None
) -> Dict[str, Any]:
    """
    Generate learning bits from learning steps.

    Args:
        run_id: Run identifier
        force_text: If True, use existing raw text output instead of calling API
        limit: Maximum number of steps to process (for testing)
        max_workers: Maximum number of concurrent worker threads to use for LLM calls (optional)
        prompt_id: Optional override identifier for selecting a specific prompt file

    Returns:
        Dict: Organized bits data
    """
    # Check if we should use existing raw output
    if force_text:
        existing_output = load_latest_output(
            pipeline="can-do-steps",
            step="bits",
            run_id=run_id,
            as_text=True,
            raw=True,
            subfolder="logs"
        )

        if existing_output:
            print(f"Using existing raw output for organize_bits in can-do-steps/{run_id}")
            # Need to load steps to get expected step IDs for validation
            steps_data = load_steps(run_id)
            if not steps_data:
                raise ValueError(f"No steps found for run ID: {run_id}")

            expected_step_ids = [step["id"] for step in steps_data["steps"]]
            parsed_bits = parse_organized_bits(existing_output, expected_step_ids)
            if "schema_version" not in parsed_bits:
                parsed_bits["schema_version"] = BITS_SCHEMA_VERSION

            try:
                validate_bits(parsed_bits)
            except JsonSchemaValidationError as exc:
                raise ValueError(f"Bits schema validation failed for run '{run_id}'. {exc}") from exc

            # Save the parsed output
            save_output(
                data=parsed_bits,
                pipeline="can-do-steps",
                step="bits",
                run_id=run_id,
                subfolder="archived"
            )

            return parsed_bits

    # Use the new step-by-step iterative approach (parallel)
    result = organize_bits_iteratively(run_id, limit, max_workers, prompt_id)

    # Save the result
    timestamped_filepath = save_output(
        data=result,
        pipeline="can-do-steps",
        step="bits",
        run_id=run_id,
        subfolder="archived"
    )

    # Create clean copy: bits-{run_id}.json
    create_clean_copy(
        timestamped_filepath=timestamped_filepath,
        pipeline="can-do-steps",
        step="bits",
        run_id=run_id
    )

    return result

def load_organized_bits(run_id: str) -> Optional[Dict[str, Any]]:
    """
    Load previously organized bits.

    Args:
        run_id: Run identifier

    Returns:
        Dict: Previously organized bits, or None if not found
    """
    return load_latest_output(
        pipeline="can-do-steps",
        step="bits",
        run_id=run_id,
        subfolder="archived"
    )

def get_bits_only(run_id: str) -> List[Dict[str, Any]]:
    """
    Get just the bits data without the generation summary or lesson content.

    Args:
        run_id: Run identifier

    Returns:
        List[Dict[str, Any]]: List of bit definitions from all steps
    """
    organized_data = load_organized_bits(run_id)
    if not organized_data:
        raise ValueError(f"No organized bits found for run ID: {run_id}")

    # Extract all bits from all steps (new structure has step IDs as keys)
    all_bits = []
    for key, value in organized_data.items():
        if key != "generation_summary" and isinstance(value, dict) and "bits" in value:
            all_bits.extend(value["bits"])

    return all_bits

def get_bits_for_step(run_id: str, step_id: str) -> List[Dict[str, Any]]:
    """
    Get all bits for a specific step.

    Args:
        run_id: Run identifier
        step_id: Step identifier

    Returns:
        List[Dict[str, Any]]: Bits for the step, ordered by their order field
    """
    organized_data = load_organized_bits(run_id)
    if not organized_data:
        raise ValueError(f"No organized bits found for run ID: {run_id}")

    # Get bits directly from the step data
    if step_id in organized_data and "bits" in organized_data[step_id]:
        step_bits = organized_data[step_id]["bits"]
        return sorted(step_bits, key=lambda x: x["order"])

    return []

def get_lesson_content_for_step(run_id: str, step_id: str) -> str:
    """
    Get lesson content for a specific step.

    Args:
        run_id: Run identifier
        step_id: Step identifier

    Returns:
        str: Lesson content for the step, or empty string if not found
    """
    organized_data = load_organized_bits(run_id)
    if not organized_data:
        raise ValueError(f"No organized bits found for run ID: {run_id}")

    # Get lesson content directly from the step data
    if step_id in organized_data and "lesson_content" in organized_data[step_id]:
        return organized_data[step_id]["lesson_content"]

    return ""

def get_step_data(run_id: str, step_id: str) -> Dict[str, Any]:
    """
    Get complete step data including bits, lesson content, and generation summary.

    Args:
        run_id: Run identifier
        step_id: Step identifier

    Returns:
        Dict[str, Any]: Complete step data with bits, lesson_content, and generation_summary
    """
    organized_data = load_organized_bits(run_id)
    if not organized_data:
        raise ValueError(f"No organized bits found for run ID: {run_id}")

    if step_id in organized_data and isinstance(organized_data[step_id], dict):
        return organized_data[step_id]

    return {"bits": [], "lesson_content": "", "generation_summary": ""}

def get_all_lesson_content(run_id: str) -> Dict[str, str]:
    """
    Get lesson content for all steps.

    Args:
        run_id: Run identifier

    Returns:
        Dict[str, str]: Dictionary mapping step IDs to their lesson content
    """
    organized_data = load_organized_bits(run_id)
    if not organized_data:
        raise ValueError(f"No organized bits found for run ID: {run_id}")

    lesson_content_by_step = {}
    for key, value in organized_data.items():
        if key != "generation_summary" and isinstance(value, dict) and "lesson_content" in value:
            lesson_content_by_step[key] = value["lesson_content"]

    return lesson_content_by_step

def get_total_bit_count(run_id: str) -> int:
    """
    Get the total number of bits generated.

    Args:
        run_id: Run identifier

    Returns:
        int: Total number of bits
    """
    bits = get_bits_only(run_id)
    return len(bits)

def get_bit_statistics(run_id: str) -> Dict[str, Any]:
    """
    Get statistics about the generated bits.

    Args:
        run_id: Run identifier

    Returns:
        Dict[str, Any]: Statistics including total bits, bits per step, etc.
    """
    bits = get_bits_only(run_id)
    steps_data = load_steps(run_id)
    if not steps_data:
        raise ValueError(f"No steps found for run ID: {run_id}")

    total_bits = len(bits)
    total_steps = len(steps_data["steps"])

    # Count bits per step
    step_bit_counts = {}
    for bit in bits:
        step_id = bit["stepId"]
        step_bit_counts[step_id] = step_bit_counts.get(step_id, 0) + 1

    # Get step levels for analysis
    step_levels = {step["id"]: step["level"] for step in steps_data["steps"]}
    level_bit_counts = {"beginner": 0, "intermediate": 0, "advanced": 0}

    for bit in bits:
        step_level = step_levels[bit["stepId"]]
        level_bit_counts[step_level] += 1

    return {
        "total_bits": total_bits,
        "total_steps": total_steps,
        "average_bits_per_step": total_bits / total_steps if total_steps > 0 else 0,
        "bits_per_step_distribution": step_bit_counts,
        "bits_by_level": level_bit_counts,
        "min_bits_per_step": min(step_bit_counts.values()) if step_bit_counts else 0,
        "max_bits_per_step": max(step_bit_counts.values()) if step_bit_counts else 0
    }
