import { useEffect, useRef, useState } from "react";
import * as d3 from "d3";
import { Badge } from "../../../shared-ui/frontend/badge";
import { getScoreCategory, SCORE_CATEGORIES } from "../utils";
import { useVariablesArrayQuery } from "../../../hooks/tanstack-query";
import { ellipsify } from "../../utils/stylable";

const X_MAX_LABEL_LENGTH = 20;
const Y_MAX_LABEL_LENGTH = 50;

type Props = {
  scores: { tag_id1: string; tag_id2: string; score: number }[];
};

export function HeatMap({ scores }: Props) {
  const [excludedAnomLevels, setExcludedAnomLevels] = useState(
    SCORE_CATEGORIES.map(() => false)
  );
  const [positives, setPositives] = useState(true);
  const [negatives, setNegatives] = useState(true);
  const filter = (d: Props["scores"][0]) =>
    !excludedAnomLevels[
      SCORE_CATEGORIES.indexOf(getScoreCategory(Math.abs(d.score)).start)
    ] &&
    (positives ? true : d.score < 0) &&
    (negatives ? true : d.score > 0);
  const filteredScores = scores.filter(filter);
  const numTags = new Set(
    filteredScores.map((d) => [d.tag_id1, d.tag_id2]).flat()
  );

  return (
    <>
      <div className="flex flex-row justify-between mr-4">
        <div className="flex flex-row gap-2 place-items-center h-fit mr-16">
          <h2 className="text-xl font-medium">Relationship Heatmap</h2>
          <Badge variant={"secondary"}>{numTags.size} tags</Badge>
        </div>
        <div className="flex flex-row gap-2">
          <div className="flex flex-row h-fit">
            {SCORE_CATEGORIES.map((cat, idx) => {
              const color = colorScale(cat);
              const { label } = getScoreCategory(cat);
              return (
                <div
                  className={`border-b-4`}
                  style={{ borderColor: color }}
                  key={cat}
                >
                  <input
                    id={`exclAnomLevel${cat}`}
                    type="checkbox"
                    className={`checkbox checkbox-xs mr-1 ml-2 z-50 hover:border hover:border-neutral`}
                    style={
                      !excludedAnomLevels[idx]
                        ? { borderColor: color, background: color }
                        : {}
                    }
                    checked={!excludedAnomLevels[idx]}
                    onChange={(e) =>
                      setExcludedAnomLevels((curr) =>
                        curr.map((b, i) => (i === idx ? !e.target.checked : b))
                      )
                    }
                  />
                  <span className="text-xs relative bottom-1">{label}</span>
                </div>
              );
            })}
          </div>
          <div className="flex flex-col gap-2">
            <div className="flex flex-row">
              <input
                id={"positives"}
                type="checkbox"
                className={`checkbox checkbox-xs mr-1 ml-2 z-50 hover:border hover:border-neutral`}
                checked={positives}
                onChange={(e) => setPositives(e.target.checked)}
              />
              <label className="text-xs" htmlFor="positives">
                Positive
              </label>
            </div>
            <div className="flex flex-row">
              <input
                id={"negatives"}
                type="checkbox"
                className={`checkbox checkbox-xs mr-1 ml-2 z-50 hover:border hover:border-neutral`}
                checked={negatives}
                onChange={(e) => setNegatives(e.target.checked)}
              />
              <label className="text-xs" htmlFor="negatives">
                Negative (✕)
              </label>
            </div>
          </div>
        </div>
      </div>
      <HeatMapChart
        scores={filteredScores}
        hiddenScores={scores.filter((d) => !filter(d))}
      />
    </>
  );
}
export function HeatMapChart({
  scores,
  hiddenScores,
}: Props & { hiddenScores: Props["scores"] }) {
  const svgRef = useRef<SVGSVGElement>(null);
  const tooltipRef = useRef<HTMLDivElement>(null);
  const tagsQuery = useVariablesArrayQuery();

  useEffect(() => {
    const allTags = tagsQuery?.data || []; // for getting descriptions for tooltip
    // margin, height, width adjusted after rendering x-axis labels
    // so we know how much space they need
    let margin = {
      top: 2,
      right: 2,
      bottom: 50,
      left: 50,
    };
    let width = 800 - margin.right - margin.left;
    let height = 800 - margin.top - margin.bottom;
    const svg = d3
      .select(svgRef.current)
      .attr("width", width + margin.left + margin.right)
      .attr("height", height + margin.top + margin.bottom);
    svg.selectAll("*").remove();
    const container = svg
      .append("g")
      .attr("transform", `translate(${margin.left},${margin.top})`);

    // Compute average scores for each tag
    const tagScoresMap = new Map<string, number[]>();
    const tags = Array.from(
      new Set(scores.flatMap((d) => [d.tag_id1, d.tag_id2]))
    );
    tags.forEach((tag) => {
      tagScoresMap.set(tag, []);
    });
    // For each pairwise score, add the score to both tags
    scores.forEach((d) => {
      tagScoresMap.get(d.tag_id1)!.push(Math.abs(d.score));
      tagScoresMap.get(d.tag_id2)!.push(Math.abs(d.score));
    });
    // Compute average score for each tag
    const tagAvgScores = tags.map((tag) => {
      const scores = tagScoresMap.get(tag)!;
      const avgScore =
        scores.reduce((sum, score) => sum + score, 0) / scores.length;
      return { tag: tag, avgScore: avgScore };
    });
    // Sort tags by average score
    tagAvgScores.sort((a, b) => a.avgScore - b.avgScore);
    // Get the sorted list of tags
    const sortedTags = tagAvgScores.map((d) => d.tag);
    // Generate all possible pairs (including self-pairs)
    const allPairs: { tag_id1: string; tag_id2: string }[] = [];
    for (let i = 0; i < sortedTags.length; i++) {
      for (let j = 0; j < sortedTags.length; j++) {
        allPairs.push({
          tag_id1: sortedTags[i] || "",
          tag_id2: sortedTags[j] || "",
        });
      }
    }
    // Create heatmap data with scores, defaulting to 1 for self-pairs and undefined for missing data
    const heatmapData = allPairs.map((pair) => {
      const scoreEntry = scores.find(
        (d) =>
          (d.tag_id1 === pair.tag_id1 && d.tag_id2 === pair.tag_id2) ||
          (d.tag_id1 === pair.tag_id2 && d.tag_id2 === pair.tag_id1)
      );

      return {
        tag_id1: pair.tag_id1,
        tag_id2: pair.tag_id2,
        score: scoreEntry
          ? scoreEntry.score
          : pair.tag_id1 === pair.tag_id2
            ? 100
            : undefined,
      };
    });

    // Define x and y scales
    const x = d3
      .scaleBand<string>()
      // .range([0, width])
      .range([width, 0]) // reversed
      .domain(sortedTags)
      .padding(0.01);
    const y = d3
      .scaleBand<string>()
      .range([height, 0]) // Reverse the y-axis to have (0,0) at the top-left
      .domain(sortedTags)
      .padding(0.01);

    const drawAxis = () => {
      // Add the x-axis
      const xAxis = container
        .append("g")
        .attr("transform", `translate(0, ${height})`)
        .call(
          d3
            .axisBottom(x)
            .tickFormat((d) =>
              ellipsify(
                allTags.find((t) => t._id === d)?.trimmedName || d,
                X_MAX_LABEL_LENGTH
              )
            )
        );
      // for highlighting on hover
      xAxis.selectAll(".tick").attr("data-tag-id", (d) => d as string);
      const rotateText = scores.length > 50; // if it's a big chart
      if (rotateText) {
        xAxis
          .selectAll("text")
          .attr("transform", "rotate(-90)")
          // .attr("x", -x.bandwidth() / 2)
          .attr("x", -9)
          .attr("y", 2.5)
          .attr("dy", "0")
          .style("text-anchor", "end")
          .style("font-size", "10px");
      }
      // Add the y-axis
      const yAxis = container
        .append("g")
        .call(
          d3
            .axisLeft(y)
            .tickFormat((d) =>
              ellipsify(
                allTags.find((t) => t._id === d)?.trimmedName || d,
                Y_MAX_LABEL_LENGTH
              )
            )
        );
      yAxis.selectAll(".tick").attr("data-tag-id", (d) => d as string);
      if (rotateText) {
        yAxis.selectAll("text").style("font-size", "10px");
      }

      return { xAxis, yAxis };
    };
    let { xAxis, yAxis } = drawAxis();
    // Measure the maximum width of labels
    const yLabelWidth = d3.max(
      yAxis
        .selectAll("text")
        .style("font-weight", 900)
        .nodes()
        .map((node) => (node as SVGTextElement).getBBox().width)
    );
    const xLabelWidth = d3.max(
      xAxis
        .selectAll("text")
        .style("font-weight", 900)
        .nodes()
        .map((node) => (node as SVGTextElement).getBBox().width)
    );
    margin.left = Math.max(margin.left, (yLabelWidth || 0) + 10);
    margin.bottom = Math.max(margin.bottom, (xLabelWidth || 0) + 10);
    height = 800 - margin.top - margin.bottom;
    width = 800 - margin.right - margin.left;
    width = width + (height - width); // make square
    x.range([width, 0]);
    y.range([height, 0]);
    svg.attr("width", width + margin.left + margin.right);
    container.attr("transform", `translate(${margin.left},${margin.top})`);
    container.selectAll("g").remove();
    ({ xAxis, yAxis } = drawAxis());

    // Create a tooltip div that is hidden by default
    const tooltipContainer = d3.select(tooltipRef.current);
    const tooltip = tooltipContainer.style("opacity", 0);

    // Add the rectangles for the heatmap
    container
      .selectAll()
      .data(heatmapData)
      .enter()
      .append("rect")
      .attr("x", (d) => x(d.tag_id1)!)
      .attr("y", (d) => y(d.tag_id2)!)
      .attr("width", x.bandwidth())
      .attr("height", y.bandwidth())
      .style("fill", (d) => {
        return d.score !== undefined ? colorScale(Math.abs(d.score)) : "#fff";
      })
      .on("mouseover", function (event, d) {
        tooltip.style("opacity", 1);
        d3.select(this).style("stroke", "black").style("opacity", 0.8);
        // Bold the corresponding x-axis label
        xAxis
          .selectAll(".tick")
          .filter((tick_d) => tick_d === d.tag_id1)
          .select("text")
          .style("font-weight", 900);
        // Bold the corresponding y-axis label
        yAxis
          .selectAll(".tick")
          .filter((tick_d) => tick_d === d.tag_id2)
          .select("text")
          .style("font-weight", 900);
      })
      .on("mousemove", function (event, d) {
        const [mouseX, mouseY] = d3.pointer(event);

        tooltip
          .html(
            `<div class="mb-2 flex items-center border-b border-inherit pb-2 gap-2">
              <span class="text-md leading-2 whitespace-nowrap font-semibold">${
                d.score !== undefined
                  ? d.score
                  : hiddenScores.find((s) => {
                      return (
                        (s.tag_id1 === d.tag_id1 && s.tag_id2 === d.tag_id2) ||
                        (s.tag_id1 === d.tag_id2 && s.tag_id2 === d.tag_id1)
                      );
                    })?.score || "N/A"
              }</span>
              <span class="col-start-1 justify-self-start whitespace-nowrap uppercase text-amber-500">
                  ${getScoreCategory(Math.abs(d.score || 0)).label}
                </span>
            </div>
            <div class="flex flex-col min-w-max">
              <div>
                ${allTags.find((t) => t._id === d.tag_id2)?.nameWithDescription}
              </div>
              <div>
                ${allTags.find((t) => t._id === d.tag_id1)?.nameWithDescription}
              </div>
            </div>
            `
          )
          .style("left", mouseX + margin.left + 15 + "px")
          .style("top", mouseY + margin.top - 28 + "px");
      })
      .on("mouseleave", function (event, d) {
        tooltip.style("opacity", 0);
        d3.select(this).style("stroke", "none").style("opacity", 1);
        // Reset the x-axis label style
        xAxis
          .selectAll(".tick")
          .filter((tick_d) => tick_d === d.tag_id1)
          .select("text")
          .style("font-weight", null);
        // Reset the y-axis label style
        yAxis
          .selectAll(".tick")
          .filter((tick_d) => tick_d === d.tag_id2)
          .select("text")
          .style("font-weight", null);
      });

    // Optionally, add labels to the cells
    container
      .selectAll()
      .data(heatmapData.filter((d) => (d.score || 0) < 0))
      .enter()
      .append("text")
      .style("font-size", "10px")
      .attr("x", (d) => x(d.tag_id1)! + x.bandwidth() / 2)
      .attr("y", (d) => y(d.tag_id2)! + y.bandwidth() / 2)
      .attr("text-anchor", "middle")
      .attr("dominant-baseline", "central")
      .style("fill", (d) =>
        d.score !== undefined && d.score > 0.5 ? "white" : "black"
      )
      .style("pointer-events", "none")
      // .text("✖");
      .text("✕");
  }, [scores, hiddenScores, tagsQuery.data]);

  return (
    <div className="relative mx-auto">
      <svg ref={svgRef} />
      <div
        ref={tooltipRef}
        className="absolute pointer-events-none flex flex-col rounded-md border border-xslate-6 bg-zinc-950/70 p-2 text-xs text-white"
      />
    </div>
  );
}

// option 1
// export const colorScale = d3
//   .scaleSequential(d3.interpolateRdYlGn)
//   .domain([0, 100]);
// option 2
// const scoreScale = d3
//   .scalePow()
//   // .exponent(1) // Adjust the exponent to tweak the non-linearity
//   .domain([0, 100])
//   .range([0, 100]);
// export const colorScale = d3
//   .scaleSequential<string>()
//   .domain([0, 100])
//   .interpolator((t: number) => d3.interpolateRdYlGn(scoreScale(t)));
// option 3
function customInterpolator(t: number): string {
  if (t < 40) return d3.interpolateRdYlGn(0);
  else if (t < 70) return d3.interpolateRdYlGn(0.25);
  else if (t < 90) return d3.interpolateRdYlGn(0.5);
  else if (t < 95) return d3.interpolateRdYlGn(0.75);
  else return d3.interpolateRdYlGn(1);
}
export const colorScale = d3
  .scaleSequential<string>()
  .domain([0, 1])
  .interpolator(customInterpolator);
