// @flow
import { useRef, useEffect } from 'react';
import { atom, useRecoilState } from 'recoil';
import type { RecoilState } from 'recoil';
import {
  asrPlexVADWebSocketService,
  asrPlexWebSocketService,
} from '../ASRPlex/ASRPlexWebSocketService';
import { logger } from 'modules/logger';

export type ASRPlexWebSocketState = {
  socket: WebSocket | null,
  active_transaction_id: string,
};

export const dictationWebSocketState: RecoilState<ASRPlexWebSocketState> = atom({
  key: 'dictationWebSocket',
  default: {
    socket: null,
    active_transaction_id: '',
  },
});

export const dictationVadWebSocketState: RecoilState<ASRPlexWebSocketState> = atom({
  key: 'dictationVadWebSocket',
  default: {
    socket: null,
    active_transaction_id: '',
  },
});

export const useRetryDictationWebSocket = (): ?WebSocket => {
  const [socketState, setSocketState] = useRecoilState(dictationWebSocketState);
  const attemptsRef = useRef(0);

  useEffect(() => {
    let socket = asrPlexWebSocketService.connect();

    const onOpen = () => {
      // Reset the attempts counter
      attemptsRef.current = 1;
      setSocketState((prev) => ({
        ...prev,
        socket,
      }));
    };
    const onClose = () => {
      // Cleanup previous event listeners
      socket.removeEventListener('open', onOpen);
      socket.removeEventListener('close', onClose);

      // Exponential backoff, capping at 30 seconds after 5 attempts
      // 2 ** 1 = 2, 2 ** 2 = 4, 2 ** 3 = 8, 2 ** 4 = 16, 2 ** 5 = 32
      const backoff = Math.min(1000 * 2 ** attemptsRef.current, 30000);
      const jitter = Math.floor(Math.random() * 50);

      logger.info(
        `[ASRPlexDictation] WebSocket connection closed, retrying connection in ${
          backoff + jitter
        } ms...`,
        {
          attempts: attemptsRef.current,
          backoff,
          jitter,
        }
      );
      setTimeout(() => {
        socket = asrPlexWebSocketService.connect();
        attemptsRef.current += 1;
        socket.addEventListener('open', onOpen);
        socket.addEventListener('close', onClose);
      }, backoff + jitter);
    };
    socket.addEventListener('open', onOpen);
    socket.addEventListener('close', onClose);

    return () => {
      socket.removeEventListener('open', onOpen);
      socket.removeEventListener('close', onClose);
    };
  }, [setSocketState]);

  return socketState.socket;
};

export const useRetryDictationVADWebSocket = (): ?WebSocket => {
  const [socketState, setSocketState] = useRecoilState(dictationVadWebSocketState);
  const attemptsRef = useRef(0);

  useEffect(() => {
    let socket = asrPlexVADWebSocketService.connect();

    const onOpen = () => {
      // Reset the attempts counter
      attemptsRef.current = 1;
      setSocketState((prev) => ({
        ...prev,
        socket,
      }));
    };
    const onClose = () => {
      // Cleanup previous event listeners
      socket.removeEventListener('open', onOpen);
      socket.removeEventListener('close', onClose);

      // Exponential backoff, capping at 30 seconds after 5 attempts
      // 2 ** 1 = 2, 2 ** 2 = 4, 2 ** 3 = 8, 2 ** 4 = 16, 2 ** 5 = 32
      const backoff = Math.min(1000 * 2 ** attemptsRef.current, 30000);
      const jitter = Math.floor(Math.random() * 50);

      logger.info(
        `[ASRPlexVADDictation] WebSocket connection closed, retrying connection in ${
          backoff + jitter
        } ms...`,
        {
          attempts: attemptsRef.current,
          backoff,
          jitter,
        }
      );
      setTimeout(() => {
        socket = asrPlexVADWebSocketService.connect();
        attemptsRef.current += 1;
        socket.addEventListener('open', onOpen);
        socket.addEventListener('close', onClose);
      }, backoff + jitter);
    };
    socket.addEventListener('open', onOpen);
    socket.addEventListener('close', onClose);

    return () => {
      socket.removeEventListener('open', onOpen);
      socket.removeEventListener('close', onClose);
    };
  }, [setSocketState]);

  return socketState.socket;
};
