"""
Utility for loading LLM models and supporting per-chain model overrides.
"""
import os
from typing import Optional, Dict, Any
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI

# Load environment variables
load_dotenv()

# Default model from environment
DEFAULT_MODEL = os.getenv("DEFAULT_LLM_MODEL", "gpt-4.1-mini")

# Model configuration overrides for specific chains
MODEL_OVERRIDES = {
    "extract_topics": "gpt-4o",  # Topic extraction uses GPT-4o
    "review_content": "gpt-5.1", # Review content with a more powerful model
    "expand_and_create_hierarchy": "gpt-5.1-mini",  # Expand and create hierarchy uses gpt-5.1-mini
}

# Optional forced model set at runtime (e.g., via CLI flag)
_FORCED_MODEL: Optional[str] = None


def set_forced_model(model_name: Optional[str]) -> None:
    """
    Set a forced model to use for all chains, overriding DEFAULT_MODEL and MODEL_OVERRIDES.

    Args:
        model_name: Name of the model to force (None to clear)
    """
    global _FORCED_MODEL
    _FORCED_MODEL = model_name

def get_model_for_chain(
    chain_name: str,
    override_model: Optional[str] = None,
    **model_kwargs
) -> ChatOpenAI:
    """
    Get the appropriate LLM model for a specific chain.

    Args:
        chain_name: Name of the chain (e.g., 'extract_topics', 'generate_lessons')
        override_model: Optional manual override for model name
        **model_kwargs: Additional keyword arguments to pass to the ChatOpenAI constructor

    Returns:
        ChatOpenAI: Configured LLM model
    """
    # Determine the model to use
    model_name = override_model

    if not model_name:
        # Force-model CLI override takes precedence over chain-specific overrides/default
        if _FORCED_MODEL:
            model_name = _FORCED_MODEL
        else:
            # Check if there's a chain-specific override
            model_name = MODEL_OVERRIDES.get(chain_name, DEFAULT_MODEL)

    # Default parameters
    default_params = {
        "temperature": 0.2,
        "streaming": False,
    }

    # Update with any provided parameters
    params = {**default_params, **model_kwargs}

    # Create and return the model
    return ChatOpenAI(
        model=model_name,
        **params
    )

def format_file_input(file_id: str) -> Dict[str, Any]:
    """
    Format a file input for OpenAI API that supports file inputs.

    Args:
        file_id: The file ID from OpenAI file upload

    Returns:
        Dict: Properly formatted file input for LangChain
    """
    return {
        "type": "file",
        "file": {"file_id": file_id}
    }
