import { EventStreamCodec, Message } from '@aws-sdk/eventstream-codec';
import * as utilUTF8 from '@aws-sdk/util-utf8';
import { Buffer } from 'buffer';
import crypto from 'crypto';
import { debounce } from 'lodash';
import querystring from 'query-string';
import { Client } from 'urql';

import { GenerateTranscribeCredentialsDocument } from '../graphql';
import { createPresignedURL } from '../helpers/sigv4Helpers';

export const createPCMStream = (mediaStream: MediaStream, onClose: () => void) => {
  const ctx = new AudioContext();
  const recorder = ctx.createScriptProcessor(undefined, 1, 1);
  recorder.connect(ctx.destination);

  const input = ctx.createMediaStreamSource(mediaStream);

  input.connect(recorder);

  recorder.onaudioprocess = (e) => {
    const data = e.inputBuffer.getChannelData(0);
    const newDataEvent = new CustomEvent('data', { detail: data });
    window.dispatchEvent(newDataEvent);
  };

  return {
    stop: () => {
      if (ctx.state === 'closed') return;

      mediaStream.getAudioTracks().forEach((t) => t.stop());

      recorder.disconnect();
      input.disconnect();
      ctx.close();
      onClose();
    },
  };
};

export interface AudioInput {
  stop: () => void;
}

export interface ClientParams {
  region: string;
  accessKeyId: string;
  secretAccessKey: string;
  sessionToken: string;
  locale: string;
}

export const getCredentials = async (client: Client) => {
  try {
    const response = await client
      .mutation(GenerateTranscribeCredentialsDocument, undefined)
      .toPromise();

    if (!response.data) {
      throw new Error(
        'Could not get temporary sts credentials, to connect to transcribe websocket',
      );
    }

    const { accessKey, secretAccessKey, sessionToken } =
      response.data.generateTranscribeCredentials;

    return {
      accessKeyId: accessKey,
      secretAccessKey,
      sessionToken,
    };
  } catch {
    throw new Error('Could not get temporary sts credentials, to connect to transcribe websocket');
  }
};

export const createWebsocketUrl = async (clientParams: ClientParams) => {
  const endpoint = `transcribestreaming.${clientParams.region}.amazonaws.com:8443`;
  // get a preauthenticated URL that we can use to establish our WebSocket
  return createPresignedURL(
    'GET',
    endpoint,
    '/stream-transcription-websocket',
    'transcribe',
    crypto.createHash('sha256').update('', 'utf8').digest('hex'),
    {
      key: clientParams.accessKeyId,
      secret: clientParams.secretAccessKey,
      sessionToken: clientParams.sessionToken,
      protocol: 'wss',
      expires: 60,
      region: clientParams.region,
      query: querystring.stringify({
        'language-code': clientParams.locale,
        'media-encoding': 'pcm',
        'sample-rate': 16000,
      }),
    },
  );
};

const INPUT_SAMPLE_RATE = 44100;

export const downsampleBuffer = (buffer: Float32Array, outputSampleRate = 16000) => {
  if (outputSampleRate === INPUT_SAMPLE_RATE) {
    return buffer;
  }

  const sampleRateRatio = INPUT_SAMPLE_RATE / outputSampleRate;
  const newLength = Math.round(buffer.length / sampleRateRatio);
  const result = new Float32Array(newLength);
  let offsetResult = 0;
  let offsetBuffer = 0;
  while (offsetResult < result.length) {
    const nextOffsetBuffer = Math.round((offsetResult + 1) * sampleRateRatio);
    let accum = 0,
      count = 0;
    for (let i = offsetBuffer; i < nextOffsetBuffer && i < buffer.length; i++) {
      accum += buffer[i];
      count++;
    }
    result[offsetResult] = accum / count;
    offsetResult++;
    offsetBuffer = nextOffsetBuffer;
  }
  return result;
};

export const pcmEncode = (input: Float32Array) => {
  let offset = 0;
  const buffer = new ArrayBuffer(input.length * 2);
  const view = new DataView(buffer);
  for (let i = 0; i < input.length; i++, offset += 2) {
    const s = Math.max(-1, Math.min(1, input[i]));
    view.setInt16(offset, s < 0 ? s * 0x8000 : s * 0x7fff, true);
  }
  return buffer;
};

export const getAudioEventMessage = (buffer: Buffer): Message => {
  // wrap the audio data in a JSON envelope
  return {
    headers: {
      ':message-type': {
        type: 'string',
        value: 'event',
      },
      ':event-type': {
        type: 'string',
        value: 'AudioEvent',
      },
    },
    body: buffer,
  };
};

export const eventStreamCodec = new EventStreamCodec(utilUTF8.toUtf8, utilUTF8.fromUtf8);

export const convertAudioToBinaryMessage = (raw: Float32Array) => {
  if (raw == null) return;

  // downsample and convert the raw audio bytes to PCM
  const downsampledBuffer = downsampleBuffer(raw);
  const pcmEncodedBuffer = pcmEncode(downsampledBuffer);

  // add the right JSON headers and structure to the message
  const audioEventMessage = getAudioEventMessage(Buffer.from(pcmEncodedBuffer));

  //convert the JSON object + headers into a binary event stream message
  const binary = eventStreamCodec.encode(audioEventMessage);

  return binary;
};

interface TranscribeData {
  Alternatives: {
    Items: {
      Content: string;
      StartTime: number;
      EndTime: number;
    }[];
    Transcript: string;
  }[];
  IsPartial: boolean;
}

export const streamAudioToWebSocket = async (
  userMediaStream: MediaStream,
  client: Client,
  locale: string,
  onChunk: (transcribeData: TranscribeData) => void,
  onOpen: () => void,
  onStop: () => void,
  onError: (e: string) => void,
): Promise<AudioInput> => {
  const credentials = await getCredentials(client);
  // Pre-signed URLs are a way to authenticate a request (or WebSocket connection, in this case)
  // via Query Parameters. Learn more: https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html
  const url = await createWebsocketUrl({
    ...credentials,
    locale,
    region: 'eu-central-1',
  });

  let socketError: boolean;
  let transcribeException: boolean;
  let manuallyStopped: boolean;

  const websocket = new WebSocket(url);

  const pcmStream = createPCMStream(
    userMediaStream,
    debounce(() => {
      websocket.close();
    }, 3000),
  );

  const handleOnData = ((e: CustomEvent<Float32Array>) => {
    const binary = convertAudioToBinaryMessage(e.detail);
    if (websocket.readyState === websocket.OPEN && binary) {
      websocket.send(binary);
    }
  }) as EventListener;

  websocket.binaryType = 'arraybuffer';

  // Stop function we expose for manually closing the transcription
  // This will be called when the audio file ends or the microphone stream
  // is stopped, so give it a whort while before cutting off the PCM stream
  const stop = () => {
    manuallyStopped = true;
    setTimeout(() => {
      window.removeEventListener('data', handleOnData);
      pcmStream.stop();
      onStop();
    }, 500);
  };

  // when we get audio data from the mic, send it to the WebSocket if possible
  websocket.onopen = () => {
    onOpen();
    window.addEventListener('data', handleOnData);
  };

  // handle messages, errors, and close events
  // handle inbound messages from Amazon Transcribe
  websocket.onmessage = (message) => {
    // convert the binary event stream message to JSON
    const decoder = new TextDecoder('utf-8');
    const messageWrapper = eventStreamCodec.decode(Buffer.from(message.data, 'utf-8'));

    const decodedString = decoder.decode(messageWrapper.body);

    const messageBody = JSON.parse(decodedString);

    if (messageWrapper.headers[':message-type'].value === 'event') {
      const results = messageBody.Transcript.Results;
      if (results.length > 0) {
        onChunk(results[0]);
      }
    } else {
      transcribeException = true;
      onError(messageBody.Message);
    }
  };

  websocket.onerror = () => {
    socketError = true;
    onError('WebSocket connection error. Try again.');
  };

  websocket.onclose = (closeEvent) => {
    if (manuallyStopped) return;
    // the close event immediately follows the error event; only handle one.
    if (!socketError && !transcribeException) {
      if (closeEvent.code !== 1000) {
        onError('Streaming Exception: ' + closeEvent.reason);
      }
    }
  };
  return { stop };
};
