import { useDependantState } from '@motion/react-core/hooks'
import { keys } from '@motion/utils/object'

import {
  type Row,
  type RowSelectionState,
  type Table,
  type Updater,
} from '@tanstack/react-table'
import { useCallback, useMemo } from 'react'

import { useBulkOpsState } from '../../bulk-ops'
import { type GroupedNode, type Tree } from '../../grouping'
import { type TreeListRowValueType } from '../types'
import { canRowSelect } from '../utils'

type Options = {
  tree: Tree<GroupedNode<TreeListRowValueType>>
  enableSelection?: boolean
}

type TraverseCallback = (
  selection: RowSelectionState,
  r: GroupedNode<TreeListRowValueType>
) => void

export function useRowSelection({ tree, enableSelection = false }: Options) {
  const { setSelectedIds: setBulkOpsSelectedIds, selectedIds } =
    useBulkOpsState()

  const rowSelection = useMemo<RowSelectionState>(() => {
    if (!enableSelection) return {}

    function traverseTree(
      tree: Options['tree']['values'],
      callbacks: { leaf?: TraverseCallback; group?: TraverseCallback }
    ): RowSelectionState {
      return tree.reduce<RowSelectionState>((acc, row) => {
        if (row.children != null) {
          const newAcc = {
            ...acc,
            ...traverseTree(row.children, callbacks),
          }

          callbacks.group?.(newAcc, row)
          return newAcc
        }

        callbacks.leaf?.(acc, row)

        return acc
      }, {})
    }

    const selection = traverseTree(tree.values, {
      leaf: (acc, row) => {
        if (selectedIds.includes(row.key)) {
          acc[row.qualifiedKey] = true
        }
      },
      group: (acc, row) => {
        if (!row.children || row.children.length === 0) return

        const allSelectableRows = row.children.filter(canRowSelect)
        if (allSelectableRows.length === 0) return

        const allChildrenSelected = allSelectableRows.every(
          (node) => acc[node.qualifiedKey]
        )

        if (allChildrenSelected) {
          acc[row.qualifiedKey] = true
        }
      },
    })

    return selection
  }, [selectedIds, tree.values, enableSelection])

  const onRowSelectionChange = useCallback(
    (
      table: Table<GroupedNode<TreeListRowValueType>>,
      updater: Updater<RowSelectionState>
    ) => {
      setBulkOpsSelectedIds((prev) => {
        const prevRowSelection = convertSelectedItemsToRowSelectionState(
          prev,
          table
        )
        const newRowSelection =
          typeof updater === 'function' ? updater(prevRowSelection) : updater

        const keysPreviousSelection = keys(prevRowSelection)
        const keysNewSelectionSelection = keys(newRowSelection)
          // Cleanup the new selection of the group. Only take the leafs
          .map((qualifiedKey) => {
            const row = table.getRowModel().rowsById[qualifiedKey]
            if (!row || row.getCanExpand()) return false

            if (
              row.original.value.type === 'task' ||
              row.original.value.type === 'project'
            ) {
              return qualifiedKey
            }
          })
          .filter(Boolean)

        // compute the difference from the previous selection so we get the items that were removed
        let rmDifference = keysPreviousSelection.filter(
          (x) => !keysNewSelectionSelection.includes(x)
        )

        const rmDifferenceKeys = rmDifference.map(qualifiedKeyToKey)

        // filter all the qualified keys matching the removed key when the same item exists in multiple groups
        const filteredQualifiedKeys = keysNewSelectionSelection.filter(
          (qualifiedKey) => {
            const key = qualifiedKeyToKey(qualifiedKey)
            if (rmDifferenceKeys.includes(key)) return false
            return true
          }
        )

        return filteredQualifiedKeys.map(qualifiedKeyToKey)
      })
    },
    [setBulkOpsSelectedIds]
  )

  // Using a cache for `canSelectRow` because in some cases, this check can be recursive for all children.
  // To avoid computing this many times, we're keeping a cache
  const [canSelectRowCache] = useDependantState(
    () => {
      return new Map<string, boolean>()
    },
    // When the tree changes, we want to start fresh with a clean cache
    // eslint-disable-next-line react-hooks/exhaustive-deps
    [tree.values]
  )

  const canSelectRow = useCallback<
    (row: Row<GroupedNode<TreeListRowValueType>>) => boolean
  >(
    (row) => {
      if (!enableSelection) return false

      const valueFromCache = canSelectRowCache.get(row.original.qualifiedKey)
      if (valueFromCache != null) {
        return valueFromCache
      }

      const canSelect = canRowSelect(row.original)
      canSelectRowCache.set(row.original.qualifiedKey, canSelect)
      return canSelect
    },
    [enableSelection, canSelectRowCache]
  )

  return {
    rowSelection,
    onRowSelectionChange,
    canSelectRow,
  }
}

function qualifiedKeyToKey(qualifiedKey: string): string {
  const idParts = qualifiedKey.split('/')
  const itemId = idParts.at(-1)

  if (itemId == null) {
    throw new Error('Bad item', { cause: { qualifiedKey } })
  }

  return itemId
}

function convertSelectedItemsToRowSelectionState(
  ids: string[],
  table: Table<GroupedNode<TreeListRowValueType>>
): RowSelectionState {
  return table.getRowModel().flatRows.reduce<RowSelectionState>((acc, row) => {
    if (ids.includes(row.original.key)) {
      acc[row.original.qualifiedKey] = true
    }

    return acc
  }, {})
}
