import { useEffect, useMemo } from "react";
import { useRecoilValue } from "recoil";
import {
  contourStepSizeAtom,
  getActiveMapStyleSelector,
  mapRefAtom,
} from "../state/map";
import mapboxgl from "mapbox-gl";
import { Map } from "mapbox-gl";
import { scream, sendInfo } from "../utils/sentry";
import { LAYER_DEBUG_PRINT } from "../state/debug";

const depthContourSourceId = "vind:source:depth-contour";
const depthContourLayerId = "vind:layer:depth-contour";

const TILE_WIDTH = 512;
const NORMALISATION_FACTOR = 1000;

export const BATHYMETRY_COLORS = {
  default: [
    [0.64 * 255, 0.78 * 255, 0.92 * 255, 255],
    [0.45 * 255, 0.55 * 255, 0.7 * 255, 255],
    [0.38 * 255, 0.47 * 255, 0.6 * 255, 255],
    [0.31 * 255, 0.36 * 255, 0.4 * 255, 255],
    [0.12 * 255, 0.17 * 255, 0.21 * 255, 255],
  ],
};

type GLStuff = {
  program: WebGLProgram;
  vertexShader: WebGLShader;
  fragmentShader: WebGLShader;

  loc: {
    a_pos: number;
    u_matrix: WebGLUniformLocation;
    u_depth_raster: WebGLUniformLocation;
    u_contourDist: WebGLUniformLocation;
    u_bilinear: WebGLUniformLocation;
  };

  buffer: {
    vertex: WebGLBuffer;
    index: WebGLBuffer;
  };
};

type GLDecodeTerrariumStuff = {
  program: WebGLProgram;
  vertexShader: WebGLShader;
  fragmentShader: WebGLShader;

  loc: {
    a_pos: number;
    a_texCoord: number;
    u_depth_raster: WebGLUniformLocation;
    texture: WebGLTexture;
    framebuffer: WebGLFramebuffer;
    bufferPosition: WebGLBuffer;
    frameBufferTextureCoords: WebGLBuffer;
  };

  buffer: {
    vertex: WebGLBuffer;
    index: WebGLBuffer;
  };
};

const makeGlProgram = (gl: WebGLRenderingContext): GLStuff => {
  const vertexSource = `#version 300 es
precision highp float;
uniform mat4 u_matrix;

in vec2 a_texCoord;
in vec2 a_pos;

out vec2 v_texCoord;

void main() {
  v_texCoord = a_pos / 8192.0;
  gl_Position = u_matrix * vec4(a_pos, 0.0, 1.0);
}
          `;

  const fragmentSource = `#version 300 es
precision highp float;
uniform sampler2D u_depth_raster;
uniform bool u_bilinear;
uniform float u_contourDist;

in vec2 a_texCoord;
in vec2 v_texCoord;

layout(location=0) out vec4 outColor;

vec3 depthToColor(float depth) {
  vec3 lightBlue = vec3(0.64, 0.78, 0.92);
  vec3 medlightBlue = vec3(0.45, 0.55, 0.70);
  vec3 mediumBlue = vec3(0.38, 0.47, 0.60);
  vec3 darkBlue = vec3(0.31, 0.36, 0.40);
  vec3 darkestBlue = vec3(0.12, 0.17, 0.21);
  float l1 = 50.0;
  float l2 = 300.0;
  float l3 = 500.0;
  float l4 = 700.0;
  if (depth < l1) {
    float f = depth  / l1;
    return mix(lightBlue, medlightBlue, f);
  } else if (depth < l2) {
    float f = clamp((depth - l1) / (l2 - l1), 0., 1.);
    return mix(medlightBlue, mediumBlue, f);
  } else if (depth < l3) {
    float f = clamp((depth - l2) / (l3 - l2), 0., 1.);
    return mix(mediumBlue, darkBlue, f);
  } else {
    float f = clamp((depth - l3) / (l4 - l3), 0., 1.);
    return mix(darkBlue, darkestBlue, f);
  }
}

// vec3 debugColor(float depth, float S) {
//   float frac = fract(depth / S);
//   float Q = 3.0;
//   float bucket = floor(depth / S);
//   float b = mod(bucket, Q) / Q;
//   float r = mod(bucket, Q * Q) / Q / Q;
//   float g = mod(bucket, Q * Q * Q) / Q / Q / Q;
//   return vec3(r, g, b);
// }

// Return 0.0 if we are off the line, 1.0 if we are on the line, and somewhere in between
// if we are somewhere in between.
float depthToContour(float depth, float Step) {
  // Simply checking that  depth  is a multiple of  step  will cause
  // thick lines in flat regions and thin lines in steep regions.
  // Instead, get the derivatives to figure out how far off the the pixel
  // is off of the contour line, and color based on that.
  float frac = fract(depth / Step);
  float grace = fwidth(depth / Step);

  return 1.0 - step(grace, frac);

  float onContour = 1.0 - smoothstep(0.0, grace, frac);
  return onContour;
  return step(0.5, onContour);
}

void main() {
  float STEP = u_contourDist;
  float depth = texture(u_depth_raster, v_texCoord).r * ${NORMALISATION_FACTOR}.;

  float land = step(STEP, depth); // 1.0 if we are in water, 0.0 if we are on land.
  float contour = depthToContour(depth, STEP);

  float depthstep = floor(depth / STEP);
  vec3 areaColor = depthToColor(depthstep * STEP);
  // areaColor = debugColor(depth, STEP);

  float sign = float(depth < 200.0) * 2.0 - 1.0; // +/- 1.0, positive is shallow.
  float showIso = float(u_bilinear);
  vec3 withIsoLines = areaColor + showIso * sign * contour * areaColor * 0.15;

  outColor = vec4(withIsoLines, 1.0) * land;
}
`;

  const vertexShader = gl.createShader(gl.VERTEX_SHADER);
  if (!vertexShader)
    throw scream("Fragment shader failed to compile:", {
      error: gl.getError(),
    });
  gl.shaderSource(vertexShader, vertexSource);
  gl.compileShader(vertexShader);

  const fragmentShader = gl.createShader(gl.FRAGMENT_SHADER);
  if (!fragmentShader)
    throw scream("Fragment shader failed to compile:", {
      error: gl.getError(),
    });
  gl.shaderSource(fragmentShader, fragmentSource);
  gl.compileShader(fragmentShader);

  const program = gl.createProgram();
  if (!program) throw scream("Failed to create program");
  gl.attachShader(program, vertexShader);
  gl.attachShader(program, fragmentShader);
  gl.linkProgram(program);

  if (!gl.getProgramParameter(program, gl.LINK_STATUS))
    throw scream("Failed to create program", {
      program,
      vertexShader,
      fragmentShader,
      programInfo: gl.getProgramInfoLog(program),
      vertexInfo: gl.getShaderInfoLog(vertexShader),
      fragmentInfo: gl.getShaderInfoLog(fragmentShader),
    });

  const a_pos = gl.getAttribLocation(program, "a_pos");
  if (a_pos === -1) throw scream("Failed to get position location");

  const vertexBuffer = gl.createBuffer();
  if (!vertexBuffer) throw scream("Failed to create vertex buffer");

  const indexBuffer = gl.createBuffer();
  if (!indexBuffer) throw scream("Failed to create vertex buffer");

  const u_matrix = gl.getUniformLocation(program, "u_matrix");
  if (u_matrix === null) throw scream("Failed to get u_matrix location");

  const u_bilinear = gl.getUniformLocation(program, "u_bilinear");
  if (u_bilinear === null) throw scream("Failed to get u_bilinear location");

  const u_contourDist = gl.getUniformLocation(program, "u_contourDist");
  if (u_contourDist === null)
    throw scream("Failed to get u_contourDist location");

  const u_depth_raster = gl.getUniformLocation(program, "u_depth_raster");
  if (u_depth_raster === null)
    throw scream("Failed to get u_depth_raster location");

  return {
    program,
    vertexShader,
    fragmentShader,
    loc: {
      a_pos,
      u_matrix,
      u_depth_raster,
      u_bilinear,
      u_contourDist,
    },
    buffer: {
      vertex: vertexBuffer,
      index: indexBuffer,
    },
  };
};

class DepthContourLayer implements mapboxgl.CustomLayerInterface {
  id: string;
  type: "custom";
  renderingMode?: "3d";
  depthSource: string;
  map: Map | undefined;
  depthSourceCache: any;
  vertexCount: number = 0;
  /** All GL things */
  gl: GLStuff | undefined;

  contourDist: number = 25.0;
  bilinear: boolean = true;

  decodeTerrariumShaderData: GLDecodeTerrariumStuff | undefined;

  constructor(depthSource: string) {
    this.id = depthContourLayerId;
    this.type = "custom";
    this.depthSource = depthSource;
  }

  updateContourDist(newContourDist: number) {
    this.contourDist = newContourDist;
  }

  onAdd(map: Map, gl: WebGLRenderingContext) {
    LAYER_DEBUG_PRINT && console.log("DepthLayer.onAdd");
    this.map = map;
    this.depthSourceCache = (map as any).style._otherSourceCaches[
      this.depthSource
    ];
    this.depthSourceCache.pause();
    const g = (this.gl = makeGlProgram(gl));

    const n = 64; // 64x64 grid; I'm not sure why we need this.

    // Make vertices.
    const vertexArray = new Int16Array(n * n * 2);
    for (let i = 0; i < n; i++) {
      for (let j = 0; j < n; j++) {
        const vertex = [j * (8192 / (n - 1)), i * (8192 / (n - 1))];
        const offset = (i * n + j) * 2;
        vertexArray.set(vertex, offset);
      }
    }

    gl.bindBuffer(gl.ARRAY_BUFFER, g.buffer.vertex);
    gl.bufferData(gl.ARRAY_BUFFER, vertexArray.buffer, gl.STATIC_DRAW);

    // Make indices for rendering triangles.
    const vertexCount = (this.vertexCount = (n - 1) * (n - 1) * 6);
    const indexArray = new Uint16Array(vertexCount);
    let offset = 0;
    for (let i = 0; i < n - 1; i++) {
      for (let j = 0; j < n - 1; j++) {
        const index = i * n + j;
        const quad = [
          index,
          index + 1,
          index + n,
          index + n,
          index + 1,
          index + n + 1,
        ];
        indexArray.set(quad, offset);
        offset += 6;
      }
    }

    gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, g.buffer.index);
    gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, indexArray.buffer, gl.STATIC_DRAW);

    this.createTerrariumDecodedFramebuffer(gl, TILE_WIDTH, TILE_WIDTH);
  }

  update() {
    const transform = (this.map as any).transform.clone();
    const pitchOffset =
      transform.cameraToCenterDistance * Math.sin(transform._pitch);
    transform.height = transform.height + pitchOffset;
    this.depthSourceCache._paused = false;
    this.depthSourceCache.used = true;
    this.depthSourceCache.update(transform);
    this.depthSourceCache.pause();
  }

  decodeTerrariumShader = (gl: WebGLRenderingContext) => {
    const vertexSource = `#version 300 es
  precision highp float;
  
  in vec2 a_texCoord;
  in vec2 a_pos;
  
  out vec2 v_texCoord;
  
  void main() {
    v_texCoord = a_texCoord;
    gl_Position = vec4(a_pos, 0.0, 1.0);
  }
            `;

    const fragmentSource = `#version 300 es
  precision highp float;
  uniform sampler2D u_depth_raster;
  in vec2 v_texCoord;

  layout(location=0) out vec4 outColor;
  
  float terrariumDepth(sampler2D depth_raster, vec2 texCoord) {
    vec4 color = texture(depth_raster, texCoord);
    float R = color.r * 255.0;
    float G = color.g * 255.0;
    float B = color.b * 255.0;
    return ((R * 256.0 + G + B / 256.0) - 32768.0) * -1.;
  }
  
  void main() {
      float depth = terrariumDepth(u_depth_raster, v_texCoord); 
      outColor = vec4(depth/${NORMALISATION_FACTOR}., 0., 0., 1.);
  }
  `;

    const vertexShader = gl.createShader(gl.VERTEX_SHADER);
    if (!vertexShader)
      throw scream("Fragment shader failed to compile:", {
        error: gl.getError(),
      });
    gl.shaderSource(vertexShader, vertexSource);
    gl.compileShader(vertexShader);

    const fragmentShader = gl.createShader(gl.FRAGMENT_SHADER);
    if (!fragmentShader)
      throw scream("Fragment shader failed to compile:", {
        error: gl.getError(),
      });
    gl.shaderSource(fragmentShader, fragmentSource);
    gl.compileShader(fragmentShader);

    const program = gl.createProgram();
    if (!program) throw scream("Failed to create program");
    gl.attachShader(program, vertexShader);
    gl.attachShader(program, fragmentShader);
    gl.linkProgram(program);

    if (!gl.getProgramParameter(program, gl.LINK_STATUS))
      throw scream("Failed to create program", {
        program,
        vertexShader,
        fragmentShader,
        programInfo: gl.getProgramInfoLog(program),
        vertexInfo: gl.getShaderInfoLog(vertexShader),
        fragmentInfo: gl.getShaderInfoLog(fragmentShader),
      });

    const u_depth_raster = gl.getUniformLocation(program, "u_depth_raster");
    if (u_depth_raster === null)
      throw scream("Failed to get u_depth_raster location");

    const a_pos = gl.getAttribLocation(program, "a_pos");
    if (a_pos === -1) throw scream("Failed to get position location");

    const a_texCoord = gl.getAttribLocation(program, "a_texCoord");
    if (a_texCoord === -1) throw scream("Failed to get position location");

    return {
      program,
      vertexShader,
      fragmentShader,
      loc: {
        a_pos,
        a_texCoord,
        u_depth_raster,
      },
    };
  };

  createTerrariumDecodedFramebuffer(
    gl: WebGLRenderingContext,
    width: number,
    height: number,
  ) {
    const positionBuffer = gl.createBuffer();
    if (!positionBuffer) throw scream("Failed to create positionBuffer buffer");
    gl.bindBuffer(gl.ARRAY_BUFFER, positionBuffer);
    const positions = [1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0, -1.0];
    gl.bufferData(gl.ARRAY_BUFFER, new Float32Array(positions), gl.STATIC_DRAW);

    const textureCoordBuffer = gl.createBuffer();
    if (!textureCoordBuffer)
      throw scream("Failed to create textureCoordBuffer buffer");

    gl.bindBuffer(gl.ARRAY_BUFFER, textureCoordBuffer);
    const textureCoordinates = [1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0];
    gl.bufferData(
      gl.ARRAY_BUFFER,
      new Float32Array(textureCoordinates),
      gl.STATIC_DRAW,
    );

    const indexBuffer = gl.createBuffer();
    if (!indexBuffer) throw scream("Failed to create index buffer");

    gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, indexBuffer);
    const indices = [0, 1, 2, 0, 2, 3];
    gl.bufferData(
      gl.ELEMENT_ARRAY_BUFFER,
      new Uint16Array(indices),
      gl.STATIC_DRAW,
    );

    const framebuffer = gl.createFramebuffer();
    if (!framebuffer) throw scream("Failed to create framebuffer");
    gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer);

    const texture = gl.createTexture();
    if (!texture) throw scream("Failed to create texture");
    gl.bindTexture(gl.TEXTURE_2D, texture);
    gl.texImage2D(
      gl.TEXTURE_2D,
      0,
      gl.RGBA,
      width,
      height,
      0,
      gl.RGBA,
      gl.UNSIGNED_BYTE,
      null,
    );
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);

    const renderbuffer = gl.createRenderbuffer();
    gl.bindRenderbuffer(gl.RENDERBUFFER, renderbuffer);
    gl.renderbufferStorage(
      gl.RENDERBUFFER,
      gl.DEPTH_COMPONENT16,
      width,
      height,
    );

    gl.framebufferTexture2D(
      gl.FRAMEBUFFER,
      gl.COLOR_ATTACHMENT0,
      gl.TEXTURE_2D,
      texture,
      0,
    );
    gl.framebufferRenderbuffer(
      gl.FRAMEBUFFER,
      gl.DEPTH_ATTACHMENT,
      gl.RENDERBUFFER,
      renderbuffer,
    );

    gl.bindTexture(gl.TEXTURE_2D, null);
    gl.bindRenderbuffer(gl.RENDERBUFFER, null);
    gl.bindFramebuffer(gl.FRAMEBUFFER, null);

    const shader = this.decodeTerrariumShader(gl);

    this.decodeTerrariumShaderData = {
      ...shader,
      loc: {
        ...shader.loc,
        texture: texture,
        framebuffer: framebuffer,
        bufferPosition: positionBuffer,
        frameBufferTextureCoords: textureCoordBuffer,
      },
      buffer: { vertex: positionBuffer, index: indexBuffer },
    };
  }

  render(gl: WebGLRenderingContext, _matrix: number[]) {
    const g = this.gl;
    if (!g || !this.decodeTerrariumShaderData) return;
    LAYER_DEBUG_PRINT && console.time("DepthLayer.render");

    this.update();

    let coords = this.depthSourceCache.getVisibleCoordinates().reverse();

    for (const coord of coords) {
      const prevViewport = gl.getParameter(gl.VIEWPORT) as
        | [number, number, number, number]
        | null;
      if (!prevViewport) {
        sendInfo("Viewport was falsy", {
          prevViewport,
          parameterValue: gl.VIEWPORT,
          coord,
        });
        continue;
      }

      const depthTile = this.depthSourceCache.getTile(coord);

      gl.bindFramebuffer(
        gl.FRAMEBUFFER,
        this.decodeTerrariumShaderData.loc.framebuffer!,
      );
      gl.viewport(0, 0, TILE_WIDTH, TILE_WIDTH);
      gl.clear(gl.COLOR_BUFFER_BIT | gl.DEPTH_BUFFER_BIT);

      gl.useProgram(this.decodeTerrariumShaderData.program);
      // Position attribute
      gl.bindBuffer(
        gl.ARRAY_BUFFER,
        this.decodeTerrariumShaderData.loc.bufferPosition!,
      );
      gl.vertexAttribPointer(
        this.decodeTerrariumShaderData.loc.a_pos,
        2,
        gl.FLOAT,
        false,
        0,
        0,
      );
      gl.enableVertexAttribArray(this.decodeTerrariumShaderData.loc.a_pos);

      // Texture coordinate attribute
      gl.bindBuffer(
        gl.ARRAY_BUFFER,
        this.decodeTerrariumShaderData.loc.frameBufferTextureCoords,
      );
      gl.vertexAttribPointer(
        this.decodeTerrariumShaderData.loc.a_texCoord,
        2,
        gl.FLOAT,
        false,
        0,
        0,
      );
      gl.enableVertexAttribArray(this.decodeTerrariumShaderData.loc.a_texCoord);

      // Indices
      gl.bindBuffer(
        gl.ELEMENT_ARRAY_BUFFER,
        this.decodeTerrariumShaderData.buffer.index,
      );
      gl.activeTexture(gl.TEXTURE0);
      gl.bindTexture(gl.TEXTURE_2D, depthTile.texture.texture);
      gl.uniform1i(this.decodeTerrariumShaderData.loc.u_depth_raster, 0);
      gl.drawElements(gl.TRIANGLES, 6, gl.UNSIGNED_SHORT, 0);

      gl.bindFramebuffer(gl.FRAMEBUFFER, null);
      gl.viewport(...prevViewport);
      gl.useProgram(g.program);

      gl.enable(gl.BLEND);
      gl.blendFunc(gl.SRC_ALPHA, gl.ONE_MINUS_SRC_ALPHA);
      gl.bindBuffer(gl.ARRAY_BUFFER, g.buffer.vertex);
      gl.enableVertexAttribArray(g.loc.a_pos);
      gl.vertexAttribPointer(g.loc.a_pos, 2, gl.SHORT, false, 0, 0);
      gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, g.buffer.index);
      gl.uniform1i(g.loc.u_bilinear, this.bilinear ? 1 : 0);
      gl.uniform1f(g.loc.u_contourDist, this.contourDist);

      // Bind depth raster texture to unit 0
      gl.activeTexture(gl.TEXTURE0);
      gl.bindTexture(gl.TEXTURE_2D, this.decodeTerrariumShaderData.loc.texture);
      const filter = this.bilinear ? gl.LINEAR : gl.NEAREST;
      gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, filter);
      gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, filter);
      gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
      gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
      gl.uniform1i(g.loc.u_depth_raster, 0);

      gl.uniformMatrix4fv(g.loc.u_matrix, false, coord.projMatrix);

      const type = gl.UNSIGNED_SHORT;
      const offset = 0;
      gl.drawElements(gl.TRIANGLES, this.vertexCount, type, offset);
    }
    LAYER_DEBUG_PRINT && console.timeEnd("DepthLayer.render");
  }
}

const BathymetryActive = () => {
  const map = useRecoilValue(mapRefAtom);
  const contourStepSize = useRecoilValue(contourStepSizeAtom);

  const bathymetrySource = useMemo(
    () => ({
      id: depthContourSourceId,
      type: "raster" as const,
      tiles: [`/tiles/gebco-terrarium-2023/{z}/{x}/{y}.png`],
    }),
    [],
  );

  const depthLayer = useMemo(
    () => new DepthContourLayer(bathymetrySource.id),
    [bathymetrySource.id],
  );

  useEffect(() => {
    if (!map) return;
    depthLayer.updateContourDist(contourStepSize);
    map.triggerRepaint();
  }, [contourStepSize, depthLayer, map]);

  useEffect(() => {
    if (!map) return;

    map.addSource(bathymetrySource.id, {
      type: bathymetrySource.type,
      tiles: bathymetrySource.tiles,
      tileSize: 512,
      maxzoom: 7,
    });

    try {
      map.addLayer(depthLayer, "land-structure-polygon");
    } catch (e) {
      return () => {
        map.removeSource(bathymetrySource.id);
      };
    }
    map.triggerRepaint();

    return () => {
      map.removeLayer(depthLayer.id);
      map.removeSource(bathymetrySource.id);
    };
  }, [map, depthLayer, bathymetrySource]);

  return null;
};

export const BathymetryContourBackgroundLayer = () => {
  const mapStyle = useRecoilValue(getActiveMapStyleSelector);
  return <> {mapStyle?.useBathymetry && <BathymetryActive />} </>;
};
