import { isNoneId } from '@motion/shared/identifiers'
import { addComponentName } from '@motion/ui/helpers'
import { useTableNavigation } from '@motion/ui/utils'
import { useHasTreatment } from '@motion/web-common/flags'

import {
  type ExpandedState,
  flexRender,
  getCoreRowModel,
  getExpandedRowModel,
  type Row,
  useReactTable,
} from '@tanstack/react-table'
import { type Range, useVirtualizer } from '@tanstack/react-virtual'
import {
  memo,
  type NamedExoticComponent,
  type ReactNode,
  useCallback,
  useEffect,
  useMemo,
  useRef,
  useState,
} from 'react'

import { type TreeListColumn } from './columns'
import { ExpandableTableRow, Table, TableCell, TableRow } from './components'
import { HeaderRow } from './header'
import {
  useHighlightLastCreated,
  useRowSelection,
  useTableGridTemplateColumns,
  useTableViewState,
} from './hooks'

import { useOnBulkOpsAction } from '../bulk-ops'
import { type GroupedNode, type Tree } from '../grouping/utils/multi-group'
import { type ViewStateColumn, type ViewStateSortBy } from '../view-state'

export type TreeListProps<T extends GroupedNode<any>> = {
  tree: Tree<T>

  columns: TreeListColumn[]
  enableSelection?: boolean

  sortBy?: ViewStateSortBy

  columnState: ViewStateColumn[]
  onColumnStateChange: (state: ViewStateColumn[]) => void
}

export const TreeList = memo(
  <T extends GroupedNode<any>>(props: TreeListProps<T>) => {
    const { rowSelection, onRowSelectionChange, canSelectRow } =
      useRowSelection({
        tree: props.tree,
        enableSelection: props.enableSelection,
      })

    const [expanded, setExpanded] = useState<ExpandedState>(() => {
      const data: Record<string, boolean> = {}

      function expand(item: Tree<T>) {
        // Only group nodes have children. Leafs don't
        if (item.children != null) {
          data[item.qualifiedKey] = true
          item.children.forEach(expand)
        }
      }

      props.tree.children.forEach(expand)
      return data
    })

    const {
      columnSizing,
      setColumnSizing,
      columnOrder,
      setColumnOrder,
      columnVisibility,
      setColumnVisibility,
      columnPinning,
      setColumnPinning,
    } = useTableViewState({
      columnState: props.columnState,
      allColumns: props.columns,
      onColumnStateChange: props.onColumnStateChange,
    })

    const table = useReactTable({
      data: props.tree.values,
      columns: props.columns,
      state: {
        expanded,
        columnSizing,
        columnOrder,
        columnVisibility,
        columnPinning,
        rowSelection,
      },
      filterFromLeafRows: true,
      getRowId: (row) => row.qualifiedKey,
      onExpandedChange: setExpanded,
      getCoreRowModel: getCoreRowModel(),
      getSubRows: (row) => {
        return row.children as T[] | undefined
      },
      getExpandedRowModel: getExpandedRowModel(),
      enableColumnResizing: true,
      enableRowSelection: canSelectRow,
      debugTable: import.meta.env.MOTION_ENV === 'localhost',
      onColumnSizingChange: setColumnSizing,
      onColumnOrderChange: setColumnOrder,
      onColumnVisibilityChange: setColumnVisibility,
      onColumnPinningChange: setColumnPinning,
      onRowSelectionChange: (setter) => {
        onRowSelectionChange(table, setter)
      },
      columnResizeMode: 'onChange',
    })

    const { rows } = table.getRowModel()

    useOnBulkOpsAction('select-all', () =>
      onRowSelectionChange(table, (selection) => {
        const newValue = { ...selection }

        rows.forEach((row) => {
          if (
            row.original.value.type === 'task' ||
            row.original.value.type === 'project'
          ) {
            newValue[row.id] = true
          }
        })

        return newValue
      })
    )
    useOnBulkOpsAction('unselect-all', () =>
      onRowSelectionChange(table, () => {
        return {}
      })
    )

    const tableContainerRef = useRef<HTMLTableElement>(null)
    const { listeners: tableNavListeners } = useTableNavigation()

    const rangeExtractor = useCallback((range: Range) => {
      // Full page of rows
      return Array.from(
        { length: Math.min(range.endIndex + range.overscan, range.count) },
        (_, i) => i
      )
    }, [])

    const rowVirtualizer = useVirtualizer({
      count: rows.length,
      estimateSize: () => 36,
      getScrollElement: () => tableContainerRef.current,
      getItemKey: (index) => rows[index].id,
      measureElement(element, entry, instance) {
        const index = instance.indexFromElement(element)
        const row = rows[index]
        const nextRowDepth = rows[index + 1]?.depth ?? 0
        return calculateRowHeight(row, nextRowDepth)
      },
      rangeExtractor,
      overscan: 15,
    })

    const gridTemplateColumns = useTableGridTemplateColumns(table)

    const toggleExpandAllRows = useCallback(
      (expand: boolean) => table.toggleAllRowsExpanded(expand),
      [table]
    )

    // Expand all by default
    useEffect(() => {
      toggleExpandAllRows(true)
    }, [toggleExpandAllRows])

    const virtualItems = rowVirtualizer.getVirtualItems()

    const range = useMemo(() => {
      return {
        startIndex: rowVirtualizer.range?.startIndex ?? 0,
        endIndex: rowVirtualizer.range?.endIndex ?? 0,
        overscan: rowVirtualizer.options.overscan,
        count: rowVirtualizer.options.count,
      }
    }, [rowVirtualizer])

    return (
      <Table
        ref={tableContainerRef}
        {...tableNavListeners}
        className='grid'
        style={{
          gridTemplateRows: 'auto 1fr',
          // @ts-expect-error - css
          '--col-template': gridTemplateColumns,
        }}
      >
        <div
          role='rowgroup'
          className='flex flex-col'
          style={{
            gridRow: 2,
            position: 'relative',
            height: rowVirtualizer.getTotalSize(),
          }}
        >
          {virtualItems.map((virtual) => {
            const row = rows[virtual.index]
            const nextRowDepth = rows[virtual.index + 1]?.depth ?? 0

            return (
              <VirtualTableRow
                key={row.id}
                index={virtual.index}
                row={row}
                nextRowDepth={nextRowDepth}
                toggleExpandAllRows={toggleExpandAllRows}
                range={range}
                measureElement={rowVirtualizer.measureElement}
              />
            )
          })}
        </div>
        <div className='contents'>
          <HeaderRow table={table} sortedBy={props.sortBy} />
        </div>
      </Table>
    )
  }
)
TreeList.displayName = 'TreeList'
function calculateRowHeight<T extends GroupedNode>(
  row: Row<T>,
  nextRowDepth: number
) {
  const rowValue = row.original.value
  const rowType = rowValue.type
  const rowDepth = getRowDepth(row)

  if (['task-totals', 'project-totals'].includes(rowType)) {
    const addedHeight = nextRowDepth < rowDepth ? 24 : 0
    if (rowValue.value.addItemValue != null) {
      return 36 + addedHeight
    }

    return 24 + addedHeight
  }

  return 36
}

function getRowDepth<T extends GroupedNode>(row: Row<T>) {
  return row.getCanExpand() ? row.depth : Math.max(row.depth - 1, 0)
}

type VirtualTableRowProps<T extends GroupedNode<any>> = {
  index: number
  row: Row<T>
  nextRowDepth: number
  measureElement: (node: Element | null) => void
  range: Range
  toggleExpandAllRows: (expand: boolean) => void
}

const VirtualTableRow: NamedExoticComponent<VirtualTableRowProps<any>> = memo(
  <T extends GroupedNode<any>>(props: VirtualTableRowProps<T>) => {
    const {
      index,
      row,
      nextRowDepth,
      measureElement,
      range,
      toggleExpandAllRows,
    } = props
    const { highlighted } = useHighlightLastCreated(
      row.original?.value?.value?.id
    )
    const canExpand = row.getCanExpand()

    const visible = canExpand || index >= range.startIndex - range.overscan

    const RowComponent = canExpand ? ExpandableTableRow : TableRow

    const rowHeight = calculateRowHeight(row, nextRowDepth)
    const depth = getRowDepth(row)
    const hasIncreasedHeight = canExpand ? false : nextRowDepth < depth

    const style = useMemo(
      () => ({
        '--depth': depth,
        height: rowHeight,
      }),
      [depth, rowHeight]
    )

    return (
      <RowComponent
        ref={(el) => measureElement(el)}
        data-index={index}
        data-has-increased-height={hasIncreasedHeight}
        expandable={canExpand}
        className='w-full'
        highlighted={highlighted}
        style={style}
        row={row}
        toggleExpandAllRows={toggleExpandAllRows}
        {...addComponentName('VirtualTableRow')}
      >
        {visible && (
          <RowContents
            visibleCells={row.getVisibleCells()}
            rowType={row.original.value.type}
            rowKey={row.original.value.key}
          />
        )}
      </RowComponent>
    )
  }
)
VirtualTableRow.displayName = 'VirtualTableRow'

type RowContentsProps<T extends GroupedNode<any>> = {
  visibleCells: ReturnType<Row<T>['getVisibleCells']>
  rowType: T['value']['type']
  rowKey: T['value']['key']
}

const RowContents: NamedExoticComponent<RowContentsProps<any>> = memo(
  <T extends GroupedNode<any>>(props: RowContentsProps<T>) => {
    const { rowType, rowKey } = props

    const noteTakerEnabled = useHasTreatment('notetaker-event-modal')

    const extended =
      [
        'task',
        'task-totals',
        'project-totals',
        'stage',
        ...(noteTakerEnabled ? ['meetingInsights'] : []),
      ].includes(rowType) ||
      (rowType === 'project' && !isNoneId(rowKey))

    const allCells = props.visibleCells

    const visibleCells = extended ? allCells : allCells.slice(0, 1)

    return (
      <>
        {visibleCells.map((cell, colIndex) => {
          return (
            <TableCellWrapper
              key={cell.id}
              colIndex={colIndex}
              isPinned={cell.column.getIsPinned() === 'left'}
              fromExpandableRow={cell.row.getCanExpand()}
            >
              {flexRender(cell.column.columnDef.cell, cell.getContext())}
            </TableCellWrapper>
          )
        })}
        {!extended && (
          <div className='group/table-cell col-start-[2] col-end-[-1] border-b border-pivot-table-header-row-border' />
        )}
      </>
    )
  }
)
RowContents.displayName = 'RowContents'

type TableCellWrapperProps = {
  children: ReactNode
  colIndex: number
  isPinned: boolean
  fromExpandableRow: boolean
}

const TableCellWrapper = memo(function TableCellWrapper({
  children,
  colIndex,
  isPinned,
  fromExpandableRow,
}: TableCellWrapperProps) {
  const style = useMemo(
    () => ({
      gridColumnStart: colIndex + 1,
    }),
    [colIndex]
  )

  return (
    <TableCell
      role='cell'
      tabIndex={-1}
      style={style}
      isPinned={isPinned}
      fromExpandableRow={fromExpandableRow}
    >
      {children}
    </TableCell>
  )
})
