import { useEffect, useMemo } from "react";
import { windLayerHeightAtom, windLayerSourceAtom } from "../state/layer";
import {
  defaultWindSpeedRasterMinMax,
  windSpeedRasterMinMaxAtom,
} from "../state/map";
import { CustomLayerInterface, RasterSourceSpecification } from "mapbox-gl";
import { LAYER_DEBUG_PRINT } from "../state/debug";
import { scream } from "../utils/sentry";
import { Map } from "mapbox-gl";
import { useAtomValue } from "jotai";
import { MesoWindDataSource } from "types/metocean";
import { COLORS } from "./windSpeed";

const tileSize = 512;

class LocalSpeedLayer implements CustomLayerInterface {
  id: string;
  type: "custom";
  speedSourceId: string;
  colors: number[][];
  map: any;
  sourceCache: any;
  vertexArray: Int16Array | undefined;
  vertexBuffer: any;
  indexArray: Uint16Array | undefined;
  indexBuffer: any;
  program: any;
  aPos: any;
  minMax: [number, number] = defaultWindSpeedRasterMinMax;

  constructor(speedSourceId: string, minMax: [number, number]) {
    this.id = "gwa-speed-layer";
    this.type = "custom";
    this.speedSourceId = speedSourceId;
    this.colors = COLORS.default;
    this.minMax = minMax;
  }

  setMinMax(minMax: [number, number]) {
    this.minMax = minMax;
  }

  onAdd(map: mapboxgl.Map, gl: WebGLRenderingContext) {
    LAYER_DEBUG_PRINT && console.log("LocalSpeedLayer.onAdd");
    this.map = map;
    const style = (map as any).style;
    this.sourceCache = style._otherSourceCaches[this.speedSourceId];
    this.sourceCache.pause();

    this.prepareShaders(gl);
    this.prepareBuffers(gl);
  }

  update() {
    const transform = this.map.transform.clone();
    const pitchOffset =
      transform.cameraToCenterDistance * Math.sin(transform._pitch);
    transform.height = transform.height + pitchOffset;

    this.sourceCache._paused = false;
    this.sourceCache.used = true;
    this.sourceCache.update(transform);
    this.sourceCache.pause();
  }

  prepareBuffers(gl: WebGLRenderingContext) {
    const n = 64;

    this.vertexArray = new Int16Array(n * n * 2);
    for (let i = 0; i < n; i++) {
      for (let j = 0; j < n; j++) {
        const vertex = [j * (8192 / (n - 1)), i * (8192 / (n - 1))];
        const offset = (i * n + j) * 2;
        this.vertexArray.set(vertex, offset);
      }
    }

    this.vertexBuffer = gl.createBuffer();
    gl.bindBuffer(gl.ARRAY_BUFFER, this.vertexBuffer);
    gl.bufferData(gl.ARRAY_BUFFER, this.vertexArray.buffer, gl.STATIC_DRAW);

    this.indexArray = new Uint16Array((n - 1) * (n - 1) * 6);
    let offset = 0;
    for (let i = 0; i < n - 1; i++) {
      for (let j = 0; j < n - 1; j++) {
        const index = i * n + j;
        const quad = [
          index,
          index + 1,
          index + n,
          index + n,
          index + 1,
          index + n + 1,
        ];
        this.indexArray.set(quad, offset);
        offset += 6;
      }
    }

    this.indexBuffer = gl.createBuffer();
    gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, this.indexBuffer);
    gl.bufferData(
      gl.ELEMENT_ARRAY_BUFFER,
      this.indexArray.buffer,
      gl.STATIC_DRAW,
    );
  }

  prepareShaders(gl: WebGLRenderingContext) {
    var vertexSource = `
      uniform mat4 u_matrix;
      attribute vec2 a_pos;
      uniform vec2 u_minMax;
      varying vec2 v_pos;
      void main() {
          v_pos = vec2(a_pos / 8192.0);
          gl_Position = u_matrix * vec4(a_pos, .0, 1.0);
    }`;

    var fragmentSource = `
  precision highp float;
  uniform sampler2D u_raster;
  varying vec2 v_pos;
  uniform sampler2D u_palette;
  uniform float u_number_of_colors;
  uniform vec2 u_minMax;

  void main() {
    vec4 value = texture2D(u_raster, v_pos);
    float speed = value.r * 10. + 3.;
    float c = (speed - u_minMax.x) / (u_minMax.y - u_minMax.x);
    vec4 color = texture2D(u_palette, vec2(c, .5));
    gl_FragColor = vec4(color.rgb, color.a * value.a);

  }`;

    var vertexShader = gl.createShader(gl.VERTEX_SHADER);
    if (!vertexShader) throw scream("windSpeed: failed to create vertexShader");
    gl.shaderSource(vertexShader, vertexSource);
    gl.compileShader(vertexShader);
    var fragmentShader = gl.createShader(gl.FRAGMENT_SHADER);
    if (!fragmentShader)
      throw scream("windSpeed: failed to create fragmentShader");
    gl.shaderSource(fragmentShader, fragmentSource);
    gl.compileShader(fragmentShader);

    this.program = gl.createProgram();
    gl.attachShader(this.program, vertexShader);
    gl.attachShader(this.program, fragmentShader);
    gl.linkProgram(this.program);

    this.aPos = gl.getAttribLocation(this.program, "a_pos");
  }

  render(gl: WebGLRenderingContext) {
    if (!this.indexArray) return;
    LAYER_DEBUG_PRINT && console.time("LocalSpeedLayer.render");
    gl.useProgram(this.program);

    gl.enable(gl.BLEND);
    gl.blendFunc(gl.SRC_ALPHA, gl.ONE_MINUS_SRC_ALPHA);

    // Bind vertex buffer
    gl.bindBuffer(gl.ARRAY_BUFFER, this.vertexBuffer);
    gl.enableVertexAttribArray(this.aPos);
    gl.vertexAttribPointer(this.aPos, 2, gl.SHORT, false, 0, 0);

    // Bind index buffer
    gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, this.indexBuffer);

    const palette = new Uint8Array(this.colors.flatMap((v) => v));

    gl.activeTexture(gl.TEXTURE2);
    gl.uniform1i(gl.getUniformLocation(this.program, "u_palette"), 2);
    var paletteTex = gl.createTexture();
    gl.bindTexture(gl.TEXTURE_2D, paletteTex);
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
    gl.texImage2D(
      gl.TEXTURE_2D,
      0,
      gl.RGBA,
      this.colors.length,
      1,
      0,
      gl.RGBA,
      gl.UNSIGNED_BYTE,
      palette,
    );

    gl.uniform1f(
      gl.getUniformLocation(this.program, "u_number_of_colors"),
      this.colors.length,
    );

    this.update();

    gl.uniform2fv(gl.getUniformLocation(this.program, "u_minMax"), this.minMax);

    let coords = this.sourceCache.getVisibleCoordinates().reverse();

    for (const coord of coords) {
      const tile = this.sourceCache.getTile(coord);

      // Bind speed raster texture to unit 0
      gl.activeTexture(gl.TEXTURE0);
      gl.bindTexture(gl.TEXTURE_2D, tile.texture.texture);
      gl.uniform1i(gl.getUniformLocation(this.program, "u_raster"), 0);
      gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
      gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
      gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
      gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);

      // Bind matrix
      gl.uniformMatrix4fv(
        gl.getUniformLocation(this.program, "u_matrix"),
        false,
        coord.projMatrix,
      );

      // Draw
      const vertexCount = this.indexArray.length;
      const type = gl.UNSIGNED_SHORT;
      const offset = 0;
      gl.drawElements(gl.TRIANGLES, vertexCount, type, offset);
    }
    LAYER_DEBUG_PRINT && console.timeEnd("LocalSpeedLayer.render");
  }
}

export const ActiveSpeedLayer2D = ({
  inputMap,
  inputWindLayerHeight,
  source,
}: {
  inputMap: Map | undefined;
  inputWindLayerHeight?: number;
  source?: MesoWindDataSource;
}) => {
  const stateWindLayerHeight = useAtomValue(windLayerHeightAtom);
  const windSource = useAtomValue(windLayerSourceAtom);
  const windSpeedRasterMinMax = useAtomValue(windSpeedRasterMinMaxAtom);

  const map = inputMap;
  const windLayerHeight =
    inputWindLayerHeight !== undefined
      ? inputWindLayerHeight
      : stateWindLayerHeight;
  const speedSource: RasterSourceSpecification & {
    id: string;
    tiles: string[];
    tileSize: number;
  } = useMemo(
    () => ({
      id: "speed-source-speed-layer",
      type: "raster",
      tiles: [
        `/tiles/${source || windSource}/speed/${windLayerHeight}/{z}/{x}/{y}.png`,
      ],
      tileSize,
    }),
    [windLayerHeight, source, windSource],
  );

  const speedLayer = useMemo(
    () => new LocalSpeedLayer(speedSource.id, windSpeedRasterMinMax),
    [speedSource.id, windSpeedRasterMinMax],
  );

  useEffect(() => {
    speedLayer.setMinMax(windSpeedRasterMinMax);
    if (!map) return;
    map.triggerRepaint();
  }, [map, windSpeedRasterMinMax, speedLayer]);

  useEffect(() => {
    if (!map) return;
    map.addSource(speedSource.id, {
      type: speedSource.type,
      tiles: speedSource.tiles,
      tileSize: speedSource.tileSize,
      maxzoom: 7,
    });
    map.addLayer(speedLayer, "building");
    map.triggerRepaint();

    return () => {
      if (map) {
        map.removeLayer(speedLayer.id);
        map.removeSource(speedSource.id);
      }
    };
  }, [map, speedLayer, speedSource]);
  return null;
};
