import {
  closestCorners,
  type CollisionDetection,
  DndContext,
  getFirstCollision,
  pointerWithin,
  rectIntersection,
  type UniqueIdentifier,
} from '@dnd-kit/core'
import {
  arrayMove,
  SortableContext,
  verticalListSortingStrategy,
} from '@dnd-kit/sortable'
import {
  memo,
  type ReactNode,
  useCallback,
  useEffect,
  useMemo,
  useRef,
  useState,
} from 'react'

import { DragOverlayPortal } from './drag-overlay-portal'
import { pastVerticalCenter } from './past-vertical-center'
import { type RenderItem, type SortableItem } from './sortable-treeview.types'
import { VirtualizedGroupedList } from './virtualized-grouped-list'

export type OnDropArgs = {
  fromContainerIndex: number
  toContainerIndex: number
  fromItemIndex?: number
  toItemIndex?: number
}

export type OnDragStartArgs<T extends SortableItem> = {
  activeItem?: T
}

export type SortableTreeviewVirtualizedDndProps<T extends SortableItem> = {
  items: T[]
  renderItem: RenderItem<T>
  onDrop: (args: OnDropArgs) => boolean
  onDragStart?: (args: OnDragStartArgs<T>) => void
  onClickExpand: (itemId: string) => void
  itemHeight: number
  renderPlaceholder?: (item: T) => ReactNode
  disableDrag?: boolean
  disableDragToAnotherGroup?: boolean
  disableDragOnGroups?: boolean
}

export const SortableTreeviewVirtualizedDnd = memo(
  function SortableTreeviewVirtualizedDnd<T extends SortableItem>(
    props: SortableTreeviewVirtualizedDndProps<T>
  ) {
    const [items, setItems] = useState(props.items)
    const [activeId, setActiveId] = useState<UniqueIdentifier | null>(null)
    const initialContainerId = useRef<UniqueIdentifier | null>(null)

    useEffect(() => {
      setItems(props.items)
    }, [props.items])

    // used to reset the state on cancel
    const [clonedItems, setClonedItems] = useState<typeof items | null>(null)

    const getContainer = (containerId: UniqueIdentifier | null) => {
      if (!containerId) return undefined
      return items.find((i) => i.id === containerId)
    }

    function findChildContainer(childId?: UniqueIdentifier) {
      if (!childId) return
      // If the id is from a container, return it.
      const container = getContainer(childId)
      if (container) return container

      return items.find((i) => i.children.some((i) => i.id === childId))
    }

    function findChildContainerIndex(childId?: UniqueIdentifier) {
      if (!childId) return -1

      return items.findIndex((i) => i.children.some((i) => i.id === childId))
    }

    const getActiveItem = useCallback(
      (activeId: UniqueIdentifier | null, items: T[]) => {
        if (!activeId) return

        for (const container of items) {
          if (container.id === activeId) return container
          const child = container.children.find((i) => i.id === activeId)
          if (child) return child as T
        }

        return
      },
      []
    )

    const activeItem = useMemo(() => {
      return getActiveItem(activeId, items)
    }, [activeId, getActiveItem, items])

    const collisionDetectionStrategy: CollisionDetection = (args) => {
      const activeContainer = getContainer(activeId)
      if (activeContainer) {
        return pastVerticalCenter({
          ...args,
          droppableContainers: args.droppableContainers.filter((container) =>
            getContainer(container.id)
          ),
        })
      }

      const pointerIntersections = pointerWithin(args)
      // If there are droppables intersecting with the pointer, return those
      const intersections =
        pointerIntersections.length > 0
          ? pointerIntersections
          : rectIntersection(args)
      let overId = getFirstCollision(intersections, 'id')

      if (overId == null) {
        return []
      }

      const overContainer = getContainer(overId)
      if (overContainer) {
        const children = getContainer(overId)?.children
        if (!children || children.length < 1) {
          return [{ id: overId }]
        }

        // Keep only the children of the container.
        const droppableContainers = args.droppableContainers.filter(
          (container) => children.some((i) => i.id === container.id)
        )

        if (droppableContainers.length) {
          overId = closestCorners({
            ...args,
            droppableContainers,
          })[0]?.id
        }
      }

      return [{ id: overId }]
    }

    const onDragCancel = () => {
      setActiveId(null)
      if (clonedItems) {
        setItems(clonedItems)
      }
    }

    return (
      <DndContext
        collisionDetection={collisionDetectionStrategy}
        onDragStart={({ active }) => {
          initialContainerId.current = findChildContainer(active.id)?.id ?? null
          setActiveId(active.id)
          setClonedItems(deepCopy(items))
          if (props.onDragStart !== null) {
            const activeItem = getActiveItem(active.id, items)
            props.onDragStart?.({ activeItem })
          }
        }}
        onDragEnd={({ over, active }) => {
          const activeContainer = getContainer(active.id)

          if (activeContainer) {
            const activeIndex = items.findIndex((i) => i.id === active.id)
            const overIndex = items.findIndex((i) => i.id === over?.id)
            if (activeIndex !== overIndex) {
              const newItems = arrayMove(items, activeIndex, overIndex)

              const accepted = props.onDrop({
                fromContainerIndex: activeIndex,
                toContainerIndex: overIndex,
              })

              if (accepted) {
                setItems(newItems)
              } else if (clonedItems) {
                setItems(clonedItems)
              }
            }
            setActiveId(null)
            return
          }
          if (!over?.id) {
            setActiveId(null)
            return
          }

          const overContainerIndex = findChildContainerIndex(over.id)
          const overContainer = items[overContainerIndex]
          if (overContainer) {
            const activeIndex = overContainer.children.findIndex(
              (i) => i.id === active.id
            )
            const overIndex = overContainer.children.findIndex(
              (i) => i.id === over?.id
            )

            const newItems = [...items]
            if (activeIndex !== overIndex) {
              newItems[overContainerIndex].children = arrayMove(
                newItems[overContainerIndex].children,
                activeIndex,
                overIndex
              )
            }

            const accepted = props.onDrop({
              fromContainerIndex: overContainerIndex,
              toContainerIndex: overContainerIndex,
              fromItemIndex: activeIndex,
              toItemIndex: overIndex,
            })

            if (accepted) {
              setItems(newItems)
            } else if (clonedItems) {
              setItems(clonedItems)
            }
            setActiveId(null)
          }
        }}
        onDragOver={({ active, over }) => {
          if (props.disableDragToAnotherGroup) return
          const overContainer = findChildContainer(over?.id)
          const activeContainer = findChildContainer(active.id)

          if (
            !overContainer ||
            !activeContainer ||
            activeContainer === overContainer
          ) {
            return
          }

          const activeItem = activeContainer.children.find(
            (i) => i.id === active.id
          )
          if (!activeItem) return

          const isCloserToTop =
            active.rect.current.translated?.top != null &&
            over?.rect.top != null &&
            over?.rect.top > active.rect.current.translated?.top

          const newItems = [...items]
          for (const item of newItems) {
            if (item.id === overContainer.id) {
              if (isCloserToTop) {
                item.children.unshift({ ...activeItem })
              } else {
                item.children.push({ ...activeItem })
              }
              continue
            }
            if (item.id === activeContainer.id) {
              item.children = item.children.filter(
                (i) => i.id !== activeItem.id
              )
            }
          }
          setItems(newItems)
        }}
        onDragCancel={onDragCancel}
      >
        <SortableContext
          items={items.map((item) => item.id)}
          strategy={verticalListSortingStrategy}
        >
          <VirtualizedGroupedList
            items={items}
            renderItem={(item, active, onClick, hovering) =>
              props.renderItem(item as T, active, onClick, hovering)
            }
            activeId={activeId}
            itemHeight={props.itemHeight}
            renderPlaceholder={(item) => props.renderPlaceholder?.(item as T)}
            setExpanded={(id: string) => {
              props.onClickExpand(id)
            }}
            disableDrag={props.disableDrag}
            disableDragOnGroups={props.disableDragOnGroups}
          />
        </SortableContext>
        <DragOverlayPortal activeItem={activeItem} />
      </DndContext>
    )
  }
)

function deepCopy<T extends SortableItem>(items: T[]): T[] {
  return [
    ...items.map((i) => ({
      ...i,
      children: i.children?.length ? deepCopy(i.children) : [],
    })),
  ]
}
