"""
Shared chain builder and utilities for all chains in the system.
"""
import os
import json
from typing import Dict, Any, Optional, Callable, List

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

from utils.llm import get_model_for_chain
from utils.io import save_output, load_latest_output
from utils.cost_tracker import create_cost_callback

LOG_PROMPTS = os.getenv("LOG_PROMPTS", "false").lower() in ("1", "true", "yes", "on")

def load_prompt(chain_name: str, pipeline: str) -> str:
    """
    Load a prompt template from the prompts directory.

    Args:
        chain_name: Name of the chain (e.g., 'extract_topics')
        pipeline: Pipeline name (e.g., 'lessons', 'roadmap')

    Returns:
        str: The prompt template text
    """
    prompt_path = f"prompts/{pipeline}/{chain_name}.txt"

    try:
        with open(prompt_path, "r", encoding="utf-8") as f:
            return f.read()
    except FileNotFoundError:
        raise ValueError(f"Prompt file not found: {prompt_path}")

def parse_output(output: str, parser: Callable) -> Any:
    """
    Parse raw output using the provided parser function.

    Args:
        output: Raw output string from the LLM
        parser: Function to parse the output

    Returns:
        Any: Parsed output
    """
    try:
        return parser(output)
    except Exception as e:
        print(f"Error parsing output: {e}")
        print(f"Raw output: {output}")
        raise

def default_json_parser(output: str) -> Dict[str, Any]:
    """
    Default parser for JSON outputs.

    Args:
        output: Raw output string from the LLM

    Returns:
        Dict: Parsed JSON
    """
    # Try to find JSON in the output
    try:
        # First, try to parse the entire output as JSON
        return json.loads(output)
    except json.JSONDecodeError:
        # If that fails, try to find JSON blocks enclosed in ```json ... ```
        import re
        json_blocks = re.findall(r'```(?:json)?\s*([\s\S]*?)```', output)
        if json_blocks:
            try:
                return json.loads(json_blocks[0])
            except json.JSONDecodeError:
                pass

        # If that also fails, try to find anything that looks like JSON
        json_pattern = r'(\{[\s\S]*\})'
        match = re.search(json_pattern, output)
        if match:
            try:
                return json.loads(match.group(1))
            except json.JSONDecodeError:
                pass

        # If all else fails, raise an exception
        raise ValueError("Could not parse JSON from output")

def build_chain(
    chain_name: str,
    pipeline: str,
    run_id: str,
    model_override: Optional[str] = None,
    input_variables: Optional[Dict[str, Any]] = None,
    file_id: Optional[str] = None,
    force_text: bool = False
) -> Dict[str, Any]:
    """
    Build and run an LLM chain, handling model selection, file inputs, and output saving.

    Args:
        chain_name: Name of the chain (e.g., 'extract_topics')
        pipeline: Pipeline name (e.g., 'lessons', 'roadmap')
        run_id: Run identifier
        model_override: Optional model override
        input_variables: Variables to pass to the prompt template
        file_id: Optional file ID for PDF inputs
        force_text: If True, use existing raw text output instead of calling API

    Returns:
        Dict: Chain result
    """
    io_subfolder = "logs" if pipeline == "can-do-steps" else None

    # Check if we should use existing raw output
    if force_text:
        raw_output = load_latest_output(
            pipeline=pipeline,
            step=chain_name,
            run_id=run_id,
            as_text=True,
            raw=True,
            subfolder=io_subfolder
        )

        if raw_output:
            print(f"Using existing raw output for {chain_name} in {pipeline}/{run_id}")
            return {"output": raw_output}
        else:
            print(f"No existing raw output found for {chain_name} in {pipeline}/{run_id}")

    # Load the prompt template
    prompt_text = load_prompt(chain_name, pipeline)
    # Use ChatPromptTemplate for new runnables pattern
    prompt = ChatPromptTemplate.from_template(prompt_text)

    # Get the appropriate model
    model = get_model_for_chain(chain_name, model_override)

    # Create cost tracking callback
    cost_callback = create_cost_callback(
        model_name=model.model_name,
        chain_name=chain_name,
        pipeline=pipeline,
        run_id=run_id
    )

    # Prepare inputs
    chain_inputs = input_variables or {}

    # Handle file input for PDF processing (leave as-is for now)
    if file_id:
        from utils.llm import format_file_input
        import logging
        logging.info(f"Using file input with ID: {file_id}")
        messages = [
            {"role": "user", "content": [
                {"type": "text", "text": prompt_text},
                format_file_input(file_id)
            ]}
        ]
        result = model.invoke(messages, config={"callbacks": [cost_callback]})
        output = result.content
    else:
        if LOG_PROMPTS:
            _log_prompt(chain_name, run_id, prompt, chain_inputs)
        # Use new runnables pattern: prompt | model | StrOutputParser
        chain = prompt | model | StrOutputParser()
        output = chain.invoke(chain_inputs, config={"callbacks": [cost_callback]})

    # Print individual query cost if available
    if hasattr(cost_callback, 'query_cost') and cost_callback.query_cost:
        print(f"💰 Query cost: {cost_callback.query_cost}")

    # Save the raw output
    save_output(
        data=output,
        pipeline=pipeline,
        step=chain_name,
        run_id=run_id,
        as_text=True,
        subfolder=io_subfolder
    )
    return {"output": output}


def _log_prompt(chain_name: str, run_id: str, prompt: ChatPromptTemplate, inputs: Dict[str, Any]) -> None:
    """Render and print the prompt that will be sent to the LLM."""
    try:
        messages = prompt.format_messages(**inputs)
    except Exception as exc:  # noqa: BLE001
        print(f"⚠️ Could not render prompt for {chain_name}: {exc}")
        return

    print(f"=== Prompt for {chain_name} ({run_id}) ===")
    for idx, message in enumerate(messages, start=1):
        role = getattr(message, "type", message.__class__.__name__)
        content = getattr(message, "content", "")
        if isinstance(content, list):
            parts: List[str] = []
            for item in content:
                if isinstance(item, dict) and "text" in item:
                    parts.append(str(item["text"]))
                else:
                    parts.append(str(item))
            content = "\n".join(parts)
        print(f"[{idx}] {role}\n{content}\n")
