import React, { useCallback } from "react";
import { useMutation, useQueryClient } from "@tanstack/react-query";
import { getQueryKey } from "@trpc/react-query";
import { useAtom } from "jotai";
import { useNavigate } from "react-router-dom";

import type { Chat } from "@charry/models";

import type { RouterOutput } from "~/lib/trpc";
import { api, trpc } from "~/lib/trpc";
import { ChatState } from "../jotai/chat.jotai";
import useIsAuthenticated from "./useIsAuthenticated.hook";

export default function useChat(chatId?: string) {
  const { isAuthenticated } = useIsAuthenticated();
  const queryClient = useQueryClient();
  const navigate = useNavigate();

  const chatListQuery = api.chat.list.useQuery(undefined, {
    enabled: isAuthenticated,
    refetchInterval: isAuthenticated ? 1000 * 60 : false,
    staleTime: isAuthenticated ? 0 : Infinity,
  });

  const chatList = React.useMemo(() => {
    if (!isAuthenticated || !chatListQuery.data) return [];
    return chatListQuery.data.chats ?? [];
  }, [isAuthenticated, chatListQuery.data]);

  const setChatList = React.useCallback(
    (fn: (current: Chat[]) => Chat[]) => {
      if (!isAuthenticated) return;
      const queryKey = getQueryKey(api.chat.list, undefined, "query");
      const current =
        queryClient.getQueryData<RouterOutput["chat"]["list"]>(queryKey);
      queryClient.setQueryData(queryKey, {
        ...current,
        chats: fn(current?.chats ?? []),
      });
    },
    [queryClient, isAuthenticated],
  );

  const pageSize = 50;
  const totalPages = Math.ceil(chatList.length / pageSize);
  const [currentPage, setCurrentPage] = React.useState(1);

  const paginatedChatList = chatList.slice(
    (currentPage - 1) * pageSize,
    currentPage * pageSize,
  );

  const refreshChatList = useMutation({
    mutationFn: async () => {
      return await trpc.chat.list.query();
    },
    mutationKey: ["refreshChatList"],
    onSuccess: ({ chats }) => {
      if (!chats) return;

      setChatList(() => chats);
    },
  });

  const chat = chatList.find((c) => c.id === chatId);

  const createChat = api.chat.create.useMutation({
    onSuccess: ({ chat, error }) => {
      if (!chat) {
        throw new Error(error?.message ?? "Failed to create chat");
      }

      setChatList((current) => [chat, ...current]);
      navigate(`/chat/${chat.id}`);
    },
  });

  const [createChatParams, setCreateChatParams] = useAtom(
    ChatState.createChatParams,
  );

  const getLatestChat = api.chat.latest.useMutation({
    onSuccess: ({ chat }) => {
      if (!chat) return;

      setChatList((current) => {
        if (current.find((c) => c.id === chat.id)) {
          return current.map((c) => {
            if (c.id === chat.id) {
              return chat;
            }
            return c;
          });
        } else {
          return [chat, ...current];
        }
      });
    },
  });

  const createTurns = api.turn.create.useMutation({
    onSuccess: ({ error, turns }) => {
      if (!turns || turns.length === 0) {
        throw new Error(error?.message ?? "Failed to create turn");
      }

      setChatList((current) =>
        current.map((c) => {
          if (c.id === chatId) {
            const updatedTurns = c.turns.map((t) => {
              const newTurn = turns.find((newT) => newT.id === t.id);
              return newTurn ?? t;
            });
            const newTurns = turns.filter(
              (t) => !c.turns.some((existingT) => existingT.id === t.id),
            );
            return {
              ...c,
              turns: [...updatedTurns, ...newTurns].sort(
                (a, b) =>
                  new Date(a.createdAt).getTime() -
                  new Date(b.createdAt).getTime(),
              ),
            };
          }
          return c;
        }),
      );
    },
  });

  const updateTurn = api.turn.update.useMutation({
    onSuccess: ({ error, turn }) => {
      if (!turn) {
        throw new Error(error?.message ?? "Failed to update turn");
      }

      setChatList((current) =>
        current.map((c) =>
          c.id === chatId
            ? {
                ...c,
                turns: c.turns
                  .map((t) => (t.id === turn.id ? turn : t))
                  .sort(
                    (a, b) =>
                      new Date(a.createdAt).getTime() -
                      new Date(b.createdAt).getTime(),
                  ),
              }
            : c,
        ),
      );
    },
  });

  const regenerateTurn = useCallback(
    async (turnId: string) => {
      const currentChat = chatList.find((c) => c.id === chatId);
      if (!currentChat) return;

      const turnToRegenerate = currentChat.turns.find((t) => t.id === turnId);
      if (!turnToRegenerate) return;

      const turnIndex = currentChat.turns.findIndex((t) => t.id === turnId);
      if (turnIndex === -1) return;

      // Update the turn state to REGENERATING
      await updateTurn.mutateAsync({
        characterId: turnToRegenerate.characterId ?? "",
        chatId: turnToRegenerate.chatId,
        content: null,
        id: turnId,
        role: turnToRegenerate.role,
        state: "REGENERATING",
        versions: turnToRegenerate.versions,
      });

      try {
        // Subscribe to the stream for the regenerated turn
        const unsubscribe = trpc.turn.stream.subscribe(
          { id: turnId },
          {
            onData: (message) => {
              const { turn } = message.data;
              if (!turn) return;

              setChatList((current) =>
                current.map((c) =>
                  c.id === chatId
                    ? {
                        ...c,
                        turns: c.turns
                          .map((t, index) =>
                            index === turnIndex
                              ? {
                                  ...t,
                                  content: turn.content,
                                  state: turn.state,
                                  versions:
                                    turn.state === "COMPLETED" ||
                                    turn.state === "REGENERATED"
                                      ? [...t.versions, turn.content ?? ""]
                                      : t.versions,
                                }
                              : t,
                          )
                          .sort(
                            (a, b) =>
                              new Date(a.createdAt).getTime() -
                              new Date(b.createdAt).getTime(),
                          ),
                      }
                    : c,
                ),
              );

              if (turn.state === "COMPLETED" || turn.state === "REGENERATED") {
                unsubscribe.unsubscribe();
              }
            },
            onError: (error) => {
              console.error("Stream error:", error);
              updateTurn.mutate({
                characterId: turnToRegenerate.characterId ?? "",
                chatId: turnToRegenerate.chatId,
                content: null,
                id: turnId,
                role: turnToRegenerate.role,
                state: "REGENERATION_FAILED",
                versions: turnToRegenerate.versions,
              });
              unsubscribe.unsubscribe();
            },
          },
        );
      } catch (error) {
        console.error("Error regenerating message:", error);
        updateTurn.mutate({
          characterId: turnToRegenerate.characterId ?? "",
          chatId: turnToRegenerate.chatId,
          content: null,
          id: turnId,
          role: turnToRegenerate.role,
          state: "REGENERATION_FAILED",
          versions: turnToRegenerate.versions,
        });
      }
    },
    [chatId, chatList, setChatList, updateTurn],
  );

  const streamTurn = useCallback(
    (turnId: string) => {
      const stream = trpc.turn.stream.subscribe(
        { id: turnId },
        {
          onData: (message) => {
            const { turn } = message.data;
            if (!turn) return;

            setChatList((current) =>
              current.map((c) =>
                c.id === chatId
                  ? {
                      ...c,
                      turns: c.turns.map((t) => (t.id === turnId ? turn : t)),
                    }
                  : c,
              ),
            );

            if (turn.state === "COMPLETED") {
              stream.unsubscribe();
            }
          },
        },
      );
    },
    [chatId, setChatList],
  );

  const archiveChat = api.chat.archive.useMutation({
    onSuccess: ({ chat }) => {
      if (!chat) return;

      setChatList((current) => {
        return current.map((c) => {
          if (c.id === chat.id) {
            return chat;
          }
          return c;
        });
      });
    },
  });

  return {
    archiveChat,
    chat,
    chatList,
    createChat,
    createChatParams,
    createTurns,
    currentPage,
    getLatestChat,
    paginatedChatList,
    refreshChatList,
    regenerateTurn,
    setChatList,
    setCreateChatParams,
    setCurrentPage,
    streamTurn,
    totalPages,
  };
}
