import { DefaultMap } from "lib/DefaultMap";
import { UnionFind } from "lib/UnionFind";
import { Tree } from "lib/tree";
import { sum } from "utils/utils";
import { CableType } from "../../services/cableTypeService";
import {
  CableFeature,
  SubstationFeature,
  TurbineFeature,
} from "../../types/feature";
import { SimpleTurbineType } from "../../types/turbines";
import { isDefined } from "../../utils/predicates";

export class CableCycleError extends Error {
  constructor(public msg?: string) {
    super(msg);
  }
}

export function makeCableForest(
  cables: CableFeature[],
  substations: SubstationFeature[],
  turbines: TurbineFeature[],
): Tree<
  SubstationFeature | { cable: CableFeature; turbine: TurbineFeature }
>[] {
  // Map the id of a turbine or substation to a list of the cables connected to it.
  const node2cables = new DefaultMap<string, CableFeature[]>(() => []);
  for (const c of cables) {
    node2cables.get(c.properties.fromId).push(c);
    node2cables.get(c.properties.toId).push(c);
  }

  const seen = new Set<string>(); // Track nodes we've seen so we only add them once
  function makeTreeNode(
    f: SubstationFeature | { cable: CableFeature; turbine: TurbineFeature },
  ): Tree<
    SubstationFeature | { cable: CableFeature; turbine: TurbineFeature }
  > {
    const children = [];
    const id = "id" in f ? f.id : f.turbine.id;

    for (const cable of node2cables.get(id)) {
      if (seen.has(cable.id)) continue;
      seen.add(cable.id);

      const otherId = // Find the other feature that the cable is connected to
        cable.properties.fromId === id
          ? cable.properties.toId
          : cable.properties.fromId;
      const turbine = turbines.find((t) => t.id === otherId);
      if (turbine) children.push({ cable, turbine });
    }

    return new Tree(
      f,
      children.map((c) => makeTreeNode(c)),
    );
  }

  return substations.map((sub) => makeTreeNode(sub));
}

export function checkForCycles(
  cables: CableFeature[],
): CableFeature | undefined {
  const uf = new UnionFind<string>();
  for (const c of cables) {
    const { fromId, toId } = c.properties;
    if (uf.find(fromId) === uf.find(toId)) return c;
    uf.union(fromId, toId);
  }
}

export function addCableTypes(
  cables: CableFeature[],
  cableTypes: CableType[],
): CableFeature[] {
  const sortedCableTypes = [...cableTypes].sort(
    (a, b) => a.powerRating - b.powerRating,
  );
  const powerLoads = [
    ...new Set(cables.map((c) => c.properties.powerLoad).filter(isDefined)),
  ];
  const powerLoadToCableType = new Map<number, string | undefined>(
    powerLoads.map((p) => [
      p,
      sortedCableTypes.find((ct) => ct.powerRating / 1e6 >= p)?.id,
    ]),
  );

  return cables.map((c) => ({
    ...c,
    properties: {
      ...c.properties,
      cableTypeId: c.properties.powerLoad
        ? powerLoadToCableType.get(c.properties.powerLoad)
        : undefined,
    },
  }));
}

/**
 * Adds `properties.powerLoad` to the given cables. The returned cables are in
 * the same order as the input cables.
 */
export function addCableLoads(
  cables: CableFeature[],
  substations: SubstationFeature[],
  turbines: TurbineFeature[],
  turbineTypes: SimpleTurbineType[],
): CableFeature[] {
  if (checkForCycles(cables) !== undefined)
    throw new CableCycleError(
      "Substations are connected with cables; this is not supported.",
    );

  const forest = makeCableForest(cables, substations, turbines);
  const turbineLoad = new Map(turbineTypes.map((t) => [t.id, t.ratedPower]));

  const loads = forest.map((tree) =>
    tree.transformUp<[string, number]>((n, children) => {
      const s = sum(children, (n) => n.data[1]);
      if ("id" in n) return ["", s];
      const rated = turbineLoad.get(n.turbine.properties.turbineTypeId);
      if (rated === undefined) throw new Error("Illegal turbine type id");
      return [n.cable.id, s + rated];
    }),
  );
  const loadMap = new Map(loads.flatMap((t) => t.flatten()));

  return cables.map((f) => {
    return {
      ...f,
      properties: {
        ...f.properties,
        powerLoad: Math.round((loadMap.get(f.id) ?? 0) / 100) / 10,
      },
    };
  });
}
