"""
Can-Do-Steps split hierarchy component for the Lumabit lesson-generation pipeline.
Phase 3: Split complete nested hierarchy into individual JSON files (deterministic, no LLM).
"""
from typing import Dict, Any, List, Optional

from utils.io import save_output, load_latest_output, create_clean_copy
from utils.storage import path_exists, read_json

def split_hierarchy(run_id: str) -> Dict[str, Any]:
    """
    Split the complete hierarchy into individual JSON files.

    This is a deterministic process that doesn't use LLM - just data transformation.

    Args:
        run_id: Run identifier

    Returns:
        Dict: Summary of the split operation
    """
    # Load the complete hierarchy from expand_and_create_hierarchy (unified pipeline)
    from chains.can_do_steps.expand_and_create_hierarchy import load_expanded_hierarchy
    hierarchy = load_expanded_hierarchy(run_id)

    if not hierarchy:
        raise ValueError(f"No complete hierarchy found for run ID: {run_id}")

    print(f"Splitting hierarchy for run ID: {run_id}")

    # Extract tracks data
    tracks_data = extract_tracks_json(hierarchy)

    # Extract paths data
    paths_data = extract_paths_json(hierarchy)

    # Extract steps data
    steps_data = extract_steps_json(hierarchy)

    # Save each JSON file separately and create clean copies
    tracks_filepath = save_output(
        data=tracks_data,
        pipeline="can-do-steps",
        step="tracks",
        run_id=run_id,
        subfolder="logs"
    )

    paths_filepath = save_output(
        data=paths_data,
        pipeline="can-do-steps",
        step="paths",
        run_id=run_id,
        subfolder="logs"
    )

    steps_filepath = save_output(
        data=steps_data,
        pipeline="can-do-steps",
        step="steps",
        run_id=run_id,
        subfolder="logs"
    )

    # Create clean copies: tracks-{run_id}.json, paths-{run_id}.json, steps-{run_id}.json
    create_clean_copy(
        timestamped_filepath=tracks_filepath,
        pipeline="can-do-steps",
        step="tracks",
        run_id=run_id,
        subfolder="logs"
    )

    create_clean_copy(
        timestamped_filepath=paths_filepath,
        pipeline="can-do-steps",
        step="paths",
        run_id=run_id,
        subfolder="logs"
    )

    create_clean_copy(
        timestamped_filepath=steps_filepath,
        pipeline="can-do-steps",
        step="steps",
        run_id=run_id,
        subfolder="logs"
    )

    # Create summary
    summary = {
        "tracks_count": len(tracks_data["tracks"]),
        "paths_count": len(paths_data["paths"]),
        "steps_count": len(steps_data["steps"]),
        "split_summary": f"Successfully split hierarchy into {len(tracks_data['tracks'])} tracks, {len(paths_data['paths'])} paths, and {len(steps_data['steps'])} steps"
    }

    # Save the split summary
    save_output(
        data=summary,
        pipeline="can-do-steps",
        step="split_hierarchy",
        run_id=run_id,
        subfolder="logs"
    )

    print(f"✅ Split complete: {summary['tracks_count']} tracks, {summary['paths_count']} paths, {summary['steps_count']} steps")

    return summary

def extract_tracks_json(hierarchy: Dict[str, Any]) -> Dict[str, Any]:
    """
    Extract tracks.json data from complete hierarchy.

    Args:
        hierarchy: Complete nested hierarchy

    Returns:
        Dict: tracks.json format data
    """
    tracks = []

    for track_index, track in enumerate(hierarchy["tracks"], start=1):
        order = track.get("order")
        if not isinstance(order, int):
            order = track_index

        tracks.append({
            "id": track["id"],
            "slug": track["slug"],
            "title": track["title"],
            "description": track["description"],
            "order": order
        })

    return {"tracks": tracks}

def extract_paths_json(hierarchy: Dict[str, Any]) -> Dict[str, Any]:
    """
    Extract paths.json data from complete hierarchy.

    Args:
        hierarchy: Complete nested hierarchy

    Returns:
        Dict: paths.json format data
    """
    paths = []

    for track in hierarchy["tracks"]:
        for path_index, path in enumerate(track["paths"], start=1):
            order = path.get("order")
            if not isinstance(order, int):
                order = path_index

            paths.append({
                "id": path["id"],
                "slug": path["slug"],
                "title": path["title"],
                "description": path["description"],
                "trackId": track["id"],  # Reference to parent track
                "order": order
            })

    return {"paths": paths}

def extract_steps_json(hierarchy: Dict[str, Any]) -> Dict[str, Any]:
    """
    Extract steps.json data from complete hierarchy.

    Args:
        hierarchy: Complete nested hierarchy

    Returns:
        Dict: steps.json format data
    """
    steps = []

    for track in hierarchy["tracks"]:
        for path in track["paths"]:
            for step_index, step in enumerate(path["steps"], start=1):
                # Create step data with required fields
                level = step.get("level")
                if level not in {"beginner", "intermediate", "advanced"}:
                    level = "unspecified"

                order = step.get("order")
                if not isinstance(order, int):
                    order = step_index

                step_data = {
                    "id": step["id"],
                    "slug": step["slug"],
                    "title": step["title"],
                    "description": step["description"],
                    "level": level,
                    "pathId": path["id"],  # Reference to parent path
                    "order": order,
                }

                # Preserve optional classification for downstream prompt selection
                if "type" in step:
                    step_data["type"] = step["type"]

                # Handle different formats - if step has statements field, use it
                # Otherwise, use the title as the statement (expand_and_create_hierarchy format)
                if "statements" in step:
                    step_data["statements"] = step["statements"]
                else:
                    step_data["statements"] = [step["title"]]

                steps.append(step_data)

    return {"steps": steps}

def load_roadmap(run_id: str) -> Optional[Dict[str, Any]]:
    """
    Load the roadmap clean copy if it exists.

    Args:
        run_id: Run identifier

    Returns:
        Dict: Roadmap data or None if not found/invalid
    """
    roadmap_path = f"output/can-do-steps/{run_id}/roadmap-{run_id}.json"
    if not path_exists(roadmap_path):
        return None

    try:
        return read_json(roadmap_path)
    except Exception as exc:
        print(f"⚠️ Error loading roadmap file {roadmap_path}: {exc}")
        return None

def _load_from_roadmap(run_id: str, extractor, label: str) -> Optional[Dict[str, Any]]:
    """
    Attempt to derive split data from the roadmap clean copy.

    Args:
        run_id: Run identifier
        extractor: Function to extract the desired split data from a roadmap dict
        label: Human-friendly label for logging

    Returns:
        Dict containing the extracted data, or None if roadmap unavailable
    """
    roadmap = load_roadmap(run_id)
    if not roadmap:
        return None

    print(f"ℹ️ Derived {label} from roadmap-{run_id}.json (split_hierarchy not required)")
    return extractor(roadmap)

def load_tracks(run_id: str) -> Optional[Dict[str, Any]]:
    """
    Load the tracks.json data.

    Args:
        run_id: Run identifier

    Returns:
        Dict: tracks.json data or None if not found
    """
    tracks = load_latest_output(
        pipeline="can-do-steps",
        step="tracks",
        run_id=run_id,
        subfolder="logs"
    )
    if tracks:
        return tracks
    tracks = load_latest_output(
        pipeline="can-do-steps",
        step="tracks",
        run_id=run_id
    )
    if tracks:
        return tracks

    return _load_from_roadmap(run_id, extract_tracks_json, "tracks")

def load_paths(run_id: str) -> Optional[Dict[str, Any]]:
    """
    Load the paths.json data.

    Args:
        run_id: Run identifier

    Returns:
        Dict: paths.json data or None if not found
    """
    paths = load_latest_output(
        pipeline="can-do-steps",
        step="paths",
        run_id=run_id,
        subfolder="logs"
    )
    if paths:
        return paths
    paths = load_latest_output(
        pipeline="can-do-steps",
        step="paths",
        run_id=run_id
    )
    if paths:
        return paths

    return _load_from_roadmap(run_id, extract_paths_json, "paths")

def load_steps(run_id: str) -> Optional[Dict[str, Any]]:
    """
    Load the steps.json data.

    Args:
        run_id: Run identifier

    Returns:
        Dict: steps.json data or None if not found
    """
    steps = load_latest_output(
        pipeline="can-do-steps",
        step="steps",
        run_id=run_id,
        subfolder="logs"
    )
    if steps:
        return steps
    steps = load_latest_output(
        pipeline="can-do-steps",
        step="steps",
        run_id=run_id
    )
    if steps:
        return steps

    return _load_from_roadmap(run_id, extract_steps_json, "steps")

def get_track_by_id(run_id: str, track_id: str) -> Optional[Dict[str, Any]]:
    """
    Get a specific track by ID.

    Args:
        run_id: Run identifier
        track_id: Track identifier

    Returns:
        Dict: Track data or None if not found
    """
    tracks_data = load_tracks(run_id)
    if not tracks_data:
        return None

    for track in tracks_data["tracks"]:
        if track["trackId"] == track_id:
            return track

    return None

def get_paths_for_track(run_id: str, track_id: str) -> List[Dict[str, Any]]:
    """
    Get all paths for a specific track.

    Args:
        run_id: Run identifier
        track_id: Track identifier

    Returns:
        List[Dict]: Paths belonging to the track
    """
    paths_data = load_paths(run_id)
    if not paths_data:
        return []

    return [path for path in paths_data["paths"] if path["trackId"] == track_id]

def get_steps_for_path(run_id: str, path_id: str) -> List[Dict[str, Any]]:
    """
    Get all steps for a specific path.

    Args:
        run_id: Run identifier
        path_id: Path identifier

    Returns:
        List[Dict]: Steps belonging to the path, ordered by their order field
    """
    steps_data = load_steps(run_id)
    if not steps_data:
        return []

    path_steps = [step for step in steps_data["steps"] if step["pathId"] == path_id]
    return sorted(path_steps, key=lambda x: x["order"])

def get_all_statements_from_steps(run_id: str) -> List[str]:
    """
    Get all can-do statements from all steps.

    Args:
        run_id: Run identifier

    Returns:
        List[str]: All can-do statements
    """
    steps_data = load_steps(run_id)
    if not steps_data:
        return []

    all_statements = []
    for step in steps_data["steps"]:
        all_statements.extend(step["statements"])

    return all_statements

def validate_hierarchy_consistency(run_id: str) -> Dict[str, Any]:
    """
    Validate that the split hierarchy maintains consistency.

    Args:
        run_id: Run identifier

    Returns:
        Dict: Validation results and statistics
    """
    tracks_data = load_tracks(run_id)
    paths_data = load_paths(run_id)
    steps_data = load_steps(run_id)

    if not all([tracks_data, paths_data, steps_data]):
        raise ValueError("Missing split hierarchy data")

    validation = {
        "tracks_count": len(tracks_data["tracks"]),
        "paths_count": len(paths_data["paths"]),
        "steps_count": len(steps_data["steps"]),
        "orphaned_paths": [],
        "orphaned_steps": [],
        "total_statements": 0,
        "level_distribution": {"beginner": 0, "intermediate": 0, "advanced": 0, "unspecified": 0}
    }

    # Check for orphaned paths (paths with invalid trackId references)
    track_ids = {track["trackId"] for track in tracks_data["tracks"]}
    for path in paths_data["paths"]:
        if path["trackId"] not in track_ids:
            validation["orphaned_paths"].append(path["pathId"])

    # Check for orphaned steps (steps with invalid pathId references)
    path_ids = {path["id"] for path in paths_data["paths"]}
    for step in steps_data["steps"]:
        if step["pathId"] not in path_ids:
            validation["orphaned_steps"].append(step["stepId"])

        # Count statements and level distribution
        validation["total_statements"] += len(step["statements"])
        level = step.get("level")
        if level not in {"beginner", "intermediate", "advanced"}:
            level = "unspecified"
        validation["level_distribution"].setdefault(level, 0)
        validation["level_distribution"][level] += 1

    validation["is_consistent"] = len(validation["orphaned_paths"]) == 0 and len(validation["orphaned_steps"]) == 0

    return validation
