import * as ort from 'onnxruntime-web';
import { Tensor } from 'onnxruntime-web';
import { EventEmitter } from 'events';
import PostAndReceiveWorker from 'services/PostAndReceiveWorker';
import logger from 'services/logger';
// eslint-disable-next-line import/no-webpack-loader-syntax
import onnxWorker from 'workerize-loader!./onnx.worker';
import isOffscreenCanvasSupported from 'utils/isOffscreenCanvasSupported';
import { getMessageFromError } from 'utils/errorMessage';
import { ModelConfig } from 'stores/configStore';
import mixpanel from 'services/mixpanel';
import imageDataToTensor, { stableSigmoid } from './utils/imageDataToTensor';

type InputImage = HTMLImageElement | HTMLCanvasElement | OffscreenCanvas;

class AgePredictor extends EventEmitter {
  onnxWorker: PostAndReceiveWorker<any>;
  session: ort.InferenceSession | null;
  isOffscreenCanvasSupported: boolean;
  models: ModelConfig[];
  model: ModelConfig | null;

  constructor(models: ModelConfig[]) {
    super();
    this.onnxWorker = new PostAndReceiveWorker({ worker: onnxWorker as Worker, name: 'Onnx' });
    this.session = null;
    this.isOffscreenCanvasSupported = isOffscreenCanvasSupported();
    this.models = models;
    this.model = null;
    this.onProgress = this.onProgress.bind(this);
  }

  async init() {
    const model = this.selectModel();
    this.model = model;
    logger.info('Selected model', { ...model });
    mixpanel.trackEvent({ event: 'Selected Model', ...model });
    mixpanel.register({ model: this.model.name });
    if (!this.isOffscreenCanvasSupported) {
      logger.info('Initializing Onnx AgePredictor without web worker');
      return this.initWithoutWebWorker();
    }
    logger.info('Initializing Onnx AgePredictor');
    return this.onnxWorker.postAndReceiveProgress({
      type: 'loadModel',
      ...model,
    }, this.onProgress);
  }

  onProgress(progressData: { total: number, loaded: number, progress: number }) {
    const { total, loaded, progress } = progressData;
    this.emit('progress', { total, loaded, progress });
  }

  terminate() {
    logger.info('Terminating Onnx AgePredictor');
    return this.onnxWorker.worker.terminate();
  }

  async freshInit() {
    this.onnxWorker = new PostAndReceiveWorker({ worker: onnxWorker as Worker, name: 'Onnx' });
    return this.init();
  }

  async predictAge(image: InputImage) {
    const resizedImage = await this.resizeImage(image);

    const { result, time, error } = await this.onnxWorker.postAndReceiveMessage({
      type: 'predictAge',
      imageData: resizedImage,
    });

    if (error) {
      logger.error('Error predicting age:', error);
      return null;
    }
    logger.debug('Age prediction results:', { result, time });
    return result as number;
  }

  async initWithoutWebWorker() {
    if (!this.model) {
      throw new Error('Model is not selected yet');
    }
    const buffer = await this.downloadModel(this.model.url);
    const s = await ort.InferenceSession.create(buffer, {
      executionProviders: this.model.executionProviders,
    });

    this.session = s;

    const [width, height] = this.model.dimensions;
    const { tensorShape, normalizePixelData } = this.model;

    // Create a blank ImageData instance to warm up the model
    const pixelData = new Uint8ClampedArray(width * height * 4);
    const imageData = new ImageData(pixelData, width, height);
    const imageTensor = imageDataToTensor(imageData, width, height, tensorShape, normalizePixelData);
    await this.runInference(imageTensor);
  }

  async downloadModel(path: string): Promise<ArrayBuffer> {
    return new Promise((resolve, reject) => {
      const xhr = new XMLHttpRequest();
      xhr.open('GET', path, true);
      xhr.responseType = 'arraybuffer';

      // Track loading progress
      xhr.onprogress = (evt) => {
        if (evt.lengthComputable) {
          const progress = Math.round((evt.loaded / evt.total) * 100);
          this.emit('progress', { total: evt.total, loaded: evt.loaded, progress });
        }
      };

      // Handle successful loading
      xhr.onload = () => {
        if (xhr.status === 200) {
          // Get the ArrayBuffer directly
          const buffer = xhr.response;
          resolve(buffer);
        } else {
          reject(new Error(`Loading failed with status ${xhr.status}`));
        }
      };

      // Handle loading error
      xhr.onerror = () => {
        reject(new Error('Loading model failed'));
      };

      // Start the loading
      xhr.send();
    });
  }

  async runInference(preprocessedData: Tensor): Promise<[any, number]> {
    if (this.session === null) {
      throw new Error('Model is not loaded yet');
    }
    // Get start time to calculate inference time.
    const start = new Date();
    // create feeds with the input name from model export and the preprocessed data.
    const feeds: Record<string, Tensor> = {};
    const key = this.session.inputNames[0];
    feeds[key] = preprocessedData;

    // Run the session inference.
    const outputData = await this.session.run(feeds as unknown as ort.InferenceSession.FeedsType);

    // Get the end time to calculate inference time.
    const end = new Date();
    // Convert to seconds.
    const inferenceTime = (end.getTime() - start.getTime()) / 1000;
    // Get output results with the output name from the model export.
    const output: Record<string, any> = {};

    for (const outputName of this.session.outputNames) {
      output[outputName] = outputData[outputName]?.data?.at(0);
      output.sigmoid = stableSigmoid(output[outputName]);
    }

    return [output, inferenceTime];
  }

  async predictAgeWithoutWebWorker(image: InputImage) {
    try {
      if (!this.model) {
        throw new Error('Model is not selected yet');
      }
      const resizedImage = await this.resizeImage(image);

      if (!resizedImage) {
        logger.error('Error resizing image');
        return null;
      }

      const imageTensor = imageDataToTensor(
        resizedImage,
        resizedImage.width,
        resizedImage.height,
        this.model?.tensorShape,
        this.model?.normalizePixelData,
      );
      const [results, time] = await this.runInference(imageTensor);
      logger.debug('Age prediction results without worker:', { results, time });

      return results[this.model.outputKey] as number;
    } catch (error) {
      const errorMessage = getMessageFromError(error);
      logger.error('Error in predicting age without worker', { errorMessage });
      return null;
    }
  }

  selectModel() {
    if (this.models.length === 0) {
      throw new Error('No models available');
    }

    if (this.models.length === 1) {
      return this.models[0];
    }

    const random = Math.random();
    let current = 0;
    let m = this.models[0];
    for (const model of this.models) {
      // skip if zero or negative traffic
      // these should have already been filtered out when the config was loaded but just in case
      if (model.traffic <= 0) {
        continue;
      }

      current += model.traffic;
      if (random < current) {
        m = model;
        break;
      }
    }
    return m;
  }

  async resizeImage(image: InputImage) {
    // Assuming you have a canvas element in your HTML to use for image preprocessing
    const canvas = document.createElement('canvas');
    const ctx = canvas.getContext('2d');
    const [width, height] = this.model?.dimensions || [128, 128];

    // Resize the image to 128x128 (adjust as needed for your model)
    canvas.width = width;
    canvas.height = height;
    ctx?.drawImage(image, 0, 0, width, height);

    const resizedImage = ctx?.getImageData(0, 0, width, width);
    return resizedImage;
  }
}

export default AgePredictor;
