"""
Cost tracking utility for LLM usage across the pipeline.
Tracks individual query costs and accumulates totals per step and pipeline.
"""
import os
from typing import Dict, Any, Optional, List
from dataclasses import dataclass, field
from datetime import datetime
import threading
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult

# OpenAI pricing per 1K tokens.
# See: https://platform.openai.com/docs/pricing for current pricing.
# Last updated August 2025.
MODEL_PRICING = {
    "gpt-4.1-mini": {"input": 0.0004, "output": 0.0016},
    "gpt-4o": {"input": 0.0025, "output": 0.01},
    "gpt-5-mini":  {"input": 0.00025, "output": 0.002},
    "gpt-5.1":  {"input": 0.00125, "output": 0.0100},

    # Default fallback for unknown models, show inflated costing for awareness.
    "default": {"input": 99, "output":99}
}

# Image generation token counts for gpt-image-1 based on quality and size
IMAGE_TOKEN_COUNTS = {
    "gpt-image-1": {
        "low": {
            "1024x1024": 272,
            "1024x1536": 408,
            "1536x1024": 400
        },
        "medium": {
            "1024x1024": 1056,
            "1024x1536": 1584,
            "1536x1024": 1568
        },
        "high": {
            "1024x1024": 4160,
            "1024x1536": 6240,
            "1536x1024": 6208
        }
    }
}

# Image generation pricing for gpt-image-1 (per 1M tokens)
IMAGE_TOKEN_PRICING = {
    "gpt-image-1": {
        "text": {"input": 5.00, "output": 1.25},
        "image": {"input": 10.00, "output": 40.00}
    }
}

# Gemini 3 image pricing (token based)
# Input: fixed 560 tokens ≈ $0.0011 per image
# Output: $120 / 1,000,000 tokens
# Up to 2K (<=2048x2048) → 1120 output tokens (~$0.134/image)
# Up to 4K          (>2048x2048) → 2000 output tokens (~$0.24/image)
GEMINI3_IMAGE_INPUT_TOKENS = 560
GEMINI3_IMAGE_INPUT_COST_PER_IMAGE = 0.0011
GEMINI3_IMAGE_OUTPUT_TOKENS = {
    "2k": 1120,
    "4k": 2000,
}
GEMINI3_OUTPUT_COST_PER_MILLION = 120.0

# Gemini 2.5 flash image pricing (token based)
# Output: $30 / 1,000,000 tokens
# Up to 1024x1024 → 1290 output tokens (~$0.039/image)
GEMINI25_IMAGE_OUTPUT_TOKENS = 1290
GEMINI25_OUTPUT_COST_PER_MILLION = 30.0

@dataclass
class QueryCost:
    """Represents the cost of a single LLM query."""
    model_name: str
    input_tokens: int
    output_tokens: int
    total_tokens: int
    input_cost: float
    output_cost: float
    total_cost: float
    timestamp: datetime = field(default_factory=datetime.now)
    chain_name: str = ""
    pipeline: str = ""
    run_id: str = ""
    query_type: str = "text"  # "text" or "image"

    # Image-specific fields
    image_count: int = 0
    image_size: str = ""

    def __str__(self) -> str:
        if self.query_type == "image":
            return f"${self.total_cost:.4f} ({self.model_name}: {self.image_count} image(s) @ {self.image_size})"
        else:
            return f"${self.total_cost:.4f} ({self.model_name}: {self.input_tokens}+{self.output_tokens}={self.total_tokens} tokens)"

@dataclass
class StepCostSummary:
    """Represents the accumulated cost for a pipeline step."""
    step_name: str
    pipeline: str
    run_id: str
    query_costs: List[QueryCost] = field(default_factory=list)
    total_cost: float = 0.0
    total_input_tokens: int = 0
    total_output_tokens: int = 0
    total_tokens: int = 0
    start_time: datetime = field(default_factory=datetime.now)
    end_time: Optional[datetime] = None

    def add_query_cost(self, query_cost: QueryCost):
        """Add a query cost to this step summary."""
        self.query_costs.append(query_cost)
        self.total_cost += query_cost.total_cost
        self.total_input_tokens += query_cost.input_tokens
        self.total_output_tokens += query_cost.output_tokens
        self.total_tokens += query_cost.total_tokens

    def finalize(self):
        """Mark the step as completed."""
        self.end_time = datetime.now()

    def __str__(self) -> str:
        query_count = len(self.query_costs)
        avg_cost = self.total_cost / query_count if query_count > 0 else 0
        return f"${self.total_cost:.4f} ({query_count} queries, avg: ${avg_cost:.4f})"

class CostTrackingCallback(BaseCallbackHandler):
    """LangChain callback handler for tracking token usage and costs."""

    def __init__(self, model_name: str, chain_name: str = "", pipeline: str = "", run_id: str = ""):
        self.model_name = model_name
        self.chain_name = chain_name
        self.pipeline = pipeline
        self.run_id = run_id
        self.query_cost: Optional[QueryCost] = None

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """Called when LLM finishes running."""
        if not response.llm_output:
            return

        # Extract token usage from response
        token_usage = response.llm_output.get("token_usage", {})
        if not token_usage:
            return

        input_tokens = token_usage.get("prompt_tokens", 0)
        output_tokens = token_usage.get("completion_tokens", 0)
        total_tokens = token_usage.get("total_tokens", input_tokens + output_tokens)

        # Calculate costs
        self.query_cost = calculate_query_cost(
            model_name=self.model_name,
            input_tokens=input_tokens,
            output_tokens=output_tokens,
            chain_name=self.chain_name,
            pipeline=self.pipeline,
            run_id=self.run_id
        )

        # Add to global tracker
        _COST_TRACKER.add_query_cost(self.query_cost)

def calculate_query_cost(
    model_name: str,
    input_tokens: int,
    output_tokens: int,
    chain_name: str = "",
    pipeline: str = "",
    run_id: str = ""
) -> QueryCost:
    """Calculate the cost of a single text LLM query."""
    # Try exact match first
    if model_name in MODEL_PRICING:
        pricing = MODEL_PRICING[model_name]
    else:
        # Fall back to default pricing
        pricing = MODEL_PRICING["default"]
        print(f"⚠️ Unknown model '{model_name}', using default pricing")

    # Calculate costs (pricing is per 1K tokens)
    input_cost = (input_tokens / 1000) * pricing["input"]
    output_cost = (output_tokens / 1000) * pricing["output"]
    total_cost = input_cost + output_cost

    return QueryCost(
        model_name=model_name,
        input_tokens=input_tokens,
        output_tokens=output_tokens,
        total_tokens=input_tokens + output_tokens,
        input_cost=input_cost,
        output_cost=output_cost,
        total_cost=total_cost,
        chain_name=chain_name,
        pipeline=pipeline,
        run_id=run_id,
        query_type="text"
    )

def calculate_image_cost(
    model_name: str,
    image_count: int,
    image_size: str = "1024x1024",
    quality: str = "low",
    prompt_tokens: int = 0,
    chain_name: str = "",
    pipeline: str = "",
    run_id: str = ""
) -> QueryCost:
    """Calculate the cost of image generation using token-based pricing."""

    # Gemini 2.5 flash image pricing (output-token based)
    if model_name.startswith("gemini-2.5-flash-image"):
        output_tokens = GEMINI25_IMAGE_OUTPUT_TOKENS * image_count
        output_cost = (output_tokens / 1_000_000) * GEMINI25_OUTPUT_COST_PER_MILLION
        input_tokens = prompt_tokens  # treat prompt tokens as input-only tracking
        input_cost = 0  # no published input cost for image output pricing
        total_cost = input_cost + output_cost

        return QueryCost(
            model_name=model_name,
            input_tokens=input_tokens,
            output_tokens=output_tokens,
            total_tokens=input_tokens + output_tokens,
            input_cost=input_cost,
            output_cost=output_cost,
            total_cost=total_cost,
            chain_name=chain_name,
            pipeline=pipeline,
            run_id=run_id,
            query_type="image",
            image_count=image_count,
            image_size=image_size
        )

    # Gemini 3 image models (token-based per Google pricing)
    if model_name.startswith("gemini-3"):
        # Decide output token tier based on size (<=2048 considered 2k)
        try:
            width, height = [int(part) for part in image_size.lower().split("x")]
            max_dim = max(width, height)
        except Exception:
            max_dim = 2048

        output_tier = "2k" if max_dim <= 2048 else "4k"
        output_tokens_per_image = GEMINI3_IMAGE_OUTPUT_TOKENS.get(output_tier, GEMINI3_IMAGE_OUTPUT_TOKENS["2k"])

        # Cost components
        input_tokens = GEMINI3_IMAGE_INPUT_TOKENS + prompt_tokens  # include prompt estimate for tracking
        input_cost = GEMINI3_IMAGE_INPUT_COST_PER_IMAGE * image_count
        output_tokens = output_tokens_per_image * image_count
        output_cost = (output_tokens / 1_000_000) * GEMINI3_OUTPUT_COST_PER_MILLION
        total_cost = input_cost + output_cost

        return QueryCost(
            model_name=model_name,
            input_tokens=input_tokens,
            output_tokens=output_tokens,
            total_tokens=input_tokens + output_tokens,
            input_cost=input_cost,
            output_cost=output_cost,
            total_cost=total_cost,
            chain_name=chain_name,
            pipeline=pipeline,
            run_id=run_id,
            query_type="image",
            image_count=image_count,
            image_size=image_size
        )

    # Handle gpt-image-1 token-based pricing
    if model_name == "gpt-image-1":
        if model_name in IMAGE_TOKEN_COUNTS and quality in IMAGE_TOKEN_COUNTS[model_name]:
            quality_tokens = IMAGE_TOKEN_COUNTS[model_name][quality]
            if image_size in quality_tokens:
                image_tokens_per_image = quality_tokens[image_size]
            else:
                # Fall back to 1024x1024 if size not found
                image_tokens_per_image = quality_tokens.get("1024x1024", 272)
                print(f"⚠️ Unknown image size '{image_size}', using 1024x1024 token count")
        else:
            # Fall back to default low quality 1024x1024
            image_tokens_per_image = 272
            print(f"⚠️ Unknown quality '{quality}' for model '{model_name}', using default")

        # Calculate total image tokens
        total_image_tokens = image_tokens_per_image * image_count

        # Get pricing
        if model_name in IMAGE_TOKEN_PRICING:
            pricing = IMAGE_TOKEN_PRICING[model_name]

            # Calculate text token costs (input only for prompts)
            text_input_cost = (prompt_tokens / 1_000_000) * pricing["text"]["input"]

            # Calculate image token costs (output only for generated images)
            image_output_cost = (total_image_tokens / 1_000_000) * pricing["image"]["output"]

            total_cost = text_input_cost + image_output_cost

            return QueryCost(
                model_name=model_name,
                input_tokens=prompt_tokens,  # Text input tokens
                output_tokens=total_image_tokens,  # Image output tokens
                total_tokens=prompt_tokens + total_image_tokens,
                input_cost=text_input_cost,
                output_cost=image_output_cost,
                total_cost=total_cost,
                chain_name=chain_name,
                pipeline=pipeline,
                run_id=run_id,
                query_type="image",
                image_count=image_count,
                image_size=image_size
            )

    # Fallback for unknown models
    print(f"⚠️ Unknown image model '{model_name}', using high fallback pricing")
    fallback_cost = image_count * 0.10  # High fallback cost per image

    return QueryCost(
        model_name=model_name,
        input_tokens=prompt_tokens,
        output_tokens=0,
        total_tokens=prompt_tokens,
        input_cost=0,
        output_cost=fallback_cost,
        total_cost=fallback_cost,
        chain_name=chain_name,
        pipeline=pipeline,
        run_id=run_id,
        query_type="image",
        image_count=image_count,
        image_size=image_size
    )

class CostTracker:
    """Global cost tracker for the entire pipeline."""

    def __init__(self):
        self.current_step: Optional[StepCostSummary] = None
        self.completed_steps: List[StepCostSummary] = []
        self.all_query_costs: List[QueryCost] = []
        self.pipeline_start_time: Optional[datetime] = None
        self.pipeline_end_time: Optional[datetime] = None
        self._lock = threading.Lock()

    def start_pipeline_timing(self):
        """Start tracking overall pipeline timing."""
        with self._lock:
            self.pipeline_start_time = datetime.now()
            self.pipeline_end_time = None

    def end_pipeline_timing(self):
        """End tracking overall pipeline timing."""
        with self._lock:
            self.pipeline_end_time = datetime.now()

    def get_pipeline_duration(self) -> Optional[float]:
        """Get pipeline duration in seconds."""
        with self._lock:
            if self.pipeline_start_time and self.pipeline_end_time:
                return (self.pipeline_end_time - self.pipeline_start_time).total_seconds()
            return None

    def start_step(self, step_name: str, pipeline: str, run_id: str):
        """Start tracking a new pipeline step."""
        with self._lock:
            # Finalize current step if exists
            if self.current_step:
                self.current_step.finalize()
                self.completed_steps.append(self.current_step)

            # Start new step
            self.current_step = StepCostSummary(
                step_name=step_name,
                pipeline=pipeline,
                run_id=run_id
            )

    def add_query_cost(self, query_cost: QueryCost):
        """Add a query cost to the current step."""
        with self._lock:
            self.all_query_costs.append(query_cost)

            if self.current_step:
                self.current_step.add_query_cost(query_cost)

    def finish_step(self) -> Optional[StepCostSummary]:
        """Finish the current step and return its summary."""
        with self._lock:
            if self.current_step:
                self.current_step.finalize()
                self.completed_steps.append(self.current_step)
                completed = self.current_step
                self.current_step = None
                return completed
            return None

    def get_current_step_cost(self) -> float:
        """Get the total cost of the current step."""
        with self._lock:
            return self.current_step.total_cost if self.current_step else 0.0

    def get_total_cost(self) -> float:
        """Get the total cost across all steps."""
        with self._lock:
            total = sum(step.total_cost for step in self.completed_steps)
            if self.current_step:
                total += self.current_step.total_cost
            return total

    def get_step_summary(self, step_name: str, pipeline: str, run_id: str) -> Optional[StepCostSummary]:
        """Get summary for a specific step."""
        with self._lock:
            for step in self.completed_steps:
                if (step.step_name == step_name and
                    step.pipeline == pipeline and
                    step.run_id == run_id):
                    return step

            # Check current step
            if (self.current_step and
                self.current_step.step_name == step_name and
                self.current_step.pipeline == pipeline and
                self.current_step.run_id == run_id):
                return self.current_step

            return None

    def reset(self):
        """Reset all tracking data."""
        with self._lock:
            self.current_step = None
            self.completed_steps.clear()
            self.all_query_costs.clear()
            self.pipeline_start_time = None
            self.pipeline_end_time = None

    def print_summary(self):
        """Print a summary of all costs and timing."""
        try:
            with self._lock:
                # Calculate total cost directly since we already hold the lock
                total_cost = sum(step.total_cost for step in self.completed_steps)
                if self.current_step:
                    total_cost += self.current_step.total_cost

                total_queries = len(self.all_query_costs)

                # Print timing information
                if self.pipeline_start_time:
                    print(f"\n⏱️ Timing Summary:")
                    print(f"   Start time: {self.pipeline_start_time.strftime('%Y-%m-%d %H:%M:%S')}")
                    if self.pipeline_end_time:
                        print(f"   End time: {self.pipeline_end_time.strftime('%Y-%m-%d %H:%M:%S')}")
                        duration = (self.pipeline_end_time - self.pipeline_start_time).total_seconds()
                        hours, remainder = divmod(int(duration), 3600)
                        minutes, seconds = divmod(remainder, 60)
                        if hours > 0:
                            print(f"   Total time: {hours:02d}:{minutes:02d}:{seconds:02d} ({duration:.2f}s)")
                        else:
                            print(f"   Total time: {minutes:02d}:{seconds:02d} ({duration:.2f}s)")
                    else:
                        print(f"   End time: (in progress)")
                        duration = (datetime.now() - self.pipeline_start_time).total_seconds()
                        hours, remainder = divmod(int(duration), 3600)
                        minutes, seconds = divmod(remainder, 60)
                        if hours > 0:
                            print(f"   Total time: {hours:02d}:{minutes:02d}:{seconds:02d} (ongoing)")
                        else:
                            print(f"   Total time: {minutes:02d}:{seconds:02d} (ongoing)")

                print(f"\n💰 Cost Summary:")
                print(f"   Total Cost: ${total_cost:.4f}")
                print(f"   Total Queries: {total_queries}")
                if total_queries > 0:
                    print(f"   Average Cost/Query: ${total_cost/total_queries:.4f}")

                if self.completed_steps or self.current_step:
                    print(f"   Steps:")
                    for step in self.completed_steps:
                        print(f"     ✅ {step.step_name}: {step}")

                    if self.current_step:
                        print(f"     🔄 {self.current_step.step_name}: {self.current_step}")
        except KeyboardInterrupt:
            # Handle keyboard interrupt gracefully
            print(f"\n💰 Cost Summary (interrupted):")
            print(f"   Total Queries: {len(self.all_query_costs) if hasattr(self, 'all_query_costs') else 0}")

# Global cost tracker instance
_COST_TRACKER = CostTracker()

# Public API functions
def start_step_tracking(step_name: str, pipeline: str, run_id: str):
    """Start tracking costs for a pipeline step."""
    _COST_TRACKER.start_step(step_name, pipeline, run_id)

def finish_step_tracking() -> Optional[StepCostSummary]:
    """Finish tracking the current step."""
    return _COST_TRACKER.finish_step()

def get_current_step_cost() -> float:
    """Get the cost of the current step."""
    return _COST_TRACKER.get_current_step_cost()

def get_total_cost() -> float:
    """Get the total cost across all steps."""
    return _COST_TRACKER.get_total_cost()

def print_cost_summary():
    """Print a summary of all costs."""
    _COST_TRACKER.print_summary()

def reset_cost_tracking():
    """Reset all cost tracking."""
    _COST_TRACKER.reset()

def start_pipeline_timing():
    """Start tracking overall pipeline timing."""
    _COST_TRACKER.start_pipeline_timing()

def end_pipeline_timing():
    """End tracking overall pipeline timing."""
    _COST_TRACKER.end_pipeline_timing()

def get_pipeline_duration() -> Optional[float]:
    """Get pipeline duration in seconds."""
    return _COST_TRACKER.get_pipeline_duration()

def create_cost_callback(model_name: str, chain_name: str = "", pipeline: str = "", run_id: str = "") -> CostTrackingCallback:
    """Create a cost tracking callback for LangChain."""
    return CostTrackingCallback(model_name, chain_name, pipeline, run_id)

def track_image_cost(
    model_name: str,
    image_count: int,
    image_size: str = "1024x1024",
    quality: str = "low",
    prompt_tokens: int = 0,
    chain_name: str = "",
    pipeline: str = "",
    run_id: str = ""
):
    """Track the cost of image generation."""
    image_cost = calculate_image_cost(
        model_name=model_name,
        image_count=image_count,
        image_size=image_size,
        quality=quality,
        prompt_tokens=prompt_tokens,
        chain_name=chain_name,
        pipeline=pipeline,
        run_id=run_id
    )

    # Add to global tracker
    _COST_TRACKER.add_query_cost(image_cost)

    return image_cost
