"""
Can-Do-Steps complete pipeline orchestrator for the Lumabit lesson-generation pipeline.
Manages the full workflow from expansion through export.
"""
import os
import json
from typing import Dict, Any, List, Optional, Literal
from datetime import datetime

from utils.io import save_output, load_latest_output
from chains.can_do_steps.organize_bits import organize_bits, load_organized_bits
from chains.can_do_steps.split_hierarchy import split_hierarchy, load_tracks, load_paths, load_steps
from chains.can_do_steps.expand_and_create_hierarchy import expand_and_create_hierarchy, load_expanded_hierarchy

def generate_complete_can_do_hierarchy(
    run_id: str,
    force_text: bool = False,
    skip_phases: Optional[List[str]] = None,
    export_json: bool = True,
    output_dir: Optional[str] = None,
    topic: Optional[str] = None,
    audience: Optional[str] = None,
    purpose: Optional[str] = None,
    style: Optional[str] = None,
    notes: Optional[str] = None,
    language: Optional[str] = None,
    prompt_id: Optional[str] = None,
    input_prefix: Optional[str] = None,
) -> Dict[str, Any]:
    """
    Run the complete can-do-steps pipeline using UNIFIED architecture.

    UNIFIED ARCHITECTURE:
    1. Expand statements AND create hierarchy in a single step
    2. Split hierarchy (deterministic JSON file splitting)
    3. Organize bits (generate learning bits from steps)
    4. Export JSON files

    Args:
        run_id: Run identifier
        force_text: If True, use existing raw outputs instead of calling API
        skip_phases: Optional list of phases to skip
        export_json: If True, export JSON files at the end
        output_dir: Optional output directory for JSON exports
        topic: Optional topic for journey-based can-do statements expansion
        audience: (optional) Intended audience for the statements
        purpose: (optional) Purpose / learning objective focus
        style: (optional) Language style to use for statements
        notes: (optional) Additional guidance or context for the prompt
        language: (optional) Output language to request from the model
        prompt_id: (optional) Override identifier for selecting specific prompt files
        input_prefix: (optional) Input prefix/path for run-scoped input files

    Returns:
        Dict: Complete pipeline results and statistics

    Note:
        This function will automatically look for input files:
        - input/cando-{run_id}.txt for required can-do statements
        - input/tracks-{run_id}.txt for required tracks
    """
    return run_unified_pipeline(
        run_id, force_text, skip_phases, export_json, output_dir,
        topic, audience, purpose, style, notes, language, prompt_id, input_prefix
    )


def run_unified_pipeline(
    run_id: str,
    force_text: bool = False,
    skip_phases: Optional[List[str]] = None,
    export_json: bool = True,
    output_dir: Optional[str] = None,
    topic: Optional[str] = None,
    audience: Optional[str] = None,
    purpose: Optional[str] = None,
    style: Optional[str] = None,
    notes: Optional[str] = None,
    language: Optional[str] = None,
    prompt_id: Optional[str] = None,
    input_prefix: Optional[str] = None,
) -> Dict[str, Any]:
    """
    Run the unified pipeline architecture.

    1. Expand and create hierarchy (single LLM call)
    2. Split hierarchy (deterministic)
    3. Organize bits
    4. Export JSON

    Args:
        run_id: Run identifier
        force_text: If True, use existing raw outputs
        skip_phases: Optional list of phases to skip
        export_json: If True, export JSON files
        output_dir: Optional output directory
        topic: Optional topic for journey-based can-do statements expansion
        audience: (optional) Intended audience for the statements
        purpose: (optional) Purpose / learning objective focus
        style: (optional) Language style to use for statements
        notes: (optional) Additional guidance or context for the prompt
        language: (optional) Output language to request from the model
        prompt_id: (optional) Override identifier for selecting specific prompt files
        input_prefix: (optional) Input prefix/path for run-scoped input files

    Returns:
        Dict: Pipeline results

    Note:
        This function will automatically look for input files:
        - input/cando-{run_id}.txt for required can-do statements
        - input/tracks-{run_id}.txt for required tracks
    """
    skip_phases = skip_phases or []
    start_time = datetime.now()

    print(f"=== Starting Can-Do-Steps Pipeline for run ID: {run_id} ===")
    print(f"Force text mode: {force_text}")
    print(f"Skip phases: {skip_phases}")
    print(f"Export JSON: {export_json}")

    results = {
        "run_id": run_id,
        "start_time": start_time.isoformat(),
        "phases_completed": [],
        "phases_skipped": skip_phases,
        "data": {},
        "statistics": {},
        "export_paths": {},
        "errors": []
    }

    try:
        # Phase 1: Expand And Create Hierarchy
        if "expand_and_create_hierarchy" not in skip_phases:
            print(f"\n--- Phase 1: Expanding Statements & Creating Hierarchy ---")
            expanded_hierarchy = expand_and_create_hierarchy(
                run_id=run_id,
                force_text=force_text,
                topic=topic,
                audience=audience,
                purpose=purpose,
                style=style,
                notes=notes,
                language=language,
                prompt_id=prompt_id,
                input_prefix=input_prefix,
            )
            results["data"]["expanded_hierarchy"] = expanded_hierarchy
            results["phases_completed"].append("expand_and_create_hierarchy")

            original_count = expanded_hierarchy.get("original_count", 0)
            new_count = expanded_hierarchy.get("new_count", 0)
            total_count = expanded_hierarchy.get("total_count", 0)
            track_count = len(expanded_hierarchy["tracks"])

            # Count paths and steps
            path_count = sum(len(track["paths"]) for track in expanded_hierarchy["tracks"])
            step_count = sum(sum(len(path["steps"]) for path in track["paths"]) for track in expanded_hierarchy["tracks"])

            print(f"✓ Expanded {original_count} original statements to {total_count} total statements ({new_count} new)")
            print(f"✓ Created complete hierarchy with {track_count} tracks, {path_count} paths, and {step_count} steps")
        else:
            print(f"\n--- Phase 1: Skipped (using existing data) ---")
            expanded_hierarchy = load_expanded_hierarchy(run_id)
            if expanded_hierarchy:
                results["data"]["expanded_hierarchy"] = expanded_hierarchy

        # Phase 2: Split Hierarchy
        if "split_hierarchy" not in skip_phases:
            print(f"\n--- Phase 2: Splitting Hierarchy ---")
            # First need to save the hierarchy in the format expected by split_hierarchy
            if "expanded_hierarchy" in results["data"]:
                hierarchy_data = results["data"]["expanded_hierarchy"]
                save_output(
                    data=hierarchy_data,
                    pipeline="can-do-steps",
                    step="expand_and_create_hierarchy",
                    run_id=run_id,
                    subfolder="archived",
                )

            split_data = split_hierarchy(run_id)
            results["data"]["split_hierarchy"] = split_data
            results["phases_completed"].append("split_hierarchy")

            print(f"✓ Split hierarchy into individual files: {split_data['tracks_count']} tracks, {split_data['paths_count']} paths, {split_data['steps_count']} steps")

            # Load the split data for further processing
            results["data"]["tracks"] = load_tracks(run_id)
            results["data"]["paths"] = load_paths(run_id)
            results["data"]["steps"] = load_steps(run_id)
        else:
            print(f"\n--- Phase 2: Skipped (using existing data) ---")
            results["data"]["tracks"] = load_tracks(run_id)
            results["data"]["paths"] = load_paths(run_id)
            results["data"]["steps"] = load_steps(run_id)

        # Phase 3: Organize Bits
        if "organize_bits" not in skip_phases:
            print(f"\n--- Phase 3: Organizing into Learning Bits ---")
            bits_data = organize_bits(run_id, force_text, prompt_id=prompt_id)
            results["data"]["bits"] = bits_data
            results["phases_completed"].append("organize_bits")

            bit_count = len(bits_data["bits"])
            step_count = len(results["data"]["steps"]["steps"]) if "steps" in results["data"] else 0
            avg_bits_per_step = bit_count / step_count if step_count > 0 else 0

            print(f"✓ Created {bit_count} learning bits ({avg_bits_per_step:.1f} per step)")
        else:
            print(f"\n--- Phase 3: Skipped (using existing data) ---")
            bits_data = load_organized_bits(run_id)
            if bits_data:
                results["data"]["bits"] = bits_data

        # Compile final statistics
        results["statistics"] = compile_pipeline_statistics(results["data"])

        # Validate data integrity
        print(f"\n--- Validating Data Integrity ---")
        validation_results = validate_export_data_integrity(run_id)
        results["validation"] = validation_results

        if validation_results["valid"]:
            print("✓ Data integrity validation passed")
        else:
            print("⚠ Data integrity validation failed:")
            for error in validation_results["errors"]:
                print(f"   ERROR: {error}")

        if validation_results["warnings"]:
            print("Warnings:")
            for warning in validation_results["warnings"]:
                print(f"   WARNING: {warning}")

        # Final summary
        end_time = datetime.now()
        duration = end_time - start_time
        results["end_time"] = end_time.isoformat()
        results["duration_seconds"] = duration.total_seconds()

        print(f"\n=== Pipeline Complete ===")
        print(f"Duration: {duration.total_seconds():.1f} seconds")
        print(f"Phases completed: {', '.join(results['phases_completed'])}")

        if results["errors"]:
            print(f"Errors encountered: {len(results['errors'])}")
            for error in results["errors"]:
                print(f"   - {error}")

        # Save the complete pipeline results
        save_output(
            data=results,
            pipeline="can-do-steps",
            step="complete_pipeline",
            run_id=run_id,
            subfolder="logs",
        )

        return results

    except Exception as e:
        error_msg = f"Pipeline failed at phase: {e}"
        print(f"❌ {error_msg}")
        results["errors"].append(error_msg)
        results["end_time"] = datetime.now().isoformat()

        # Save partial results even on failure
        save_output(
            data=results,
            pipeline="can-do-steps",
            step="complete_pipeline",
            run_id=run_id,
            subfolder="logs",
        )

        raise

def compile_pipeline_statistics(data: Dict[str, Any]) -> Dict[str, Any]:
    """
    Compile comprehensive statistics from pipeline data.

    Args:
        data: Pipeline data dictionary

    Returns:
        Dict[str, Any]: Compiled statistics
    """
    stats = {}

    # Expanded hierarchy statistics
    if "expanded_hierarchy" in data:
        expanded = data["expanded_hierarchy"]
        stats["statements"] = {
            "original_count": expanded.get("original_count", 0),
            "new_count": expanded.get("new_count", 0),
            "total_count": expanded.get("total_count", 0)
        }

    # Tracks statistics
    if "tracks" in data:
        tracks = data["tracks"]["tracks"]
        stats["tracks"] = {
            "count": len(tracks),
            "titles": [track["title"] for track in tracks]
        }


    # Paths statistics
    if "paths" in data:
        paths = data["paths"]["paths"]
        stats["paths"] = {
            "count": len(paths),
            "average_per_track": len(paths) / len(stats.get("tracks", {}).get("titles", [1]))
        }

        # Path distribution by track
        track_path_counts = {}
        for path in paths:
            track_id = path["trackId"]
            track_path_counts[track_id] = track_path_counts.get(track_id, 0) + 1
        stats["path_distribution"] = track_path_counts

    # Steps statistics
    if "steps" in data:
        steps = data["steps"]["steps"]
        level_dist = data["steps"]["level_distribution"]
        stats["steps"] = {
            "count": len(steps),
            "level_distribution": level_dist,
            "level_percentages": {
                level: (count / len(steps) * 100) if len(steps) > 0 else 0
                for level, count in level_dist.items()
            }
        }

        # Step distribution by path
        path_step_counts = {}
        for step in steps:
            path_id = step["pathId"]
            path_step_counts[path_id] = path_step_counts.get(path_id, 0) + 1
        stats["step_distribution"] = path_step_counts

    # Bits statistics
    if "bits" in data:
        bits = data["bits"]["bits"]
        stats["bits"] = {
            "count": len(bits),
            "average_per_step": len(bits) / len(stats.get("steps", {}).get("count", [1])) if "steps" in stats else 0
        }

        # Bit distribution by step
        step_bit_counts = {}
        for bit in bits:
            step_id = bit["stepId"]
            step_bit_counts[step_id] = step_bit_counts.get(step_id, 0) + 1
        stats["bit_distribution"] = step_bit_counts

    return stats

def run_pipeline_phase(
    phase_name: str,
    run_id: str,
    force_text: bool = False
) -> Dict[str, Any]:
    """
    Run a single phase of the pipeline.

    Args:
        phase_name: Name of the phase to run
        run_id: Run identifier
        force_text: If True, use existing raw text outputs

    Returns:
        Dict: Phase results
    """
    phase_functions = {
        # UNIFIED architecture phases
        "expand_and_create_hierarchy": expand_and_create_hierarchy,
        "split_hierarchy": split_hierarchy,
        "organize_bits": organize_bits,
    }

    if phase_name not in phase_functions:
        raise ValueError(f"Unknown phase: {phase_name}. Valid phases: {list(phase_functions.keys())}")

    print(f"Running phase: {phase_name}")
    phase_function = phase_functions[phase_name]
    return phase_function(run_id, force_text)

def get_pipeline_status(run_id: str) -> Dict[str, Any]:
    """
    Get the current status of a pipeline run.

    Args:
        run_id: Run identifier

    Returns:
        Dict: Pipeline status information
    """
    status = {
        "run_id": run_id,
        "phases_completed": [],
        "phases_available": [],
        "ready_for_export": False,
        "data_summary": {}
    }

    # Check each phase for UNIFIED architecture
    unified_phases = [
        ("expand_and_create_hierarchy", load_expanded_hierarchy),
        ("split_hierarchy", lambda run_id: {"tracks": load_tracks(run_id), "paths": load_paths(run_id), "steps": load_steps(run_id)}),
        ("organize_bits", load_organized_bits)
    ]

    for phase_name, load_function in unified_phases:
        try:
            data = load_function(run_id)
            if data:
                status["phases_completed"].append(phase_name)
            else:
                status["phases_available"].append(phase_name)
        except Exception:
            status["phases_available"].append(phase_name)

    # Check if ready for export
    unified_required_phases = ["expand_and_create_hierarchy", "split_hierarchy", "organize_bits"]
    status["ready_for_export"] = all(phase in status["phases_completed"] for phase in unified_required_phases)

    # Get export summary if ready
    if status["ready_for_export"]:
        try:
            status["export_summary"] = get_export_summary(run_id)
        except Exception as e:
            status["export_error"] = str(e)

    return status

def cleanup_pipeline_run(run_id: str) -> Dict[str, Any]:
    """
    Clean up intermediate files for a pipeline run.

    Args:
        run_id: Run identifier

    Returns:
        Dict: Cleanup results
    """
    # This would implement cleanup of intermediate files
    # For now, just return a status
    return {
        "run_id": run_id,
        "cleanup_performed": False,
        "message": "Cleanup functionality not yet implemented"
    }
