import { useEffect, useMemo } from "react";
import { useRecoilValue } from "recoil";
import {
  costLayerVariablesAtom,
  costLayerRangeAtom,
  costLayerFilterAtom,
  lowerRightMenuActiveModeAtom,
  CostLayerVariables,
} from "../state/layer";
import { mapRefAtom } from "../state/map";
import { CustomLayerInterface, RasterSource } from "mapbox-gl";
import { LAYER_DEBUG_PRINT } from "../state/debug";
import { scream } from "../utils/sentry";

export type RasterSourceLiteral = RasterSource & {
  id: string;
} & ({ tiles: string[] } | { url: string });

const tileSize = 512;
const maxzoom = 6;

export const CostLayerId = "custom-combination-layer";

const bathymetrySource: RasterSourceLiteral = {
  id: "bathymetry-source-cost",
  type: "raster",
  tiles: [`/tiles/gebco-terrarium-2023/{z}/{x}/{y}.png`],
};

const shoreDistanceSource: RasterSourceLiteral = {
  id: "shoredistance-cost",
  type: "raster",
  tiles: [`/tiles/shore/{z}/{x}/{y}.png`],
};

const weibullSource: RasterSourceLiteral = {
  id: "weibull-source-cost",
  type: "raster",
  tiles: [`/tiles/gwa/capacity-iec2/{z}/{x}/{y}.png`],
};

class Layer implements CustomLayerInterface {
  id: string;
  type: "custom";
  depthSourceId: any;
  input: CostLayerVariables | undefined;
  shoreDistanceSourceId: any;
  weibullSourceId: any;
  range: number[];
  filter: number[];
  map: any;
  depthSourceCache: any;
  weibullSourceCache: any;
  shoreDistanceSourceCache: any;
  vertexArray: Int16Array | undefined;
  vertexBuffer: any;
  indexArray: Uint16Array | undefined;
  indexBuffer: any;
  program: any;
  aPos: any;

  constructor(
    depthSourceId: string,
    shoreDistanceSourceId: string,
    weibullSourceId: string,
  ) {
    this.id = CostLayerId;
    this.type = "custom";

    this.depthSourceId = depthSourceId;
    this.shoreDistanceSourceId = shoreDistanceSourceId;
    this.weibullSourceId = weibullSourceId;

    this.range = [0, 100];
    this.filter = [0, 100];
  }

  updateVariables(
    input: CostLayerVariables,
    filter: number[],
    range: number[],
  ) {
    this.input = input;
    this.range = range;
    this.filter = filter;
  }

  onAdd(map: mapboxgl.Map, gl: WebGLRenderingContext) {
    LAYER_DEBUG_PRINT && console.log("CostLayer.onAdd");
    this.map = map;
    const style = (map as any).style;
    this.depthSourceCache = style._otherSourceCaches[this.depthSourceId];
    this.weibullSourceCache = style._otherSourceCaches[this.weibullSourceId];
    this.shoreDistanceSourceCache =
      style._otherSourceCaches[this.shoreDistanceSourceId];

    this.depthSourceCache.pause();
    this.shoreDistanceSourceCache.pause();
    this.weibullSourceCache.pause();

    this.prepareShaders(gl);
    this.prepareBuffers(gl);
  }

  update() {
    const transform = this.map.transform.clone();
    const pitchOffset =
      transform.cameraToCenterDistance * Math.sin(transform._pitch);
    transform.height = transform.height + pitchOffset;

    this.depthSourceCache._paused = false;
    this.shoreDistanceSourceCache._paused = false;
    this.weibullSourceCache._paused = false;

    this.depthSourceCache.used = true;
    this.shoreDistanceSourceCache.used = true;
    this.weibullSourceCache.used = true;

    this.depthSourceCache.update(transform);
    this.shoreDistanceSourceCache.update(transform);
    this.weibullSourceCache.update(transform);

    this.depthSourceCache.pause();
    this.shoreDistanceSourceCache.pause();
    this.weibullSourceCache.pause();
  }

  prepareBuffers(gl: WebGLRenderingContext) {
    const n = 64;

    this.vertexArray = new Int16Array(n * n * 2);
    for (let i = 0; i < n; i++) {
      for (let j = 0; j < n; j++) {
        const vertex = [j * (8192 / (n - 1)), i * (8192 / (n - 1))];
        const offset = (i * n + j) * 2;
        this.vertexArray.set(vertex, offset);
      }
    }

    this.vertexBuffer = gl.createBuffer();
    gl.bindBuffer(gl.ARRAY_BUFFER, this.vertexBuffer);
    gl.bufferData(gl.ARRAY_BUFFER, this.vertexArray.buffer, gl.STATIC_DRAW);

    this.indexArray = new Uint16Array((n - 1) * (n - 1) * 6);
    let offset = 0;
    for (let i = 0; i < n - 1; i++) {
      for (let j = 0; j < n - 1; j++) {
        const index = i * n + j;
        const quad = [
          index,
          index + 1,
          index + n,
          index + n,
          index + 1,
          index + n + 1,
        ];
        this.indexArray.set(quad, offset);
        offset += 6;
      }
    }

    this.indexBuffer = gl.createBuffer();
    gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, this.indexBuffer);
    gl.bufferData(
      gl.ELEMENT_ARRAY_BUFFER,
      this.indexArray.buffer,
      gl.STATIC_DRAW,
    );
  }

  prepareShaders(gl: WebGLRenderingContext) {
    var vertexSource = `
        uniform mat4 u_matrix;
        attribute vec2 a_pos;
        varying vec2 v_pos;
        void main() {
            v_pos = vec2(a_pos / 8192.0);
            gl_Position = u_matrix * vec4(a_pos, .0, 1.0);
        }`;

    var fragmentSource = `
        precision highp float;
        const float years = 25.;
        const float MwPerTurbine = 15.;
        const float discount_rate = 0.05;

        uniform sampler2D u_depth_raster;
        uniform sampler2D u_shoreline_distance_raster;
        uniform sampler2D u_weibull_raster;

        uniform float u_turbine_cost;
        uniform float u_fixed_cost;
        uniform float u_fixed_cost_depth;
        uniform float u_floating_cost;
        uniform float u_floating_cost_depth;
        uniform float u_fixed_to_floating_depth;
        uniform float u_export_cable_cost_shore_distance;
        uniform float u_opex_per_mw;
        uniform vec2 u_range;
        uniform vec2 u_filter;
        varying vec2 v_pos;

        float getDepth(vec2 coord) {
            vec4 color = texture2D(u_depth_raster, coord);
            float R = color.r * 255.0;
            float G = color.g * 255.0;
            float B = color.b * 255.0;
            return ((R * 256.0 + G + B / 256.0) - 32768.0) * -1.;
        }

        float getShorelineDistanceKM(vec2 coord) {
            return texture2D(u_shoreline_distance_raster, v_pos).r * 256.;
        }

        void main() {
            float depth = getDepth(v_pos);
            float shoreLineDistance = getShorelineDistanceKM(v_pos);

            vec4 gwa = texture2D(u_weibull_raster, v_pos);
            float capacity = gwa.r * 0.8;
            float power = capacity * 15.;

            float capex = u_turbine_cost + 
              (u_fixed_cost + u_fixed_cost_depth * depth) * step(depth, u_fixed_to_floating_depth) +
              (u_floating_cost + u_floating_cost_depth * depth) * step(u_fixed_to_floating_depth,depth) +
              u_export_cable_cost_shore_distance * shoreLineDistance;

            float opex_per_year = u_opex_per_mw;
            float opex = opex_per_year * (1. - pow(1. + discount_rate, -years)) / discount_rate;
            float eur = 1000. * (capex + opex);
            float mwh = power * 8766. * years;
            float y = MwPerTurbine * eur / mwh;

            float green_limit = u_range.x;
            float red_limit = u_range.y;

            float x = (y - green_limit) / (red_limit - green_limit);

            float r = 2.* x;
            float g = 2. * (1.-x);
            float b = 0.;

            float alpha = step(1., shoreLineDistance) * gwa.a;
            alpha *= step(u_filter.x, y);
            alpha *= 1. - step(u_filter.y, y);

            gl_FragColor = vec4(r, g, b, alpha);
        }`;

    var vertexShader = gl.createShader(gl.VERTEX_SHADER);
    if (!vertexShader)
      throw scream("Failed to create vertex shader in costLayer");
    gl.shaderSource(vertexShader, vertexSource);
    gl.compileShader(vertexShader);
    var fragmentShader = gl.createShader(gl.FRAGMENT_SHADER);
    if (!fragmentShader)
      throw scream("Failed to create fragmentShader shader in costLayer");
    gl.shaderSource(fragmentShader, fragmentSource);
    gl.compileShader(fragmentShader);

    this.program = gl.createProgram();
    gl.attachShader(this.program, vertexShader);
    gl.attachShader(this.program, fragmentShader);
    gl.linkProgram(this.program);

    this.aPos = gl.getAttribLocation(this.program, "a_pos");
  }

  render(gl: WebGLRenderingContext, _matrix: number[]) {
    if (!this.input || !this.indexArray) return;
    LAYER_DEBUG_PRINT && console.time("CostLayer.render");
    gl.useProgram(this.program);

    gl.enable(gl.BLEND);
    gl.blendFunc(gl.SRC_ALPHA, gl.ONE_MINUS_SRC_ALPHA);

    // Bind vertex buffer
    gl.bindBuffer(gl.ARRAY_BUFFER, this.vertexBuffer);
    gl.enableVertexAttribArray(this.aPos);
    gl.vertexAttribPointer(this.aPos, 2, gl.SHORT, false, 0, 0);

    // Bind index buffer
    gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, this.indexBuffer);

    this.update();

    gl.uniform2fv(gl.getUniformLocation(this.program, "u_range"), this.range);
    gl.uniform2fv(gl.getUniformLocation(this.program, "u_filter"), this.filter);
    gl.uniform1f(
      gl.getUniformLocation(this.program, "u_turbine_cost"),
      this.input.turbinesPerMw,
    );
    gl.uniform1f(
      gl.getUniformLocation(this.program, "u_fixed_cost"),
      this.input.fixedFoundationPerMw,
    );
    gl.uniform1f(
      gl.getUniformLocation(this.program, "u_fixed_cost_depth"),
      this.input.fixedFoundationPerMwDepth,
    );
    gl.uniform1f(
      gl.getUniformLocation(this.program, "u_floating_cost"),
      this.input.floatingFoundationPerMw,
    );
    gl.uniform1f(
      gl.getUniformLocation(this.program, "u_floating_cost_depth"),
      this.input.floatingFoundationPerMwDepth,
    );
    gl.uniform1f(
      gl.getUniformLocation(this.program, "u_fixed_to_floating_depth"),
      this.input.fixedFoundationMaxDepth,
    );
    gl.uniform1f(
      gl.getUniformLocation(this.program, "u_export_cable_cost_shore_distance"),
      this.input.exportCablePerShoreDistance,
    );
    gl.uniform1f(
      gl.getUniformLocation(this.program, "u_opex_per_mw"),
      this.input.opexPerMw,
    );

    let coords = this.depthSourceCache.getVisibleCoordinates().reverse();
    for (const coord of coords) {
      const depthTile = this.depthSourceCache.getTile(coord);
      const weibullTile = this.weibullSourceCache.getTile(coord);
      const shoreLineDistanceTile =
        this.shoreDistanceSourceCache.getTile(coord);

      // Bind depth raster texture to unit 0
      gl.activeTexture(gl.TEXTURE0);
      gl.bindTexture(gl.TEXTURE_2D, depthTile.texture.texture);
      gl.uniform1i(gl.getUniformLocation(this.program, "u_depth_raster"), 0);
      gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
      gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
      gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
      gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);

      // Bind shore line distance raster texture to unit 1
      gl.activeTexture(gl.TEXTURE1);
      gl.bindTexture(
        gl.TEXTURE_2D,
        shoreLineDistanceTile && shoreLineDistanceTile.texture
          ? shoreLineDistanceTile.texture.texture
          : gl.createTexture(),
      );
      gl.uniform1i(
        gl.getUniformLocation(this.program, "u_shoreline_distance_raster"),
        1,
      );
      gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
      gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
      gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
      gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);

      // Bind weibull raster texture to unit 2
      gl.activeTexture(gl.TEXTURE2);
      gl.bindTexture(
        gl.TEXTURE_2D,
        weibullTile && weibullTile.texture
          ? weibullTile.texture.texture
          : gl.createTexture(),
      );
      gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
      gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
      gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
      gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
      gl.uniform1i(gl.getUniformLocation(this.program, "u_weibull_raster"), 2);

      // Bind matrix
      gl.uniformMatrix4fv(
        gl.getUniformLocation(this.program, "u_matrix"),
        false,
        coord.projMatrix,
      );

      // Draw
      const vertexCount = this.indexArray.length;
      const type = gl.UNSIGNED_SHORT;
      const offset = 0;
      gl.drawElements(gl.TRIANGLES, vertexCount, type, offset);
    }
    LAYER_DEBUG_PRINT && console.timeEnd("CostLayer.render");
  }
}

const CostLayer = () => {
  const lowerRightActiveMode = useRecoilValue(lowerRightMenuActiveModeAtom);
  if (lowerRightActiveMode === "cost") return <CostLayerActive />;
  return null;
};

const CostLayerActive = () => {
  const map = useRecoilValue(mapRefAtom);
  const costLayerVariables = useRecoilValue(costLayerVariablesAtom);
  const costFilter = useRecoilValue(costLayerFilterAtom);
  const costRange = useRecoilValue(costLayerRangeAtom);

  const bathymetry = useMemo(() => bathymetrySource, []);
  const weibull = useMemo(() => weibullSource, []);
  const shoreDistance = useMemo(() => shoreDistanceSource, []);

  const layer = useMemo(
    () => new Layer(bathymetry.id, shoreDistance.id, weibull.id),
    [bathymetry.id, shoreDistance.id, weibull.id],
  );

  useEffect(() => {
    if (!map || !layer) return;
    layer.updateVariables(costLayerVariables, costFilter, costRange);
    map.triggerRepaint();
  }, [map, costLayerVariables, layer, costFilter, costRange]);

  useEffect(() => {
    if (!map) return;

    map.addSource(bathymetry.id, {
      type: bathymetry.type,
      tiles: bathymetry.tiles,
      tileSize,
      maxzoom,
    });
    map.addSource(shoreDistance.id, {
      type: shoreDistance.type,
      tiles: shoreDistance.tiles,
      tileSize,
      maxzoom,
    });
    map.addSource(weibull.id, {
      type: weibull.type,
      tiles: weibull.tiles,
      tileSize,
      maxzoom,
    });
    map.addLayer(layer, "building");

    return () => {
      map.removeLayer(layer.id);
      map.removeSource(bathymetry.id);
      map.removeSource(shoreDistance.id);
      map.removeSource(weibull.id);
    };
  }, [map, layer, bathymetry, shoreDistance, weibull]);

  return null;
};

export default CostLayer;
