"""
Tests for the base chain builder and utilities.
"""
import os
import json
from unittest.mock import MagicMock, patch

import pytest

from chains.base import load_prompt, parse_output, default_json_parser, build_chain

@pytest.fixture
def mock_prompt():
    """Sample prompt template for testing."""
    return "This is a test prompt with a variable: {variable}."

def test_load_prompt(monkeypatch):
    """Test loading a prompt template."""
    # Mock the open function to return our test prompt
    mock_open = MagicMock()
    mock_open.return_value.__enter__.return_value.read.return_value = "Test prompt content"
    monkeypatch.setattr("builtins.open", mock_open)

    # Test loading a prompt
    prompt = load_prompt("test_chain", "test_pipeline")

    # Check that the prompt was loaded correctly
    assert prompt == "Test prompt content"

    # Check that the correct file was opened
    mock_open.assert_called_once_with("prompts/test_pipeline/test_chain.txt", "r", encoding="utf-8")

def test_parse_output():
    """Test parsing output with a custom parser."""
    # Test output
    test_output = "This is a test output"

    # Test parser function
    def test_parser(output):
        return {"parsed": output}

    # Parse the output
    result = parse_output(test_output, test_parser)

    # Check the result
    assert result == {"parsed": test_output}

    # Test with a parser that raises an exception
    def error_parser(output):
        raise ValueError("Test error")

    # Check that the exception is raised
    with pytest.raises(ValueError):
        parse_output(test_output, error_parser)

def test_default_json_parser():
    """Test the default JSON parser."""
    # Test cases
    test_cases = [
        # Plain JSON
        ('{"key": "value"}', {"key": "value"}),

        # JSON in a code block
        ('```json\n{"key": "value"}\n```', {"key": "value"}),

        # JSON without language specifier
        ('```\n{"key": "value"}\n```', {"key": "value"}),

        # JSON embedded in text
        ('Some text\n{"key": "value"}\nMore text', {"key": "value"})
    ]

    # Test each case
    for input_text, expected_output in test_cases:
        result = default_json_parser(input_text)
        assert result == expected_output

    # Test invalid JSON
    with pytest.raises(ValueError):
        default_json_parser("This is not JSON")

@patch("chains.base.load_prompt")
@patch("chains.base.get_model_for_chain")
@patch("chains.base.save_output")
def test_build_chain(mock_save_output, mock_get_model, mock_load_prompt):
    """Test building and running a chain."""
    # Mock the dependencies
    mock_load_prompt.return_value = "Test prompt with {variable}."

    # Mock the model
    mock_model = MagicMock()
    mock_model.invoke.return_value = {"text": "Test output"}
    mock_get_model.return_value = mock_model

    # Mock save_output
    mock_save_output.return_value = "/path/to/output.txt"

    # Build and run the chain
    result = build_chain(
        chain_name="test_chain",
        pipeline="test_pipeline",
        run_id="test_run",
        input_variables={"variable": "test_value"}
    )

    # Check the result
    assert result == {"output": "Test output"}

    # Check that the model was called with the correct inputs
    mock_get_model.assert_called_once_with("test_chain", None)
    mock_model.invoke.assert_called_once()

    # Check that the output was saved
    mock_save_output.assert_called_once()