"""
Can-Do-Steps review content component for the Lumabit lesson-generation pipeline.
Evaluates roadmap learning steps against quality criteria before publication.
"""
import os
import re
import json
import traceback
import concurrent.futures
from typing import Dict, Any, List, Optional, Tuple

from utils.io import (
    save_output,
    load_latest_output,
    create_clean_copy,
    load_optional_input_text,
    build_input_path,
)
from chains.base import build_chain, default_json_parser, parse_output
from chains.can_do_steps.split_hierarchy import load_steps, load_paths, load_tracks
from utils.storage import path_exists, read_json

REVIEW_CRITERIA: List[str] = [
    "Alignment: The translation is correct for the language being learned.",
    "Suitability: The translation is correct and would be suitable for everyday conversations."
]

def _clean_criteria_line(line: str) -> str:
    """
    Normalize a line from a review criteria file by removing numbering/bullets.
    """
    cleaned = line.strip()
    if not cleaned:
        return ""

    # Remove leading numbering like "1." or "1) "
    cleaned = re.sub(r"^\d+[\.\)\-:]*\s*", "", cleaned)
    # Remove leading bullet characters
    cleaned = re.sub(r"^[-*•]\s*", "", cleaned)
    return cleaned.strip()


def load_review_criteria(run_id: str, input_prefix: Optional[str] = None) -> Tuple[List[str], Optional[str]]:
    """
    Load review criteria from optional input/review-{run_id}.txt file.

    Args:
        run_id: Run identifier / roadmap slug.

    Returns:
        Tuple of (criteria list, filepath if a custom file was used).
    """
    filename = f"review-{run_id}.txt"
    review_path = build_input_path(filename, prefix=input_prefix)
    try:
        content = load_optional_input_text(filename, prefix=input_prefix, strip=False)
    except Exception as exc:
        print(f"⚠️ Error reading review criteria file {review_path}: {exc}. Using default criteria.")
        content = None

    if content is not None:
        cleaned_lines = []
        for raw_line in content.splitlines():
            normalized = _clean_criteria_line(raw_line)
            if normalized:
                cleaned_lines.append(normalized)

        if cleaned_lines:
            print(f"✓ Loaded custom review criteria from {review_path}: {len(cleaned_lines)} items")
            return cleaned_lines, review_path
        print(f"⚠️ Custom review criteria file {review_path} had no usable items; falling back to defaults.")
    else:
        print(f"ℹ️ No optional review criteria file found at {review_path}; using defaults.")

    return list(REVIEW_CRITERIA), None


def slugify_criterion(text: str) -> str:
    """
    Create a stable key from a criterion description.
    """
    base = text.split(":", 1)[0] if ":" in text else text
    base = base.strip().lower()
    base = re.sub(r"[^a-z0-9]+", "-", base)
    base = re.sub(r"-{2,}", "-", base).strip("-")
    return base


def build_criteria_items(criteria_list: List[str]) -> List[Dict[str, str]]:
    """
    Build structured criteria entries with explicit keys for prompt clarity.

    Args:
        criteria_list: List of user or default criteria strings

    Returns:
        List of dicts containing 'key' and 'text'
    """
    items: List[Dict[str, str]] = []
    seen_keys: Dict[str, int] = {}

    for idx, criterion in enumerate(criteria_list, start=1):
        key = slugify_criterion(criterion)
        if not key:
            key = f"criterion-{idx}"

        base_key = key
        if key in seen_keys:
            seen_keys[key] += 1
            key = f"{base_key}-{seen_keys[key]}"
        else:
            seen_keys[key] = 1

        items.append({"key": key, "text": criterion})

    return items


def load_step_details(run_id: str) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
    """
    load step detail data from bits-{run_id}.json clean copy.

    Args:
        run_id: roadmap identifier

    Returns:
        Tuple of (step data dict keyed by step_id, file path if loaded)
    """
    candidate_filenames = [
        f"bits-{run_id}.json",
        f"bits_{run_id}.json",
    ]

    for filename in candidate_filenames:
        bits_filepath = os.path.join("output", "can-do-steps", run_id, filename)
        if path_exists(bits_filepath):
            try:
                data = read_json(bits_filepath)
                print(f"✓ Loaded bits file for step details: {bits_filepath}")
                return data, bits_filepath
            except Exception as exc:
                print(f"⚠️ Error loading bits file {bits_filepath}: {exc}")
                return None, None

    print(f"⚠️ Bits file not found for run {run_id}. Tried: {', '.join(candidate_filenames)}")
    return None, None


def load_roadmap_file(run_id: str) -> Optional[Dict[str, Any]]:
    """
    Load the roadmap JSON file for the given run ID.

    Args:
        run_id: Run identifier

    Returns:
        Dict: Loaded roadmap data, or None if not found
    """
    roadmap_path = f"output/can-do-steps/{run_id}/roadmap-{run_id}.json"

    if path_exists(roadmap_path):
        try:
            data = read_json(roadmap_path)
            print(f"✓ Loaded roadmap file: {roadmap_path}")
            return data
        except Exception as exc:
            print(f"⚠️ Error loading roadmap file {roadmap_path}: {exc}")
            return None

    print(f"⚠️ Roadmap file not found: {roadmap_path}")
    return None


def build_lookup(items: Optional[List[Dict[str, Any]]], key: str = "id") -> Dict[str, Dict[str, Any]]:
    """
    Build a dictionary lookup from a list of dictionaries.

    Args:
        items: List of dictionaries containing the lookup key
        key: Key to index by

    Returns:
        Dict mapping key values to dicts
    """
    if not items:
        return {}

    lookup: Dict[str, Dict[str, Any]] = {}
    for item in items:
        if isinstance(item, dict) and key in item:
            lookup[item[key]] = item
    return lookup


def parse_review_output(output: str, expected_step_id: str) -> Dict[str, Any]:
    """
    Parse and validate the review output returned by the LLM.

    Args:
        output: Raw output from the LLM
        expected_bit_id: The bit ID we asked the model to review

    Returns:
        Dict containing parsed review information
    """
    parsed = parse_output(output, default_json_parser)

    if not isinstance(parsed, dict):
        raise ValueError("Parsed review output must be a JSON object.")

    step_id = parsed.get("step_id", expected_step_id)
    if step_id != expected_step_id:
        print(f"⚠️ Step ID mismatch in review output. Expected {expected_step_id}, got {step_id}. Using expected ID.")
        step_id = expected_step_id

    passes_review = parsed.get("passes_review")
    if not isinstance(passes_review, bool):
        raise ValueError("Review output must include boolean 'passes_review'.")

    criteria_results = parsed.get("criteria_results")
    if not isinstance(criteria_results, dict):
        criteria_results = {}

    feedback = parsed.get("feedback", "")
    if not isinstance(feedback, str):
        feedback = str(feedback)

    return {
        "step_id": step_id,
        "passes_review": passes_review,
        "criteria_results": criteria_results,
        "feedback": feedback.strip(),
    }


def collect_review_targets(
    steps: List[Dict[str, Any]],
    path_lookup: Dict[str, Dict[str, Any]],
    track_lookup: Dict[str, Dict[str, Any]],
) -> List[Dict[str, Any]]:
    """
    Collect all step review targets with their associated hierarchy context.

    Args:
        steps: List of roadmap steps to evaluate
        path_lookup: Lookup table of paths by ID
        track_lookup: Lookup table of tracks by ID

    Returns:
        List of dictionaries containing step, path, and track context
    """
    targets: List[Dict[str, Any]] = []

    for step in steps:
        if not isinstance(step, dict):
            continue

        step_id = step.get("id")
        if not step_id:
            continue

        path_id = step.get("pathId")
        path_info = path_lookup.get(path_id) if path_lookup else None

        track_info = None
        if path_info:
            track_info = track_lookup.get(path_info.get("trackId"))
        else:
            track_info = track_lookup.get(step.get("trackId")) if track_lookup else None

        targets.append(
            {
                "step": step,
                "path": path_info,
                "track": track_info,
            }
        )

    return targets


def _review_single_target(
    index: int,
    total_targets: int,
    target: Dict[str, Any],
    run_id: str,
    roadmap_title: str,
    roadmap_description: str,
    criteria_text: str,
    criteria_keys_instruction: str,
    criteria_example_pairs: str,
    step_details_map: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
    """
    Review a single step target and return the aggregated result with logs.
    """
    step = target.get("step") or {}
    path = target.get("path") or {}
    track = target.get("track") or {}

    step_id = step.get("id", f"step-{index + 1}")
    logs = [f"\n➡️ Reviewing step {index + 1}/{total_targets}: {step_id}"]

    step_title = step.get("title", "")
    step_description = step.get("description", "")
    step_level = step.get("level", "")
    step_statements = step.get("statements", [])
    if isinstance(step_statements, list):
        step_statements_text = "; ".join(step_statements) if step_statements else "None provided"
    elif isinstance(step_statements, str) and step_statements.strip():
        step_statements_text = step_statements
    else:
        step_statements_text = "None provided"

    path_id = path.get("id", "")
    path_title = path.get("title", "")
    path_description = path.get("description", "")

    track_id = track.get("id", "")
    track_title = track.get("title", "")
    track_description = track.get("description", "")

    step_details_text = "Step details unavailable."
    if isinstance(step_details_map, dict):
        raw_step_details = step_details_map.get(step_id)
        if raw_step_details is None:
            logs.append(f"⚠️ Step details not found for {step_id} in bits file; proceeding without them.")
        else:
            step_details_text = json.dumps(raw_step_details, indent=2, ensure_ascii=False)

    try:
        result = build_chain(
            chain_name="review_content",
            pipeline="can-do-steps",
            run_id=run_id,
            input_variables={
                "run_id": run_id,
                "roadmap_title": roadmap_title,
                "roadmap_description": roadmap_description,
                "track_title": track_title,
                "track_id": track_id,
                "track_description": track_description,
                "path_title": path_title,
                "path_id": path_id,
                "path_description": path_description,
                "step_title": step_title,
                "step_id": step_id,
                "step_description": step_description,
                "step_level": step_level,
                "step_statements": step_statements_text,
                "evaluation_criteria": criteria_text,
                "evaluation_criteria_keys": criteria_keys_instruction,
                "criteria_results_example": criteria_example_pairs,
                "step_details": step_details_text,
            },
        )

        parsed_output = parse_review_output(result["output"], expected_step_id=step_id)
        passes_review = parsed_output["passes_review"]
        criteria_results = parsed_output.get("criteria_results", {})
        feedback = parsed_output.get("feedback", "")

    except Exception as exc:
        passes_review = False
        criteria_results = {}
        feedback = f"Error during review: {exc}"
        logs.append(f"⚠️ Error reviewing step {step_id}: {exc}")
        logs.append(traceback.format_exc())

    status_icon = "✅" if passes_review else "❌"
    logs.append(f"{status_icon} {step_id}: {str(passes_review).lower()}")
    if not passes_review and feedback:
        logs.append(f"   Feedback: {feedback}")

    aggregated_result = {
        "step_id": step_id,
        "step_title": step_title,
        "passes_review": passes_review,
        "criteria_results": criteria_results,
        "feedback": feedback,
        "path_id": path_id,
        "path_title": path_title,
        "track_id": track_id,
        "track_title": track_title,
        "path_description": path_description,
        "track_description": track_description,
        "step_level": step_level,
        "step_description": step_description,
        "step_statements": step_statements_text,
    }

    return {
        "index": index,
        "result": aggregated_result,
        "logs": logs,
    }


def review_content(
    run_id: str,
    force_text: bool = False,
    limit: Optional[int] = None,
    max_workers: Optional[int] = None,
    input_prefix: Optional[str] = None,
) -> Dict[str, Any]:
    """
    Review generated bits for a roadmap run and report pass/fail against quality criteria.

    Args:
        run_id: Run identifier
        force_text: If True, reuse the latest saved review output instead of re-running the model
        limit: Optional limit on the number of bits to review (for testing)
        max_workers: Maximum number of concurrent worker threads to use when reviewing
        input_prefix: Optional input prefix/path for review criteria files

    Returns:
        Dict containing the aggregated review results
    """
    print(f"🔎 Starting content review for run ID: {run_id}")

    if force_text:
        existing_output = load_latest_output(
            pipeline="can-do-steps",
            step="review_content",
            run_id=run_id,
            subfolder="archived/reviews"
        )
        if existing_output:
            print(f"Using existing review_content output for can-do-steps/{run_id}")
            for result in existing_output.get("results", []):
                step_id = result.get("step_id", "unknown-step")
                passes_review = bool(result.get("passes_review"))
                print(f"{step_id}: {str(passes_review).lower()}")
            return existing_output
        else:
            print("⚠️ No existing review content output found; running review now.")

    roadmap_data = load_roadmap_file(run_id)
    if not roadmap_data:
        raise ValueError(f"Could not load roadmap file for run ID: {run_id}")

    steps_data = load_steps(run_id) or {"steps": []}
    paths_data = load_paths(run_id) or {"paths": []}
    tracks_data = load_tracks(run_id) or {"tracks": []}

    path_lookup = build_lookup(paths_data.get("paths", []))
    track_lookup = build_lookup(tracks_data.get("tracks", []))

    steps_list = steps_data.get("steps", [])
    if not steps_list:
        raise ValueError(f"No steps found for run ID: {run_id}")

    all_targets = collect_review_targets(steps_list, path_lookup, track_lookup)
    total_available = len(all_targets)

    if total_available == 0:
        raise ValueError(f"No steps available to review for run ID: {run_id}")

    if limit is not None and limit > 0:
        review_targets = all_targets[:limit]
        print(f"Reviewing {len(review_targets)} of {total_available} steps (limit={limit}).")
    else:
        review_targets = all_targets
        print(f"Reviewing all {total_available} steps.")

    selected_criteria, criteria_file = load_review_criteria(run_id, input_prefix=input_prefix)
    criteria_items = build_criteria_items(selected_criteria)
    step_details_map, step_details_file = load_step_details(run_id)

    criteria_text = "\n".join(
        f"{idx + 1}. [{item['key']}] {item['text']}"
        for idx, item in enumerate(criteria_items)
    )
    criteria_keys_instruction = "\n".join(f"- {item['key']}" for item in criteria_items)
    criteria_example_pairs = ",\n    ".join(
        f"\"{item['key']}\": true" for item in criteria_items[:2]
    )
    if not criteria_example_pairs:
        criteria_example_pairs = "\"criterion\": true"

    roadmap_title = roadmap_data.get("title", run_id)
    roadmap_description = roadmap_data.get("description", "")
    roadmap_id = roadmap_data.get("id", run_id)

    default_workers = min(8, (os.cpu_count() or 4) * 2)
    workers = max_workers if max_workers is not None else default_workers
    if workers is None or workers < 1:
        workers = 1

    review_count = len(review_targets)
    if review_count > 1:
        worker_desc = "1 worker" if workers == 1 else f"up to {workers} workers"
        print(f"Processing {review_count} reviews with {worker_desc}.")

    aggregated_slots: List[Optional[Dict[str, Any]]] = [None] * review_count

    if workers == 1 or review_count == 1:
        for idx, target in enumerate(review_targets):
            result_data = _review_single_target(
                index=idx,
                total_targets=review_count,
                target=target,
                run_id=run_id,
                roadmap_title=roadmap_title,
                roadmap_description=roadmap_description,
                criteria_text=criteria_text,
                criteria_keys_instruction=criteria_keys_instruction,
                criteria_example_pairs=criteria_example_pairs,
                step_details_map=step_details_map,
            )
            aggregated_slots[idx] = result_data["result"]
            for line in result_data["logs"]:
                print(line)
    else:
        with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
            futures = [
                executor.submit(
                    _review_single_target,
                    idx,
                    review_count,
                    target,
                    run_id,
                    roadmap_title,
                    roadmap_description,
                    criteria_text,
                    criteria_keys_instruction,
                    criteria_example_pairs,
                    step_details_map,
                )
                for idx, target in enumerate(review_targets)
            ]

            for future in concurrent.futures.as_completed(futures):
                result_data = future.result()
                aggregated_slots[result_data["index"]] = result_data["result"]
                for line in result_data["logs"]:
                    print(line)

    aggregated_results: List[Dict[str, Any]] = [
        result for result in aggregated_slots if result is not None
    ]

    total_reviewed = len(aggregated_results)
    total_passed = sum(1 for item in aggregated_results if item.get("passes_review"))
    total_failed = total_reviewed - total_passed

    final_result = {
        "run_id": run_id,
        "roadmap_id": roadmap_id,
        "roadmap_title": roadmap_title,
        "total_steps_available": total_available,
        "total_steps_reviewed": total_reviewed,
        "total_passed": total_passed,
        "total_failed": total_failed,
        "evaluation_criteria": selected_criteria,
        "evaluation_criteria_items": criteria_items,
        "custom_review_criteria_used": bool(criteria_file),
        "custom_review_criteria_file": criteria_file,
        "evaluation_criteria_source": criteria_file or "default",
        "step_details_file": step_details_file,
        "step_details_available": bool(step_details_file),
        "results": aggregated_results,
    }

    timestamped_filepath = save_output(
        data=final_result,
        pipeline="can-do-steps",
        step="review_content",
        run_id=run_id,
        subfolder="archived/reviews"
    )

    create_clean_copy(
        timestamped_filepath=timestamped_filepath,
        pipeline="can-do-steps",
        step="review_content",
        run_id=run_id,
        subfolder="reviews"
    )

    print(f"\n✅ Content review complete for {run_id}: {total_passed}/{total_reviewed} steps passed.")
    if total_reviewed < total_available:
        print(f"ℹ️ Reviewed subset due to limit; {total_available - total_reviewed} steps remaining.")

    return final_result


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Review generated bits for a can-do-steps roadmap.")
    parser.add_argument("--run-id", required=True, help="Run identifier")
    parser.add_argument("--force-text", action="store_true", help="Reuse existing review output if available")
    parser.add_argument("--limit", type=int, help="Limit number of bits to review")
    parser.add_argument(
        "--max-workers",
        type=int,
        dest="max_workers",
        help="Maximum number of concurrent worker threads to use"
    )
    parser.add_argument("--input-prefix", default="input", help="Input prefix/path for review criteria file")
    args = parser.parse_args()

    summary = review_content(
        run_id=args.run_id,
        force_text=args.force_text,
        limit=args.limit,
        max_workers=args.max_workers,
        input_prefix=args.input_prefix,
    )

    print(json.dumps(summary, indent=2, ensure_ascii=False))
