"""Generate roadmap lessons using retrieved context from a vectorstore."""
from __future__ import annotations

import os
import concurrent.futures
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urljoin

from langchain_core.documents import Document

from chains.base import build_chain, default_json_parser, parse_output
from utils.io import create_clean_copy, load_latest_output, save_output
from utils.rag import (
    DEFAULT_EMBED_BATCH_SIZE,
    DEFAULT_EMBED_MODEL,
    get_vectorstore,
)
from utils.storage import path_exists, read_json


def _load_roadmap(roadmap_path: str) -> Dict[str, Any]:
    if not path_exists(roadmap_path):
        raise FileNotFoundError(f"Roadmap file not found: {roadmap_path}")
    return read_json(roadmap_path)


def _flatten_steps(roadmap: Dict[str, Any]) -> List[Dict[str, Any]]:
    steps: List[Dict[str, Any]] = []
    for track in roadmap.get("tracks", []):
        for path in track.get("paths", []):
            for step in path.get("steps", []):
                steps.append(
                    {
                        "track": track,
                        "path": path,
                        "step": step,
                    }
                )
    return steps


def _format_step_context(
    roadmap: Dict[str, Any],
    bundle: Dict[str, Any],
) -> str:
    track = bundle["track"]
    path = bundle["path"]
    step = bundle["step"]

    lines = [
        f"Roadmap: {roadmap.get('title', 'Untitled Roadmap')} ({roadmap.get('id', 'unknown-id')})",
        f"Track {track.get('order', '?')}: {track.get('title')} ({track.get('id')})",
        f"Track description: {track.get('description', 'No description provided.')}",
        f"Path {path.get('order', '?')}: {path.get('title')} ({path.get('id')})",
        f"Path description: {path.get('description', 'No description provided.')}",
        f"Step {step.get('order', '?')}: {step.get('title')} ({step.get('id')})",
        f"Step level: {step.get('level', 'unspecified')}",
        f"Step description: {step.get('description', 'No description provided.')}",
    ]
    return "\n".join(lines)


def _format_retrieval(documents: List[Tuple[Document, float]]) -> str:
    blocks: List[str] = []
    for idx, (doc, score) in enumerate(documents, start=1):
        metadata = doc.metadata or {}
        header = (
            f"[Chunk {idx} | score={score:.3f} | source={metadata.get('source_filename', metadata.get('source_file', 'unknown'))} | "
            f"chunk_index={metadata.get('chunk_index', 'n/a')}]"
        )
        content = doc.page_content.strip()
        blocks.append(f"{header}\n{content}")
    return "\n\n".join(blocks)


def _truncate_context(context: str, max_chars: int = 6000) -> str:
    if len(context) <= max_chars:
        return context
    return context[: max_chars - 3].rstrip() + "..."


def _combine_source_base(base: str, reference: str) -> str:
    if not reference:
        return reference
    if reference.startswith(("http://", "https://")):
        return reference
    normalized_base = base.rstrip("/") + "/"
    return urljoin(normalized_base, reference.lstrip("/"))


def _apply_source_base_to_references(values: Any, source_base: str) -> Any:
    if isinstance(values, list):
        prefixed: List[Any] = []
        for item in values:
            if isinstance(item, str):
                prefixed.append(_combine_source_base(source_base, item.strip()))
            elif isinstance(item, dict):
                updated: Dict[str, Any] = {}
                for key, value in item.items():
                    if isinstance(value, str) and key.lower() in {"url", "link", "href"}:
                        updated[key] = _combine_source_base(source_base, value.strip())
                    elif isinstance(value, list):
                        updated[key] = _apply_source_base_to_references(value, source_base)
                    else:
                        updated[key] = value
                prefixed.append(updated)
            elif isinstance(item, list):
                prefixed.append(_apply_source_base_to_references(item, source_base))
            else:
                prefixed.append(item)
        return prefixed
    if isinstance(values, dict):
        return {
            key: _apply_source_base_to_references(value, source_base)
            if isinstance(value, (list, dict))
            else _combine_source_base(source_base, value.strip())
            if isinstance(value, str) and key.lower() in {"url", "link", "href"}
            else value
            for key, value in values.items()
        }
    if isinstance(values, str):
        return _combine_source_base(source_base, values.strip())
    return values


def _parse_step_lesson(raw_output: str) -> Dict[str, Any]:
    parsed = parse_output(raw_output, default_json_parser)
    if "lesson" not in parsed:
        raise ValueError("Expected 'lesson' key in model output")

    lesson = parsed["lesson"]
    if not isinstance(lesson, dict):
        raise ValueError("'lesson' must be an object")

    required_fields = ["title", "intro", "body"]
    for field in required_fields:
        if field not in lesson:
            raise ValueError(f"Lesson missing required field '{field}'")

    if not isinstance(lesson["body"], str) or not lesson["body"].strip():
        raise ValueError("Lesson body must be non-empty text")

    return parsed


def _retrieve_context(
    query: str,
    collection_name: str,
    persist_directory: str,
    top_k: int,
    embedding_model: Optional[str] = None,
    embedding_batch_size: Optional[int] = None,
) -> List[Tuple[Document, float]]:
    vectorstore = get_vectorstore(
        collection_name=collection_name,
        persist_directory=persist_directory,
        embedding_model=embedding_model,
        embedding_batch_size=embedding_batch_size,
    )

    # similarity_search_with_score returns a list of tuples (Document, score)
    try:
        return vectorstore.similarity_search_with_score(query, k=top_k)
    except Exception as exc:
        raise RuntimeError(f"Failed to retrieve context from vectorstore: {exc}") from exc


def _generate_single_step_lesson(
    index: int,
    total_steps: int,
    bundle: Dict[str, Any],
    roadmap: Dict[str, Any],
    run_id: str,
    collection: str,
    persist_dir: str,
    top_k: int,
    embed_model: str,
    embed_batch_size: int,
    source_base: Optional[str],
) -> Tuple[str, Dict[str, Any], List[Dict[str, Any]]]:
    """Generate a lesson for a single step."""

    track = bundle["track"]
    path = bundle["path"]
    step = bundle["step"]
    step_id = step.get("id", f"step-{index}")

    step_context = _format_step_context(roadmap, bundle)
    retrieval_results = _retrieve_context(
        query=f"{step.get('title', '')} {step.get('description', '')}",
        collection_name=collection,
        persist_directory=persist_dir,
        top_k=top_k,
        embedding_model=embed_model,
        embedding_batch_size=embed_batch_size,
    )

    formatted_context = _format_retrieval(retrieval_results)
    truncated_context = _truncate_context(formatted_context)

    input_variables = {
        "roadmap_title": roadmap.get("title", "Untitled Roadmap"),
        "step_context": step_context,
        "resource_context": truncated_context or "",
    }

    print(f"🧠 Generating lesson for step {step_id} ({index}/{total_steps})")
    result = build_chain(
        chain_name="generate_step_lesson",
        pipeline="rag-roadmap",
        run_id=run_id,
        input_variables=input_variables,
    )

    parsed = _parse_step_lesson(result["output"])

    if source_base:
        lesson_obj = parsed.get("lesson")
        if isinstance(lesson_obj, dict) and "references" in lesson_obj:
            lesson_obj["references"] = _apply_source_base_to_references(
                lesson_obj.get("references", []),
                source_base,
            )

        if "references" in parsed:
            parsed["references"] = _apply_source_base_to_references(
                parsed.get("references", []),
                source_base,
            )

    lesson_entry = {
        "track": {
            "id": track.get("id"),
            "title": track.get("title"),
            "order": track.get("order"),
        },
        "path": {
            "id": path.get("id"),
            "title": path.get("title"),
            "order": path.get("order"),
        },
        "step": step,
        "lesson": parsed.get("lesson"),
        "key_points": parsed.get("key_points", []),
        "quiz": parsed.get("quiz", []),
        "references": parsed.get("references", []),
        "generation_summary": parsed.get("generation_summary", ""),
    }

    retrieval_entry = [
        {
            "chunk": idx,
            "score": float(f"{score:.4f}"),
            "metadata": doc.metadata,
        }
        for idx, (doc, score) in enumerate(retrieval_results, start=1)
    ]

    return step_id, lesson_entry, retrieval_entry


def generate_lessons_from_roadmap(
    run_id: str,
    roadmap_path: str,
    collection_name: Optional[str] = None,
    persist_directory: Optional[str] = None,
    top_k: int = 4,
    max_steps: Optional[int] = None,
    force_text: bool = False,
    embedding_model: Optional[str] = None,
    embedding_batch_size: Optional[int] = None,
    max_workers: Optional[int] = None,
    source_base: Optional[str] = None,
) -> Dict[str, Any]:
    """Generate lessons for each roadmap step using retrieved context."""
    if force_text:
        existing = load_latest_output(
            pipeline="rag-roadmap",
            step="lessons",
            run_id=run_id,
        )
        if existing:
            if source_base:
                lessons_section = existing.get("lessons", {})
                for step_data in lessons_section.values():
                    if not isinstance(step_data, dict):
                        continue
                    lesson_obj = step_data.get("lesson")
                    if isinstance(lesson_obj, dict) and "references" in lesson_obj:
                        lesson_obj["references"] = _apply_source_base_to_references(
                            lesson_obj.get("references", []),
                            source_base,
                        )
                    if "references" in step_data:
                        step_data["references"] = _apply_source_base_to_references(
                            step_data.get("references", []),
                            source_base,
                        )

                existing_config = existing.setdefault("config", {})
                existing_config["source_base"] = source_base

            print("Using existing generated lessons output")
            return existing

    roadmap = _load_roadmap(roadmap_path)
    steps = _flatten_steps(roadmap)
    if max_steps is not None:
        steps = steps[:max_steps]

    if not steps:
        raise ValueError("Roadmap contains no steps to process")

    ingestion_details = load_latest_output(
        pipeline="rag-roadmap",
        step="ingest_resource",
        run_id=run_id,
    )

    ingestion_details = ingestion_details or {}

    collection = collection_name or ingestion_details.get("collection_name") or f"roadmap_rag_{run_id}"
    persist_dir = (
        persist_directory
        or ingestion_details.get("persist_directory")
        or f".chroma_{collection}"
    )
    embed_model = embedding_model or ingestion_details.get("embedding_model") or DEFAULT_EMBED_MODEL
    embed_batch_size = (
        embedding_batch_size
        or ingestion_details.get("embedding_batch_size")
        or DEFAULT_EMBED_BATCH_SIZE
    )

    lessons_by_step: Dict[str, Any] = {}
    retrieval_metadata: Dict[str, Any] = {}

    total_steps = len(steps)

    default_workers = min(8, (os.cpu_count() or 4) * 2)
    workers = max_workers if max_workers is not None else default_workers

    if workers < 1:
        workers = 1

    def _collect_result(future: concurrent.futures.Future) -> None:
        step_id, lesson_entry, retrieval_entry = future.result()
        lessons_by_step[step_id] = lesson_entry
        retrieval_metadata[step_id] = retrieval_entry

    if workers > 1:
        print(f"Processing {total_steps} steps with up to {workers} workers...")
        with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
            futures = [
                executor.submit(
                    _generate_single_step_lesson,
                    index,
                    total_steps,
                    bundle,
                    roadmap,
                    run_id,
                    collection,
                    persist_dir,
                    top_k,
                    embed_model,
                    embed_batch_size,
                    source_base,
                )
                for index, bundle in enumerate(steps, start=1)
            ]

            for future in concurrent.futures.as_completed(futures):
                _collect_result(future)
    else:
        print(f"Processing {total_steps} steps sequentially...")
        for index, bundle in enumerate(steps, start=1):
            step_result = _generate_single_step_lesson(
                index,
                total_steps,
                bundle,
                roadmap,
                run_id,
                collection,
                persist_dir,
                top_k,
                embed_model,
                embed_batch_size,
                source_base,
            )
            step_id, lesson_entry, retrieval_entry = step_result
            lessons_by_step[step_id] = lesson_entry
            retrieval_metadata[step_id] = retrieval_entry

    output_payload = {
        "run_id": run_id,
        "roadmap": {
            "id": roadmap.get("id"),
            "title": roadmap.get("title"),
            "description": roadmap.get("description"),
        },
        "config": {
            "collection_name": collection,
            "persist_directory": str(Path(persist_dir).expanduser().resolve()),
            "top_k": top_k,
            "max_steps": max_steps,
            "embedding_model": embed_model,
            "embedding_batch_size": embed_batch_size,
            "source_base": source_base,
        },
        "lessons": lessons_by_step,
        "retrieval": retrieval_metadata,
        "total_steps_processed": len(lessons_by_step),
    }

    save_output(
        data=output_payload,
        pipeline="rag-roadmap",
        step="lessons",
        run_id=run_id,
    )

    # Create lesson-only "bits" style exports grouped by step ID for downstream use
    roadmap_id = roadmap.get("id") or run_id
    bits_step_name = f"bits_{roadmap_id}"

    lessons_subset: Dict[str, Any] = {}
    for step_id, step_data in lessons_by_step.items():
        lesson = step_data.get("lesson", {})
        lessons_subset[step_id] = {
            "lesson_content": {
                "title": lesson.get("title", ""),
                "intro": lesson.get("intro", ""),
                "body": lesson.get("body", ""),
                "definitions": lesson.get("definitions", []) or [],
                "references": lesson.get("references", []) or [],
                "illustrations": lesson.get("illustrations", []) or [],
            }
        }

    bits_filepath = save_output(
        data=lessons_subset,
        pipeline="rag-roadmap",
        step=bits_step_name,
        run_id=run_id,
    )

    create_clean_copy(
        timestamped_filepath=bits_filepath,
        pipeline="rag-roadmap",
        step=bits_step_name,
        run_id=run_id,
        clean_filename=f"bits_{roadmap_id}.json",
    )

    return output_payload


__all__ = ["generate_lessons_from_roadmap"]
