import type { FullSingleLayerStack } from 'domains/viewer/ViewportsConfigurations/types';
import type { FrameData } from 'modules/viewer/imageloaders/BaseWebSocketImageLoader';
import type { BaseImageLoader, OrderedFrame } from './imageloaders/BaseImageLoader';
import Deferred from 'utils/deferred';
import { FULL_MEMORY_USAGE } from 'config/constants';
import { DicomWebSocketImageLoader } from 'modules/viewer/imageloaders/DicomWebSocketImageLoader';
import { ParallelPixelWebSocketImageLoader } from 'modules/viewer/imageloaders/ParallelPixelWebSocketImageLoader';
import { HttpImageLoader } from './imageloaders/HttpImageLoader';
import {
  getDirectionForCacheLoad,
  openDB,
  readAllFromDB,
  readFromDB,
  writeAllToDB,
} from './pixelCache';
import type { SupportedTextureTypes, SupportedTexturesMap } from 'utils/textureUtils';
import { logger } from '../logger';
import type {
  HandleFrameMessage,
  HandleInitialFrameMessage,
} from './workers/PixelWorkerConnection';
import { PixelDataSharedWorkerError } from './workers/PixelWorkerConnection';
import { mapToOrderedFrame } from 'utils/frameMapper';
import { maxPrecision } from 'utils/math';

export type LoaderCallbackArg =
  | {
      stackSmid: string;
      type: 'initial-frame' | 'frame' | 'complete';
      data: HandleFrameMessage | HandleInitialFrameMessage | FrameData | null;
    }
  | {
      stackSmid: string;
      type: 'error';
      data: string;
    };

export type WSTransferType = 'dicom' | 'url-pixels';
export type TransferProtocol = 'http' | 'ws';
export type TransferType = WSTransferType | 'http';

type LoaderCallback = (arg: LoaderCallbackArg) => void;

type CacheObject = {
  stackSmid: string;
  frameSmid: string;
  data: {
    pixels: SupportedTextureTypes;
  };
};

const LOADERS = {
  dicom: DicomWebSocketImageLoader,
  'url-pixels': ParallelPixelWebSocketImageLoader,
  http: HttpImageLoader,
} as const;

export class PixelDataLoader {
  #imageLoader: BaseImageLoader;
  #pixelDataCache: IDBDatabase;
  #pixelDataCacheOpenPromise: Deferred<undefined> = new Deferred();

  constructor({
    transferType,
    SUPPORTED_TEXTURES,
  }: {
    transferType: TransferType;
    SUPPORTED_TEXTURES: SupportedTexturesMap;
  }) {
    this.#imageLoader = new LOADERS[transferType](SUPPORTED_TEXTURES);
    openDB().then((db) => {
      this.#pixelDataCache = db;
      // @ts-expect-error [EN-7967] - TS2554 - Expected 1 arguments, but got 0.
      this.#pixelDataCacheOpenPromise.resolve();
    });
  }

  async cancel(stackSmid?: string) {
    this.#imageLoader.cancel(stackSmid);
  }

  refreshLoader() {
    this.#imageLoader.refreshLoader();
  }

  async loadStack({
    stack,
    initialFrameIndex,
    stackPriority,
    isDropped,
    callback,
  }: {
    stack: FullSingleLayerStack;
    initialFrameIndex: number;
    stackPriority: number;
    isDropped: boolean;
    callback: LoaderCallback;
  }) {
    // Ensure the pixel cache is opened before attempting to load a series
    await this.#pixelDataCacheOpenPromise.promise;
    try {
      if ((stack?.frames?.length ?? 0) > 0) {
        const frames = this.#generateOrderedFrames(stack, initialFrameIndex);
        await this.#loadFrames({ frames, initialFrameIndex, stackPriority, isDropped, callback });
      }

      callback({
        stackSmid: stack.smid,
        type: 'complete',
        data: null,
      });
    } catch (error: any) {
      callback({
        stackSmid: stack.smid,
        type: 'error',
        data: error,
      });
    }
  }

  async loadInitialStacks({
    initialStacks,
    stackPriority,
    isDropped,
    callback,
  }: {
    initialStacks: Array<{
      stack: FullSingleLayerStack;
      initialFrameIndex: number;
    }>;
    stackPriority: number;
    isDropped: boolean;
    callback: LoaderCallback;
  }): Promise<void> {
    // Ensure the pixel cache is opened before attempting to load a series
    await this.#pixelDataCacheOpenPromise.promise;

    // Step 1: Load the initial instances, specified by `initialFrameIndex`, for each stack
    const { initialFrames, remainingFrameStacks } = initialStacks.reduce<{
      initialFrames: OrderedFrame[];
      remainingFrameStacks: Array<OrderedFrame[]>;
    }>(
      (acc, { stack, initialFrameIndex }) => {
        const [initialFrame, ...rest] = this.#generateOrderedFrames(stack, initialFrameIndex);

        acc.initialFrames.push({
          ...initialFrame,
          sortIndex: acc.initialFrames.length,
        });
        acc.remainingFrameStacks.push(rest);

        return acc;
      },
      { initialFrames: [], remainingFrameStacks: [] }
    );

    const { framesToFetch } = await this.#loadFromCache(
      initialFrames,
      callback,
      true,
      initialStacks
    );

    if (framesToFetch.length > 0) {
      const { networkSuccess } = await this.#loadFromNetwork({
        frames: framesToFetch,
        initialFrameIndex: 0,
        stackPriority,
        isInitialFrame: true,
        isDropped,
        callback,
      });
      if (networkSuccess === false) {
        // Loading the first frames from the network has failed even with retries.
        // The callback has been called with the error already so we just need
        // to stop the loading process as more frames will fail without the first
        // frame's special calculations.
        return;
      }
    }

    // Step 2: Interleave the remaining instances from each stack and check the cache for each
    const maxLength = Math.max(...remainingFrameStacks.map((frameStack) => frameStack.length));
    const interleavedFrames: Array<OrderedFrame> = [];

    for (let i = 0; i < maxLength; i++) {
      remainingFrameStacks.forEach((frameStack) => {
        const frame = frameStack[i];

        if (frame != null) {
          interleavedFrames.push(frame);
        }
      });
    }
    interleavedFrames.sort(
      (a, b) =>
        Math.abs(
          a.sortIndex -
            (initialStacks.find((s) => s.stack.smid === a.stackSmid)?.initialFrameIndex ?? Infinity)
        ) -
        Math.abs(
          b.sortIndex -
            (initialStacks.find((s) => s.stack.smid === b.stackSmid)?.initialFrameIndex ?? Infinity)
        )
    );

    // Step 3: If the cache contains the instance, load it from the cache.
    //         Otherwise, queue it up to be fetched over the network.
    const { framesToFetch: interleavedFramesToFetch } = await this.#loadFromCache(
      interleavedFrames,
      callback,
      false,
      initialStacks
    );

    // Step 4: If any instances need fetched over the network, fetch them and
    //         then store in the cache.
    if (interleavedFramesToFetch.length > 0) {
      const stacks = interleavedFramesToFetch.reduce<{
        [key: string]: OrderedFrame[];
      }>((acc, frame) => {
        if (!acc[frame.stackSmid]) {
          acc[frame.stackSmid] = [];
        }

        acc[frame.stackSmid].push(frame);

        return acc;
      }, {});
      const networkPromises = Object.values(stacks).map((frames) => {
        const initialStack = initialStacks.find((s) => s.stack.smid === frames[0].stackSmid);

        if (initialStack == null) return Promise.resolve();
        return this.#loadFromNetwork({
          frames,
          initialFrameIndex: initialStack.initialFrameIndex,
          stackPriority: stackPriority,
          isInitialFrame: false,
          isDropped,
          callback,
        });
      });

      await Promise.all(networkPromises);
    }

    initialStacks.forEach(({ stack }) => {
      callback({
        stackSmid: stack.smid,
        type: 'complete',
        data: null,
      });
    });
  }

  async loadFrame({
    stack,
    frameIndex,
    stackPriority,
    callback,
  }: {
    stack: FullSingleLayerStack;
    frameIndex: number;
    stackPriority: number;
    callback: LoaderCallback;
  }) {
    const frame = stack.frames[frameIndex];
    const mappedFrame = mapToOrderedFrame(frame, frameIndex, stack);

    const framesToLoad = [mappedFrame];

    await this.#pixelDataCacheOpenPromise.promise;
    const { framesToFetch } = await this.#loadFromCache(framesToLoad, callback, true, []);
    if (framesToFetch.length > 0) {
      await this.#loadFromNetwork({
        frames: framesToLoad,
        initialFrameIndex: 0,
        isInitialFrame: false,
        stackPriority,
        isDropped: false,
        callback,
      });
    }
  }

  updatePriority({
    stack,
    focusFrameIndex,
    stackPriority,
  }: {
    stack: FullSingleLayerStack;
    focusFrameIndex: number;
    stackPriority: number;
  }) {
    this.#imageLoader.updatePriority({ stack, focusFrameIndex, stackPriority });
  }

  async #loadFrames({
    frames,
    initialFrameIndex,
    stackPriority,
    isDropped,
    callback,
  }: {
    frames: OrderedFrame[];
    initialFrameIndex: number;
    stackPriority: number;
    isDropped: boolean;
    callback: LoaderCallback;
  }): Promise<undefined> {
    const { framesToFetch: firstFramesToFetch } = await this.#loadFromCache(
      frames.slice(0, 1),
      callback,
      true,
      []
    );

    // in the event a new priority is updated for the request
    // it should be passed to the follow-up network request
    let priority = stackPriority;
    let focus = initialFrameIndex;
    if (firstFramesToFetch.length > 0) {
      const {
        priority: updatedPriority,
        focus: updatedFocus,
        networkSuccess,
      } = await this.#loadFromNetwork({
        frames: firstFramesToFetch,
        initialFrameIndex: 0,
        stackPriority,
        isInitialFrame: true,
        isDropped,
        callback,
      });
      if (networkSuccess === false) {
        // Loading the first frame from the network has failed even with retries.
        // The callback has been called with the error already so we just need
        // to stop the loading process as more frames will fail without the first
        // frame's special calculations.
        return;
      }
      priority = updatedPriority;
      // only update the focus if a new one arrived from a priority-update
      // the initial frame fetch locks the focus at 0 for a 1-frame request
      // so we need to ignore that
      if (updatedFocus !== 0) {
        focus = updatedFocus;
      }
    }

    const { framesToFetch } = await this.#loadFromCache(
      frames.slice(1), // the initial slice was moved to the front by #generateOrderedFrames
      callback,
      false,
      []
    );
    if (framesToFetch.length > 0) {
      await this.#loadFromNetwork({
        frames: framesToFetch,
        initialFrameIndex: focus,
        stackPriority: priority,
        isInitialFrame: false,
        isDropped,
        callback,
      });
    }
  }

  async #loadFromCache(
    frames: OrderedFrame[],
    callback: LoaderCallback,
    isInitialFrame: boolean,
    initialStacks: Array<{
      stack: FullSingleLayerStack;
      initialFrameIndex: number;
    }>
  ): Promise<{
    framesToFetch: OrderedFrame[];
  }> {
    const framesToFetch: Set<OrderedFrame> = new Set();
    const stackSmids = Array.from(new Set(frames.map((frame) => frame.stackSmid)));
    const readFromCache = new Set();

    const processFrame = (
      object: CacheObject | OrderedFrame,
      pixels?: SupportedTextureTypes | null
    ) => {
      if (pixels == null) {
        const frame = frames.find((frame) => frame.frameSmid === object.frameSmid);

        if (frame == null) {
          throw new Error(
            `${PixelDataSharedWorkerError.FrameNotFound} frameSmid: ${object.frameSmid}`
          );
        }

        framesToFetch.add(frame);
      } else {
        const range = getRange(pixels);

        readFromCache.add(object.frameSmid);

        callback({
          stackSmid: object.stackSmid,
          type: isInitialFrame ? 'initial-frame' : 'frame',
          data: {
            frameSmid: object.frameSmid,
            pixels,
            range,
            fromCache: true,
          },
        });
      }
    };

    const onRead = (object: CacheObject) => {
      const data = object.data;
      const isFrameInSelection = frames.some((frame) => frame.frameSmid === object.frameSmid);

      if (isFrameInSelection) {
        processFrame(object, data.pixels);
      }
    };

    try {
      if (!isInitialFrame) {
        await Promise.all(
          stackSmids.map(async (stackSmid) => {
            const initialStack = initialStacks.find((s) => s.stack.smid === stackSmid);
            return await readAllFromDB<CacheObject>(
              this.#pixelDataCache,
              'cache',
              stackSmid,
              onRead,
              getDirectionForCacheLoad(initialStack)
            );
          })
        );
      } else {
        const framesData = await Promise.all(
          frames.map((frame) => {
            return Promise.all([
              frame,
              readFromDB<CacheObject | null | undefined>(
                this.#pixelDataCache,
                'cache',
                frame.frameSmid
              ),
            ]);
          })
        );
        framesData.forEach(([frame, item]: [any, any]) => {
          processFrame(frame, item?.data.pixels);
        });
      }
    } catch (cacheError: any) {
      console.error('cache error', cacheError);
    }

    frames
      .filter((frame) => !readFromCache.has(frame.frameSmid))
      .forEach((frame) => {
        if (!framesToFetch.has(frame)) {
          framesToFetch.add(frame);
        }
      });

    return { framesToFetch: Array.from(framesToFetch) };
  }

  async #loadFromNetwork({
    frames,
    initialFrameIndex,
    stackPriority,
    isInitialFrame,
    isDropped,
    callback,
  }: {
    frames: OrderedFrame[];
    initialFrameIndex: number;
    stackPriority: number;
    isInitialFrame: boolean;
    isDropped: boolean;
    callback: LoaderCallback;
  }): Promise<{
    priority: number;
    focus: number;
    networkSuccess: boolean;
  }> {
    const pixelMap = new Map<string, PixelMapEntry>();
    pixelMaps.add(pixelMap);
    return new Promise((resolve) => {
      this.#imageLoader.loadFrames({
        orderedFrames: frames,
        initialFocus: initialFrameIndex,
        isInitialFrame,
        stackPriority,
        isDropped,
        messageReceivedCallback: (message) => {
          const event = message.event;
          switch (event) {
            case 'data-received': {
              const { frameSmid, pixels } = message;
              const frameInfo = frames.find((frame) => frame.frameSmid === frameSmid);

              if (frameSmid == null || frameInfo == null || pixels == null) return;

              const range = getRange(pixels);

              const entry = {
                frameInfo,
                data: { pixels },
                status: 'received',
                // @ts-expect-error [EN-7967] - TS2339 - Property 'BYTES_PER_ELEMENT' does not exist on type 'Function'.
                size: pixels.constructor.BYTES_PER_ELEMENT * pixels.length,
              } as const;

              // Immediately execute the callback so the client can
              // get the data loaded for viewing.
              callback({
                stackSmid: frameInfo.stackSmid,
                type: isInitialFrame ? 'initial-frame' : 'frame',
                data: { frameSmid, pixels, range, fromCache: false },
              });

              pixelMap.set(frameSmid, entry);

              /*
               * Every MEMORY_USAGE_COUNTER_THRESHOLD frames, check the memory usage
               * to see if we are over the defined memory limit. The limit is soft
               * and has plenty of buffer room to go over, so we do not need to check
               * memory usage after every frame which would be a lot of calculations.
               * Instead, we just check at a frequency that works out to every few seconds.
               * Frames can range in size from 20 kB to 20 MB, but there's still enough
               * overhead even if we went 2 GB over (100 * 20 MB in an extreme case)
               */
              if (
                ++checkMemoryUsageCounter >= MEMORY_USAGE_COUNTER_THRESHOLD &&
                !earlyWriteTriggered
              ) {
                checkMemoryUsageCounter -= MEMORY_USAGE_COUNTER_THRESHOLD;
                const memoryUsage = calculateMemoryUsageForPixelMaps();
                if (memoryUsage >= FULL_MEMORY_USAGE) {
                  earlyWriteTriggered = true;
                  this.#imageLoader.pause();
                  logger.info(
                    `Early Write to DB to clear ${maxPrecision(memoryUsage / (1024 * 1024 * 1024), 2)} GB across ${pixelMaps.size} pixel maps`
                  );

                  Promise.all(
                    [...pixelMaps.values()].map((pixelMap) =>
                      writeAllToDB(this.#pixelDataCache, 'cache', pixelMap).catch((e) => {
                        logger.error('early write error', e);
                      })
                    )
                  ).finally(() => {
                    earlyWriteTriggered = false;
                    this.#imageLoader.resume();
                  });
                }
              }

              break;
            }

            case 'transfer-complete': {
              resolve({
                priority: message.endingPriority ?? stackPriority,
                focus: message.endingFocus ?? initialFrameIndex,
                networkSuccess: true,
              });

              // Limiting transactions and committing the cache after we have completed
              writeAllToDB(this.#pixelDataCache, 'cache', pixelMap)
                .catch((e) => {
                  logger.error('transfer-complete write error', e);
                })
                .finally(() => {
                  pixelMaps.delete(pixelMap);
                });
              break;
            }
            case 'transfer-error': {
              const { stacks, error } = message;
              if (stacks == null || error == null) return;
              stacks.forEach((stackSmid) => {
                callback({
                  stackSmid,
                  type: 'error',
                  data: error,
                });
              });
              resolve({
                priority: stackPriority,
                focus: initialFrameIndex,
                networkSuccess: false,
              });
              pixelMaps.delete(pixelMap);
              break;
            }
            default:
              break;
          }
        },
      });
    });
  }

  #generateOrderedFrames(stack: FullSingleLayerStack, initialFrameIndex: number): OrderedFrame[] {
    const frames = stack.frames.map((frame, index) => {
      return mapToOrderedFrame(frame, index, stack);
    });

    const initialFrame = frames.splice(initialFrameIndex, 1);
    return [...initialFrame, ...frames];
  }
}

function getRange(arr: SupportedTextureTypes): [number, number] {
  const len = arr.length;
  let min = arr[0];
  let max = arr[0];
  let current;

  for (let i = 0; i < len; i++) {
    current = arr[i];
    if (current < min) {
      min = current;
    } else if (current > max) {
      max = current;
    }
  }
  return [min, max];
}

export type PixelMapEntryState = 'received' | 'written';
export type PixelMapEntry = {
  frameInfo: OrderedFrame;
  data: {
    pixels: SupportedTextureTypes;
  };
  status: PixelMapEntryState;
  size: number;
  writeId?: string;
};
export type PixelMap = Map<string, PixelMapEntry>;

const pixelMaps = new Set<PixelMap>();

export function calculateMemoryUsageForPixelMaps(): number {
  let memoryUsage = 0;
  // use iterators to avoid cloning maps that may have thousands of frames
  for (const map of pixelMaps.values()) {
    for (const entry of map.values()) {
      memoryUsage += entry.size;
    }
  }

  return memoryUsage;
}

let checkMemoryUsageCounter = 0;
// check memory usage around every ~20-300 MB depending on frame size
const MEMORY_USAGE_COUNTER_THRESHOLD = 100;
let earlyWriteTriggered = false;
