import { useOnce } from '@motion/react-core/hooks'

import { type Row } from '@tanstack/react-table'
import {
  type KeyboardEvent,
  type PropsWithChildren,
  useCallback,
  useMemo,
  useState,
} from 'react'

import { TreeKeyboardContext } from './tree-keyboard-context'

import { type VirtualizedTreeNode } from '../types'
import { useTreeContext } from '../virtualized-tree-context'

export const TreeKeyboardProvider = (props: PropsWithChildren) => {
  const { children } = props

  const { flatTree, selectedId, onSelect, computeSelectable } = useTreeContext()

  const [focusedIndex, setFocusedIndex] = useState<{
    index: number
    scroll: boolean
  }>({ index: 0, scroll: true })

  useOnce(() => {
    // Set active value to the selected value on mount
    const index = flatTree.findIndex((node) => node.original.id === selectedId)
    setFocusedIndex({ index: index > -1 ? index : 0, scroll: true })
  })

  const activeRow = useMemo(
    () => flatTree[focusedIndex.index],
    [flatTree, focusedIndex]
  )

  const handleKeyDown = (e: KeyboardEvent) => {
    if (e.defaultPrevented) return
    const getNextIndex = (startIndex: number, step: 1 | -1): number => {
      let nextIndex = startIndex + step
      while (
        nextIndex < flatTree.length - 1 &&
        nextIndex > 0 &&
        flatTree[nextIndex].original.disabled
      ) {
        nextIndex += step
      }
      return flatTree[nextIndex] && !flatTree[nextIndex].original.disabled
        ? nextIndex
        : startIndex
    }

    switch (e.key) {
      case 'ArrowDown': {
        e.preventDefault()
        const newIndex = getNextIndex(focusedIndex.index, 1)
        setFocusedIndex({ index: newIndex, scroll: true })
        break
      }
      case 'ArrowUp': {
        e.preventDefault()
        const newIndex = getNextIndex(focusedIndex.index, -1)
        setFocusedIndex({ index: newIndex, scroll: true })
        break
      }
      case 'Tab':
      case 'Enter': {
        e.preventDefault()
        if (!activeRow) return
        if (
          computeSelectable(activeRow.original) &&
          !activeRow.original.disabled
        ) {
          onSelect(activeRow)
          activeRow.toggleSelected()
        } else {
          activeRow.toggleExpanded()
        }

        break
      }
      case 'ArrowRight': {
        if (!activeRow || activeRow.original.disabled) return
        activeRow.toggleExpanded(true)
        break
      }
      case 'ArrowLeft': {
        if (!activeRow || activeRow.original.disabled) return
        activeRow.toggleExpanded(false)
        break
      }
    }
  }

  const setActiveRow = useCallback(
    (activeNode: Row<VirtualizedTreeNode>) => {
      const index = flatTree.findIndex((row) => row.id === activeNode.id)
      setFocusedIndex({ index: index > -1 ? index : 0, scroll: false })
    },
    [flatTree]
  )

  const resetFocusedIndex = useCallback(() => {
    if (flatTree.length < 1) return
    setFocusedIndex({ index: 0, scroll: true })
  }, [flatTree.length])

  const contextValue = useMemo(
    () => ({
      activeRow,
      setActiveRow,
      focusedIndex,
      resetFocusedIndex,
    }),
    [activeRow, focusedIndex, resetFocusedIndex, setActiveRow]
  )

  return (
    <TreeKeyboardContext.Provider value={contextValue}>
      <div onKeyDown={handleKeyDown} tabIndex={0}>
        {children}
      </div>
    </TreeKeyboardContext.Provider>
  )
}
