import {
  getCacheEntryValue,
  getModelCache,
  type ModelCacheCollection,
  MotionCache,
  type OptimisticUpdateValue,
} from '@motion/rpc-cache'
import { getStageTense } from '@motion/ui-logic/pm/project'
import { entries } from '@motion/utils/object'
import { logInDev } from '@motion/web-base/logging'
import { stats } from '@motion/web-common/performance'
import type { ProjectSchema, StageSchema, TaskSchema } from '@motion/zod/client'

import type { QueryClient } from '@tanstack/react-query'
import { getProjectQueryFilters, getTaskQueryFilters } from '~/global/cache'
import { DateTime } from 'luxon'

export function optimisticUpdateStageTasks(
  queryClient: QueryClient,
  stageDefinitionId: string,
  project: ProjectSchema,
  newProjectStages: StageSchema[]
): OptimisticUpdateValue[] {
  return stats.time('optimisticUpdateStageTasks', () => {
    // Only update the modified stage and future stages (skip past stages)
    const affectedStages = project.stages.filter(
      (stage) =>
        getStageTense(project, stage.stageDefinitionId) !== 'past' ||
        stage.stageDefinitionId === stageDefinitionId
    )
    const newAffectedStages = newProjectStages.filter(
      (stage) =>
        getStageTense(project, stage.stageDefinitionId) !== 'past' ||
        stage.stageDefinitionId === stageDefinitionId
    )
    const allTasks = getModelCache(queryClient).tasks

    // Collect all task updates using reduce
    const taskUpdates = affectedStages.reduce<Record<string, TaskSchema>>(
      (updates, oldStage, index) => {
        const newStage = newAffectedStages[index]
        const stageTasks = getTasksForStageDefinitionId(
          allTasks,
          oldStage.stageDefinitionId
        )

        // Skip if no date change or no tasks
        if (!oldStage.dueDate || !newStage?.dueDate || !stageTasks.length) {
          return updates
        }

        const delta = DateTime.fromISO(newStage.dueDate).diff(
          DateTime.fromISO(oldStage.dueDate),
          'days'
        ).days

        // Skip if no actual change in days
        if (delta === 0) return updates

        // Add task updates to accumulator
        stageTasks.forEach((task) => {
          if (!task.dueDate) return

          const newDueDate = DateTime.fromISO(task.dueDate)
            .plus({ days: delta })
            .toISO()

          updates[task.id] = { ...task, dueDate: newDueDate }
        })

        return updates
      },
      {}
    )

    // If no updates, return empty array
    if (Object.keys(taskUpdates).length === 0) {
      return []
    }

    // Perform single cache update with all task updates
    const { rollback } = MotionCache.upsert(
      queryClient,
      getTaskQueryFilters(),
      {
        models: {
          tasks: taskUpdates,
        },
      }
    )

    return [
      {
        withRollback<T>(p: Promise<T>) {
          return p.catch((ex) => {
            rollback()
            throw ex
          })
        },
        rollback,
      },
    ]
  })
}

function getTasksForStageDefinitionId(
  allTasks: ModelCacheCollection<TaskSchema>,
  stageDefinitionId: string
) {
  return entries(allTasks).reduce((acc, [_, { value }]) => {
    if (
      value.type === 'NORMAL' &&
      value.stageDefinitionId === stageDefinitionId
    ) {
      acc.push(value)
    }
    return acc
  }, [] as TaskSchema[])
}

export function optimisticUpdateProjectStage(
  queryClient: QueryClient,
  projectId: ProjectSchema['id'],
  updatedStage: StageSchema
): OptimisticUpdateValue | null {
  const project = getCacheEntryValue(queryClient, 'projects', projectId)

  if (project == null) {
    logInDev('Project not found in cache, returning', { projectId })
    return null
  }

  const { rollback } = MotionCache.upsert(
    queryClient,
    getProjectQueryFilters(projectId),
    {
      models: {
        projects: {
          [projectId]: {
            ...project,
            stages: project.stages.map((stage) =>
              stage.id === updatedStage.id ? updatedStage : stage
            ),
          },
        },
      },
    }
  )

  return {
    withRollback<T>(p: Promise<T>) {
      return p.catch((ex) => {
        rollback()
        throw ex
      })
    },
    rollback,
  }
}
