import mapboxgl from "mapbox-gl";
import { useEffect, useMemo } from "react";
import { useRecoilValue, useSetRecoilState } from "recoil";
import { mapRefAtom } from "../state/map";
import {
  MeanSpeedGrid,
  MeanSpeedGridCustom,
  meanSpeedGridLimitsAtom,
} from "../state/windStatistics";
import { COLORS } from "./windSpeed";
import { scream } from "../utils/sentry";

class SpeedGridLayer {
  id: string;
  type: string;
  source: string;
  points: number[];
  values: number[];
  meanSpeed: number;
  minSpeed: number;
  maxSpeed: number;
  cols: number;
  rows: number;
  program: WebGLProgram | undefined;
  aPos: number | undefined;
  aValue: number | undefined;
  buffer: WebGLBuffer | undefined;
  texcoordBuffer: WebGLBuffer | undefined;
  texcoordLocation?: number;
  texture: any;
  palette: any;
  paletteTex: any;
  colors: number[][];
  constructor(
    points: number[],
    speeds: number[],
    cols: number,
    rows: number,
    minSpeed: number,
    maxSpeed: number,
  ) {
    this.id = "speedgrid";
    this.type = "custom";
    this.source = "turbines";
    this.points = points;
    this.values = speeds.flatMap((v) => [
      (v - minSpeed) * (maxSpeed - minSpeed),
      0,
      0,
      v > 0 ? 1 : 0,
    ]);
    this.cols = cols;
    this.meanSpeed = speeds.reduce((acc, v) => acc + v, 0) / speeds.length;
    this.minSpeed = minSpeed;
    this.maxSpeed = maxSpeed;
    this.rows = rows;
    this.colors = [];
  }
  onAdd(map: mapboxgl.Map, gl: WebGLRenderingContext) {
    // create GLSL source for vertex shader
    const vertexSource = `
              precision highp float; 
              uniform mat4 u_matrix;
              attribute vec2 a_pos;
              varying vec2 cc;
              attribute vec2 a_texCoord;
              varying vec2 v_texCoord;
              uniform float u_lower;
              uniform float u_upper;
              uniform float u_mean;

              void main() {
                  cc = a_pos;
                  v_texCoord = a_texCoord;
                  gl_Position = u_matrix * vec4(a_pos, 0.0, 1.0);
              }`;

    // create GLSL source for fragment shader
    const fragmentSource = `
              precision highp float;
              uniform sampler2D u_valuesTexture;
              uniform float u_lower;
              uniform float u_upper;
              uniform float u_mean;
              varying vec2 v_texCoord;
              uniform sampler2D u_palette;
              void main() {
                float value = texture2D(u_valuesTexture, v_texCoord).r * 255.;
                float a = texture2D(u_valuesTexture, v_texCoord).a * 255.;
                float speed = (value / (u_upper - u_lower)) + u_lower;
                float normalized = (speed - u_lower) / (u_upper - u_lower);
                vec4 color = texture2D(u_palette, vec2(normalized, 0.5));
                gl_FragColor = color * a;
              }`;

    // create a vertex shader
    const vertexShader = gl.createShader(gl.VERTEX_SHADER);
    if (!vertexShader) throw scream("speedGrid failed to create vertexShader");
    gl.shaderSource(vertexShader, vertexSource);
    gl.compileShader(vertexShader);

    // create a fragment shader
    const fragmentShader = gl.createShader(gl.FRAGMENT_SHADER);
    if (!fragmentShader)
      throw scream("speedGrid failed to create fragmentShader");
    gl.shaderSource(fragmentShader, fragmentSource);
    gl.compileShader(fragmentShader);

    // link the two shaders into a WebGL program
    const program = gl.createProgram();
    if (!program) throw scream("speedGrid failed to create program");
    this.program = program;
    gl.attachShader(this.program, vertexShader);
    gl.attachShader(this.program, fragmentShader);
    gl.linkProgram(this.program);

    this.aPos = gl.getAttribLocation(this.program, "a_pos");
    // create and initialize a WebGLBuffer to store vertex and color data
    const buffer = gl.createBuffer();
    if (!buffer) throw scream("speedGrid failed to create buffer");
    this.buffer = buffer;
    gl.bindBuffer(gl.ARRAY_BUFFER, this.buffer);
    gl.bufferData(
      gl.ARRAY_BUFFER,
      new Float32Array(this.points),
      gl.STATIC_DRAW,
    );
    this.texcoordLocation = gl.getAttribLocation(this.program, "a_texCoord");

    const texcoordBuffer = gl.createBuffer();
    if (!texcoordBuffer)
      throw scream("speedGrid failed to create texcoordBuffer");
    this.texcoordBuffer = texcoordBuffer;
    gl.bindBuffer(gl.ARRAY_BUFFER, this.texcoordBuffer);
    gl.bufferData(
      gl.ARRAY_BUFFER,
      new Float32Array([
        0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0, 1, 0.0, 0.0,
      ]),
      gl.STATIC_DRAW,
    );

    gl.activeTexture(gl.TEXTURE0);
    this.texture = gl.createTexture();
    gl.bindTexture(gl.TEXTURE_2D, this.texture);

    this.colors = COLORS.default;
    this.palette = new Uint8Array(this.colors.flatMap((v) => v));
    this.paletteTex = gl.createTexture();

    // Set the parameters so we can render any size image.
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
  }

  render(gl: WebGLRenderingContext, matrix: number[]) {
    if (
      !this.program ||
      this.aPos === undefined ||
      !this.buffer ||
      !this.texcoordBuffer ||
      !this.texcoordLocation
    )
      return;
    gl.useProgram(this.program);

    gl.enable(gl.BLEND);
    gl.blendFunc(gl.SRC_ALPHA, gl.ONE_MINUS_SRC_ALPHA);

    gl.uniformMatrix4fv(
      gl.getUniformLocation(this.program, "u_matrix"),
      false,
      matrix,
    );

    gl.enableVertexAttribArray(this.aPos);
    gl.bindBuffer(gl.ARRAY_BUFFER, this.buffer);
    gl.vertexAttribPointer(this.aPos, 2, gl.FLOAT, false, 0, 0);

    gl.enableVertexAttribArray(this.texcoordLocation);
    gl.bindBuffer(gl.ARRAY_BUFFER, this.texcoordBuffer);
    gl.vertexAttribPointer(this.texcoordLocation, 2, gl.FLOAT, false, 0, 0);

    gl.activeTexture(gl.TEXTURE0);
    gl.bindTexture(gl.TEXTURE_2D, this.texture);
    gl.uniform1i(gl.getUniformLocation(this.program, "u_valuesTexture"), 0);
    const array = new Uint8Array(this.values);
    gl.texImage2D(
      gl.TEXTURE_2D,
      0,
      gl.RGBA,
      this.cols,
      this.rows,
      0,
      gl.RGBA,
      gl.UNSIGNED_BYTE,
      array,
    );

    gl.activeTexture(gl.TEXTURE1);
    gl.bindTexture(gl.TEXTURE_2D, this.paletteTex);
    gl.uniform1i(gl.getUniformLocation(this.program, "u_palette"), 1);
    gl.texImage2D(
      gl.TEXTURE_2D,
      0,
      gl.RGBA,
      this.colors.length,
      1,
      0,
      gl.RGBA,
      gl.UNSIGNED_BYTE,
      this.palette,
    );

    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);

    gl.uniform1f(gl.getUniformLocation(this.program, "u_lower"), this.minSpeed);
    gl.uniform1f(gl.getUniformLocation(this.program, "u_upper"), this.maxSpeed);
    gl.uniform1f(gl.getUniformLocation(this.program, "u_mean"), this.meanSpeed);

    gl.drawArrays(gl.TRIANGLES, 0, 6);
  }
}

const SpeedGrid = ({ meanSpeedGrid }: { meanSpeedGrid: MeanSpeedGrid }) => {
  const setMeanSpeedLimits = useSetRecoilState(meanSpeedGridLimitsAtom);
  const map = useRecoilValue(mapRefAtom);

  const [minSpeed, maxSpeed] = useMemo(() => {
    const speeds = meanSpeedGrid.grid.flat();
    const min =
      speeds.reduce((a, v) => (v > 0 ? Math.min(a, v) : a), 15) - 0.01;
    const max = speeds.reduce((a, v) => Math.max(a, v), 0) + 0.01;

    return [min, max];
  }, [meanSpeedGrid]);

  useEffect(() => {
    setMeanSpeedLimits([minSpeed, maxSpeed]);
    return () => setMeanSpeedLimits(undefined);
  }, [minSpeed, maxSpeed, setMeanSpeedLimits]);

  const speedLayer = useSpeedGridActive({
    meanSpeedGrid,
    minSpeed,
    maxSpeed,
  });

  useEffect(() => {
    if (!map || !speedLayer) return;
    map.addLayer(speedLayer as any, "building");
    return () => {
      map.removeLayer(speedLayer.id);
    };
  }, [map, speedLayer]);

  return null;
};

export const useSpeedGridActive = ({
  meanSpeedGrid,
  minSpeed,
  maxSpeed,
}: {
  meanSpeedGrid?: MeanSpeedGrid | MeanSpeedGridCustom;
  minSpeed?: number;
  maxSpeed?: number;
}) => {
  const speedLayer = useMemo(() => {
    if (!minSpeed || !maxSpeed || !meanSpeedGrid) return;
    const x0 = meanSpeedGrid.xllcorner;
    const x1 = x0 + meanSpeedGrid.dx * meanSpeedGrid.ncols;
    const y0 = meanSpeedGrid.yllcorner;
    const y1 = y0 + meanSpeedGrid.dy * meanSpeedGrid.nrows;
    const lonlats = [
      [x0, y0],
      [x1, y0],
      [x1, y1],
      [x1, y1],
      [x0, y1],
      [x0, y0],
    ];

    const mercatorPoints = lonlats.map((lonlat) =>
      mapboxgl.MercatorCoordinate.fromLngLat({
        lng: lonlat[0],
        lat: lonlat[1],
      }),
    );

    const points = mercatorPoints.flatMap((p) => [p.x, p.y]);
    const speeds = meanSpeedGrid.grid.flatMap((v) => v);

    return new SpeedGridLayer(
      points,
      speeds,
      meanSpeedGrid.ncols,
      meanSpeedGrid.nrows,
      minSpeed,
      maxSpeed,
    );
  }, [meanSpeedGrid, maxSpeed, minSpeed]);

  return speedLayer;
};

export default SpeedGrid;
