import { isNoneId } from '@motion/shared/identifiers'
import { addComponentName } from '@motion/ui/helpers'
import { useTableNavigation } from '@motion/ui/utils'

import {
  type ExpandedState,
  flexRender,
  getCoreRowModel,
  getExpandedRowModel,
  type Row,
  useReactTable,
} from '@tanstack/react-table'
import {
  type Range,
  useVirtualizer,
  type VirtualItem,
} from '@tanstack/react-virtual'
import { useCallback, useEffect, useRef, useState } from 'react'

import { type TreeListColumn } from './columns'
import { ExpandableTableRow, Table, TableCell, TableRow } from './components'
import { HeaderRow } from './header'
import {
  useHighlightLastCreated,
  useRowSelection,
  useTableGridTemplateColumns,
  useTableViewState,
} from './hooks'

import { useOnBulkOpsAction } from '../bulk-ops'
import { type GroupedNode, type Tree } from '../grouping/utils/multi-group'

export type TreeListProps<T extends GroupedNode<any>> = {
  tree: Tree<T>
  columns: TreeListColumn[]
  enableSelection?: boolean
}

export const TreeList = <T extends GroupedNode<any>>(
  props: TreeListProps<T>
) => {
  const { rowSelection, onRowSelectionChange, canSelectRow } = useRowSelection({
    tree: props.tree,
    enableSelection: props.enableSelection,
  })

  const [expanded, setExpanded] = useState<ExpandedState>(() => {
    const data: Record<string, boolean> = {}

    function expand(item: Tree<T>) {
      // Only group nodes have children. Leafs don't
      if (item.children != null) {
        data[item.qualifiedKey] = true
        item.children.forEach(expand)
      }
    }

    props.tree.children.forEach(expand)
    return data
  })

  const {
    columnSizing,
    setColumnSizing,
    columnOrder,
    setColumnOrder,
    columnVisibility,
    setColumnVisibility,
    columnPinning,
    setColumnPinning,
  } = useTableViewState(props.columns)

  const table = useReactTable({
    data: props.tree.values,
    columns: props.columns,
    state: {
      expanded,
      columnSizing,
      columnOrder,
      columnVisibility,
      columnPinning,
      rowSelection,
    },
    filterFromLeafRows: true,
    getRowId: (row) => row.qualifiedKey,
    onExpandedChange: setExpanded,
    getCoreRowModel: getCoreRowModel(),
    getSubRows: (row) => {
      return row.children as T[] | undefined
    },
    getExpandedRowModel: getExpandedRowModel(),
    enableColumnResizing: true,
    enableRowSelection: canSelectRow,
    debugTable: import.meta.env.MOTION_ENV === 'localhost',
    onColumnSizingChange: setColumnSizing,
    onColumnOrderChange: setColumnOrder,
    onColumnVisibilityChange: setColumnVisibility,
    onColumnPinningChange: setColumnPinning,
    onRowSelectionChange: (setter) => {
      onRowSelectionChange(table, setter)
    },
    columnResizeMode: 'onChange',
  })

  const { rows } = table.getRowModel()

  useOnBulkOpsAction('select-all', () =>
    onRowSelectionChange(table, (selection) => {
      const newValue = { ...selection }

      rows.forEach((row) => {
        if (
          row.original.value.type === 'task' ||
          row.original.value.type === 'project'
        ) {
          newValue[row.id] = true
        }
      })

      return newValue
    })
  )
  useOnBulkOpsAction('unselect-all', () =>
    onRowSelectionChange(table, () => {
      return {}
    })
  )

  const tableContainerRef = useRef<HTMLTableElement>(null)
  const { listeners: tableNavListeners } = useTableNavigation()

  const rangeExtractor = useCallback((range: Range) => {
    // Full page of rows
    return Array.from(
      { length: Math.min(range.endIndex + range.overscan, range.count) },
      (_, i) => i
    )
  }, [])

  const rowVirtualizer = useVirtualizer({
    count: rows.length,
    estimateSize: () => 36,
    getScrollElement: () => tableContainerRef.current,
    getItemKey: (index) => rows[index].id,
    measureElement(element, entry, instance) {
      const index = instance.indexFromElement(element)
      const row = rows[index]
      const nextRowDepth = rows[index + 1]?.depth ?? 0
      return calculateRowHeight(row, nextRowDepth)
    },
    rangeExtractor,
    overscan: 15,
  })

  const gridTemplateColumns = useTableGridTemplateColumns(table)

  const toggleExpandAllRows = useCallback(
    (expand: boolean) => table.toggleAllRowsExpanded(expand),
    [table]
  )

  // Expand all by default
  useEffect(() => {
    toggleExpandAllRows(true)
  }, [toggleExpandAllRows])

  const virtualItems = rowVirtualizer.getVirtualItems()

  return (
    <Table
      ref={tableContainerRef}
      {...tableNavListeners}
      className='grid'
      style={{
        gridTemplateRows: 'auto 1fr',
        // @ts-expect-error - css
        '--col-template': gridTemplateColumns,
      }}
    >
      <div
        role='rowgroup'
        className='flex flex-col'
        style={{
          gridRow: 2,
          position: 'relative',
          height: rowVirtualizer.getTotalSize(),
        }}
      >
        {virtualItems.map((virtual) => {
          const row = rows[virtual.index]
          const nextRowDepth = rows[virtual.index + 1]?.depth ?? 0

          return (
            <VirtualTableRow
              key={row.id}
              virtual={virtual}
              row={row}
              nextRowDepth={nextRowDepth}
              toggleExpandAllRows={toggleExpandAllRows}
              range={{
                startIndex: rowVirtualizer.range?.startIndex ?? 0,
                endIndex: rowVirtualizer.range?.endIndex ?? 0,
                overscan: rowVirtualizer.options.overscan,
                count: rowVirtualizer.options.count,
              }}
              measureElement={rowVirtualizer.measureElement}
            />
          )
        })}
      </div>
      <div className='contents'>
        <HeaderRow table={table} />
      </div>
    </Table>
  )
}

function calculateRowHeight<T extends GroupedNode>(
  row: Row<T>,
  nextRowDepth: number
) {
  const rowValue = row.original.value
  const rowType = rowValue.type
  const rowDepth = getRowDepth(row)

  if (['task-totals', 'project-totals'].includes(rowType)) {
    const addedHeight = nextRowDepth < rowDepth ? 24 : 0
    if (rowValue.value.addItemValue != null) {
      return 36 + addedHeight
    }

    return 24 + addedHeight
  }

  return 36
}

function getRowDepth<T extends GroupedNode>(row: Row<T>) {
  return row.getCanExpand() ? row.depth : Math.max(row.depth - 1, 0)
}

type VirtualTableRowProps<T extends GroupedNode> = {
  virtual: VirtualItem
  row: Row<T>
  nextRowDepth: number
  measureElement: (node: Element | null) => void
  range: Range
  toggleExpandAllRows: (expand: boolean) => void
}

const VirtualTableRow = <T extends GroupedNode>({
  virtual,
  row,
  nextRowDepth,
  measureElement,
  range,
  toggleExpandAllRows,
}: VirtualTableRowProps<T>) => {
  const { highlighted } = useHighlightLastCreated(
    row.original?.value?.value?.id
  )
  const canExpand = row.getCanExpand()

  const visible =
    canExpand || virtual.index >= range.startIndex - range.overscan

  const RowComponent = canExpand ? ExpandableTableRow : TableRow

  const rowHeight = calculateRowHeight(row, nextRowDepth)
  const depth = getRowDepth(row)
  const hasIncreasedHeight = canExpand ? false : nextRowDepth < depth

  return (
    <RowComponent
      ref={(el) => measureElement(el)}
      data-index={virtual.index}
      data-has-increased-height={hasIncreasedHeight}
      expandable={canExpand}
      className='w-full'
      highlighted={highlighted}
      style={{
        // @ts-expect-error - css
        '--depth': depth,
        height: rowHeight,
      }}
      row={row}
      toggleExpandAllRows={toggleExpandAllRows}
      {...addComponentName('VirtualTableRow')}
    >
      {visible && <RowContents row={row} />}
    </RowComponent>
  )
}
type RowContentsProps<T extends GroupedNode> = {
  row: Row<T>
}
const RowContents = <T extends GroupedNode>({ row }: RowContentsProps<T>) => {
  const rowType = row.original.value.type
  const rowValue = row.original.value

  const extended =
    ['task', 'task-totals', 'project-totals', 'stage'].includes(rowType) ||
    (rowValue.type === 'project' && !isNoneId(rowValue.key))

  const allCells = row.getVisibleCells()
  const visibleCells = extended ? allCells : allCells.slice(0, 1)

  return (
    <>
      {visibleCells.map((cell, colIndex) => {
        return (
          <TableCell
            key={cell.id}
            role='cell'
            tabIndex={-1}
            style={{
              gridColumnStart: colIndex + 1,
            }}
            isPinned={cell.column.getIsPinned() === 'left'}
            fromExpandableRow={cell.row.getCanExpand()}
          >
            {flexRender(cell.column.columnDef.cell, cell.getContext())}
          </TableCell>
        )
      })}
      {/* https://gitlab.com/usemotion/motion/-/merge_requests/13984#note_1914472489 */}
      {!extended && (
        <div className='group/table-cell col-start-[2] col-end-[-1] border-b border-pivot-table-header-row-border' />
      )}
    </>
  )
}
