from __future__ import annotations

import uuid
import warnings
from collections.abc import Sequence
from functools import partial
from typing import (
    Annotated,
    Any,
    Callable,
    Literal,
    Union,
    cast,
)

from langchain_core.messages import (
    AnyMessage,
    BaseMessage,
    BaseMessageChunk,
    MessageLikeRepresentation,
    RemoveMessage,
    convert_to_messages,
    message_chunk_to_message,
)
from typing_extensions import TypedDict

from langgraph.constants import CONF, CONFIG_KEY_SEND
from langgraph.graph.state import StateGraph

Messages = Union[list[MessageLikeRepresentation], MessageLikeRepresentation]

REMOVE_ALL_MESSAGES = "__remove_all__"


def _add_messages_wrapper(func: Callable) -> Callable[[Messages, Messages], Messages]:
    def _add_messages(
        left: Messages | None = None, right: Messages | None = None, **kwargs: Any
    ) -> Messages | Callable[[Messages, Messages], Messages]:
        if left is not None and right is not None:
            return func(left, right, **kwargs)
        elif left is not None or right is not None:
            msg = (
                f"Must specify non-null arguments for both 'left' and 'right'. Only "
                f"received: '{'left' if left else 'right'}'."
            )
            raise ValueError(msg)
        else:
            return partial(func, **kwargs)

    _add_messages.__doc__ = func.__doc__
    return cast(Callable[[Messages, Messages], Messages], _add_messages)


@_add_messages_wrapper
def add_messages(
    left: Messages,
    right: Messages,
    *,
    format: Literal["langchain-openai"] | None = None,
) -> Messages:
    """Merges two lists of messages, updating existing messages by ID.

    By default, this ensures the state is "append-only", unless the
    new message has the same ID as an existing message.

    Args:
        left: The base list of messages.
        right: The list of messages (or single message) to merge
            into the base list.
        format: The format to return messages in. If None then messages will be
            returned as is. If 'langchain-openai' then messages will be returned as
            BaseMessage objects with their contents formatted to match OpenAI message
            format, meaning contents can be string, 'text' blocks, or 'image_url' blocks
            and tool responses are returned as their own ToolMessages.

            !!! important "Requirement"

                Must have ``langchain-core>=0.3.11`` installed to use this feature.

    Returns:
        A new list of messages with the messages from `right` merged into `left`.
        If a message in `right` has the same ID as a message in `left`, the
        message from `right` will replace the message from `left`.

    Example:
        ```python title="Basic usage"
        from langchain_core.messages import AIMessage, HumanMessage
        msgs1 = [HumanMessage(content="Hello", id="1")]
        msgs2 = [AIMessage(content="Hi there!", id="2")]
        add_messages(msgs1, msgs2)
        # [HumanMessage(content='Hello', id='1'), AIMessage(content='Hi there!', id='2')]
        ```

        ```python title="Overwrite existing message"
        msgs1 = [HumanMessage(content="Hello", id="1")]
        msgs2 = [HumanMessage(content="Hello again", id="1")]
        add_messages(msgs1, msgs2)
        # [HumanMessage(content='Hello again', id='1')]
        ```

        ```python title="Use in a StateGraph"
        from typing import Annotated
        from typing_extensions import TypedDict
        from langgraph.graph import StateGraph

        class State(TypedDict):
            messages: Annotated[list, add_messages]

        builder = StateGraph(State)
        builder.add_node("chatbot", lambda state: {"messages": [("assistant", "Hello")]})
        builder.set_entry_point("chatbot")
        builder.set_finish_point("chatbot")
        graph = builder.compile()
        graph.invoke({})
        # {'messages': [AIMessage(content='Hello', id=...)]}
        ```

        ```python title="Use OpenAI message format"
        from typing import Annotated
        from typing_extensions import TypedDict
        from langgraph.graph import StateGraph, add_messages

        class State(TypedDict):
            messages: Annotated[list, add_messages(format='langchain-openai')]

        def chatbot_node(state: State) -> list:
            return {"messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": "Here's an image:",
                            "cache_control": {"type": "ephemeral"},
                        },
                        {
                            "type": "image",
                            "source": {
                                "type": "base64",
                                "media_type": "image/jpeg",
                                "data": "1234",
                            },
                        },
                    ]
                },
            ]}

        builder = StateGraph(State)
        builder.add_node("chatbot", chatbot_node)
        builder.set_entry_point("chatbot")
        builder.set_finish_point("chatbot")
        graph = builder.compile()
        graph.invoke({"messages": []})
        # {
        #     'messages': [
        #         HumanMessage(
        #             content=[
        #                 {"type": "text", "text": "Here's an image:"},
        #                 {
        #                     "type": "image_url",
        #                     "image_url": {"url": ""},
        #                 },
        #             ],
        #         ),
        #     ]
        # }
        ```

    """
    remove_all_idx = None
    # coerce to list
    if not isinstance(left, list):
        left = [left]  # type: ignore[assignment]
    if not isinstance(right, list):
        right = [right]  # type: ignore[assignment]
    # coerce to message
    left = [
        message_chunk_to_message(cast(BaseMessageChunk, m))
        for m in convert_to_messages(left)
    ]
    right = [
        message_chunk_to_message(cast(BaseMessageChunk, m))
        for m in convert_to_messages(right)
    ]
    # assign missing ids
    for m in left:
        if m.id is None:
            m.id = str(uuid.uuid4())
    for idx, m in enumerate(right):
        if m.id is None:
            m.id = str(uuid.uuid4())
        if isinstance(m, RemoveMessage) and m.id == REMOVE_ALL_MESSAGES:
            remove_all_idx = idx

    if remove_all_idx is not None:
        return right[remove_all_idx + 1 :]

    # merge
    merged = left.copy()
    merged_by_id = {m.id: i for i, m in enumerate(merged)}
    ids_to_remove = set()
    for m in right:
        if (existing_idx := merged_by_id.get(m.id)) is not None:
            if isinstance(m, RemoveMessage):
                ids_to_remove.add(m.id)
            else:
                ids_to_remove.discard(m.id)
                merged[existing_idx] = m
        else:
            if isinstance(m, RemoveMessage):
                raise ValueError(
                    f"Attempting to delete a message with an ID that doesn't exist ('{m.id}')"
                )

            merged_by_id[m.id] = len(merged)
            merged.append(m)
    merged = [m for m in merged if m.id not in ids_to_remove]

    if format == "langchain-openai":
        merged = _format_messages(merged)
    elif format:
        msg = f"Unrecognized {format=}. Expected one of 'langchain-openai', None."
        raise ValueError(msg)
    else:
        pass

    return merged


class MessageGraph(StateGraph):
    """A StateGraph where every node receives a list of messages as input and returns one or more messages as output.

    MessageGraph is a subclass of StateGraph whose entire state is a single, append-only* list of messages.
    Each node in a MessageGraph takes a list of messages as input and returns zero or more
    messages as output. The `add_messages` function is used to merge the output messages from each node
    into the existing list of messages in the graph's state.

    Examples:
        ```pycon
        >>> from langgraph.graph.message import MessageGraph
        ...
        >>> builder = MessageGraph()
        >>> builder.add_node("chatbot", lambda state: [("assistant", "Hello!")])
        >>> builder.set_entry_point("chatbot")
        >>> builder.set_finish_point("chatbot")
        >>> builder.compile().invoke([("user", "Hi there.")])
        [HumanMessage(content="Hi there.", id='...'), AIMessage(content="Hello!", id='...')]
        ```

        ```pycon
        >>> from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
        >>> from langgraph.graph.message import MessageGraph
        ...
        >>> builder = MessageGraph()
        >>> builder.add_node(
        ...     "chatbot",
        ...     lambda state: [
        ...         AIMessage(
        ...             content="Hello!",
        ...             tool_calls=[{"name": "search", "id": "123", "args": {"query": "X"}}],
        ...         )
        ...     ],
        ... )
        >>> builder.add_node(
        ...     "search", lambda state: [ToolMessage(content="Searching...", tool_call_id="123")]
        ... )
        >>> builder.set_entry_point("chatbot")
        >>> builder.add_edge("chatbot", "search")
        >>> builder.set_finish_point("search")
        >>> builder.compile().invoke([HumanMessage(content="Hi there. Can you search for X?")])
        {'messages': [HumanMessage(content="Hi there. Can you search for X?", id='b8b7d8f4-7f4d-4f4d-9c1d-f8b8d8f4d9c1'),
                     AIMessage(content="Hello!", id='f4d9c1d8-8d8f-4d9c-b8b7-d8f4f4d9c1d8'),
                     ToolMessage(content="Searching...", id='d8f4f4d9-c1d8-4f4d-b8b7-d8f4f4d9c1d8', tool_call_id="123")]}
        ```
    """

    def __init__(self) -> None:
        super().__init__(Annotated[list[AnyMessage], add_messages])  # type: ignore[arg-type]


class MessagesState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]


def _format_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
    try:
        from langchain_core.messages import convert_to_openai_messages
    except ImportError:
        msg = (
            "Must have langchain-core>=0.3.11 installed to use automatic message "
            "formatting (format='langchain-openai'). Please update your langchain-core "
            "version or remove the 'format' flag. Returning un-formatted "
            "messages."
        )
        warnings.warn(msg)
        return list(messages)
    else:
        return convert_to_messages(convert_to_openai_messages(messages))


def push_message(
    message: MessageLikeRepresentation | BaseMessageChunk,
    *,
    state_key: str | None = "messages",
) -> AnyMessage:
    """Write a message manually to the `messages` / `messages-tuple` stream mode.

    Will automatically write to the channel specified in the `state_key` unless `state_key` is `None`.
    """

    from langchain_core.callbacks.base import (
        BaseCallbackHandler,
        BaseCallbackManager,
    )

    from langgraph.config import get_config
    from langgraph.constants import NS_SEP
    from langgraph.pregel.messages import StreamMessagesHandler

    config = get_config()
    message = next(x for x in convert_to_messages([message]))

    if message.id is None:
        raise ValueError("Message ID is required")

    if isinstance(config["callbacks"], BaseCallbackManager):
        manager = config["callbacks"]
        handlers = manager.handlers
    elif isinstance(config["callbacks"], list) and all(
        isinstance(x, BaseCallbackHandler) for x in config["callbacks"]
    ):
        handlers = config["callbacks"]

    if stream_handler := next(
        (x for x in handlers if isinstance(x, StreamMessagesHandler)), None
    ):
        metadata = config["metadata"]
        message_meta = (
            tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP)),
            metadata,
        )
        stream_handler._emit(message_meta, message, dedupe=False)

    if state_key:
        config[CONF][CONFIG_KEY_SEND]([(state_key, message)])

    return message
