import { clamp } from '@motion/utils/math'

import { type KeyboardEvent, useMemo, useRef } from 'react'

import { findAllFocusableNodes, isFocusable } from '../focus'

const tableSelectors = {
  cell: '[role="cell"],[role="gridcell"],[role="columnheader"],[role="rowheader"],td,th',
  row: '[role="row"],tr',
  rowGroup: '[role="rowgroup"],thead,tbody,tfoot',
} as const

// https://www.w3.org/WAI/ARIA/apg/patterns/grid/#gridNav_inside
export function useTableNavigation() {
  const idealCellIndexRef = useRef<number | undefined>(undefined)

  return useMemo(() => {
    function cellNavigation(e: KeyboardEvent) {
      const { target } = e
      if (!(target instanceof HTMLElement)) return

      if (e.key === 'Escape') {
        const cell = target.closest(tableSelectors.cell)
        if (cell && cell instanceof HTMLElement && isFocusable(cell)) {
          cell.focus()
          return
        }
      }
    }

    function gridNavigation(e: KeyboardEvent) {
      const { target } = e
      if (!(target instanceof HTMLElement)) return

      if (e.key === 'Enter') {
        const nodes = findAllFocusableNodes(target)
        const [first] = nodes

        first?.focus()
        idealCellIndexRef.current = undefined

        if (nodes.length !== 1) {
          // When there are multiple focusable element in the cell, we don't automatically trigger the element
          e.preventDefault()
        }
        return
      }

      if (['ArrowLeft', 'ArrowRight'].includes(e.key)) {
        focusNextHorizontalCell(target, e.key === 'ArrowLeft' ? -1 : 1)
        idealCellIndexRef.current = undefined
        document.activeElement?.scrollIntoView({
          block: 'nearest',
          inline: 'nearest',
        })
        // prevent scrolling
        e.preventDefault()

        return
      }

      if (['ArrowDown', 'ArrowUp'].includes(e.key)) {
        const result = focusNextVerticalCell(
          target,
          e.key === 'ArrowDown' ? 1 : -1,
          idealCellIndexRef.current
        )
        if (result != null) {
          if (result.previousCellIndex !== result.currentCellIndex) {
            idealCellIndexRef.current = result.previousCellIndex
          }
        }
        // prevent scrolling
        e.preventDefault()

        return
      }
    }

    return {
      listeners: {
        onKeyDown: (e: KeyboardEvent) => {
          if (!(e.target instanceof HTMLElement)) return

          const cell = e.target.parentElement?.closest(
            `${tableSelectors.cell},${tableSelectors.row}`
          )

          if (!cell) {
            return
          }

          if (cell.matches(tableSelectors.cell)) {
            cellNavigation(e)
          } else {
            gridNavigation(e)
          }
        },
      },
    }
  }, [])
}

function focusNextHorizontalCell(
  target: HTMLElement,
  direction: -1 | 1
): boolean {
  const row = target.closest(tableSelectors.row)
  const cell = target.closest(tableSelectors.cell)

  if (!row || !cell) {
    return false
  }

  const cells = Array.from(row.querySelectorAll(tableSelectors.cell))
  const cellIndex = cells.findIndex((n) => n === cell)

  if (cellIndex < 0) {
    return false
  }

  const nextCellIndex = clamp(cellIndex + direction, 0, cells.length - 1)
  if (nextCellIndex === cellIndex) {
    return false
  }

  const nextCell = cells[nextCellIndex] as HTMLElement | undefined

  if (nextCell) {
    nextCell.focus()
    return true
  }

  return false
}

function focusNextVerticalCell(
  target: HTMLElement,
  direction: -1 | 1,
  idealCellIndex?: number
): { previousCellIndex: number; currentCellIndex: number } | null {
  const rowGroup = target.closest(tableSelectors.rowGroup)
  const row = target.closest(tableSelectors.row)
  const cell = target.closest(tableSelectors.cell)

  if (!rowGroup || !row || !cell) {
    return null
  }

  const rows = Array.from(rowGroup.querySelectorAll(tableSelectors.row))
  const rowIndex = rows.findIndex((n) => n === row)

  if (rowIndex < 0) {
    return null
  }

  const nextRowIndex = clamp(rowIndex + direction, 0, rows.length - 1)
  if (nextRowIndex === rowIndex) {
    return null
  }

  const nextRow = rows[nextRowIndex] as HTMLElement | undefined

  if (nextRow) {
    let cellIndexCandidate = idealCellIndex ?? 0
    if (idealCellIndex == null) {
      const cells = Array.from(row.querySelectorAll(tableSelectors.cell))
      cellIndexCandidate = cells.findIndex((n) => n === cell)
    }

    const newRowCells = Array.from(
      nextRow.querySelectorAll(tableSelectors.cell)
    )

    const finalIndex = clamp(cellIndexCandidate, 0, newRowCells.length - 1)
    const newCell = newRowCells[finalIndex] as HTMLElement | undefined

    if (newCell) {
      newCell.focus()
      return {
        currentCellIndex: finalIndex,
        previousCellIndex: cellIndexCandidate,
      }
    }
  }

  return null
}

/**
 * @deprecated
 */
export function focusNextCellVertically(
  target: HTMLElement,
  direction: -1 | 1
): boolean {
  const result = focusNextVerticalCell(target, direction)
  if (result == null) return false
  return true
}
