import React, { useState } from "react";
import getAriaColors from "../common/aria-colors";
import BarChart from "../charts/BarChart";
import _ from "lodash";
import { useDateState } from "../../zustand/useDateState";
import { useClusterDriftScoreQuery } from "../../hooks/tanstack-query";
import { cn, type PropsWithCn } from "../../shared-ui/frontend/cn";

const DRIFT_LEVELS = getAriaColors();
DRIFT_LEVELS.sort((a, b) => b.value - a.value); // ascending order

export function ClusterDriftScore({
  selectedClusters,
  className,
}: {
  selectedClusters: string[];
} & PropsWithCn) {
  const [excludedAnomLevels, setExcludedAnomLevels] = useState(
    DRIFT_LEVELS.map((o) => false)
  ); // parallel array starts as all false

  const $ds = useDateState();
  const cdsQuery = useClusterDriftScoreQuery(
    $ds.axisRangeFrom.dateString,
    $ds.axisRangeTo.dateString
  );
  const cdsData = cdsQuery.data || {};
  const allLevelCounts = Object.keys(cdsData)
    .sort()
    .map((day) => {
      const data = Object.fromEntries(
        Object.entries(cdsData[day] || {}).filter(([key]) =>
          selectedClusters.includes(key)
        )
      );
      const levelCounts = {
        value: [0, 1, 2].map(
          (level) =>
            Object.values(data).filter((l) => l.maxLevel === level).length
        ),
        key: new Date(day).toISOString().split("T")[0],
        stripes: false,
      };

      return levelCounts;
    });

  const shownLevelCounts = (levels: number[]) => {
    const shownLevels = _.reverse(levels).filter(
      (l, i) => !excludedAnomLevels[i]
    );
    const hiddenLevel =
      levels.reduce((s, a) => s + a, 0) -
      shownLevels.reduce((s, a) => s + a, 0);
    _.reverse(shownLevels);
    shownLevels.push(hiddenLevel);
    return shownLevels;
  };

  const hiddenLevels = (levels: number[]) => {
    const hiddenLevels = levels.filter((l, i) => excludedAnomLevels[i]);

    return hiddenLevels.reduce((s, a) => s + a, 0);
  };

  /**
   * Must add white at the end because the blocks at the top are white
   */
  const shownColorBars = (levels: typeof DRIFT_LEVELS) => {
    return [
      ..._.reverse(
        levels.map((c) => c.rgb).filter((_, i) => !excludedAnomLevels[i])
      ),
      "#fff",
    ];
  };

  return (
    <div
      className={cn(
        "card border border-zinc-300 h-full rounded-md overflow-visible",
        className
      )}
    >
      <h2 className="card-title bg-base-100 py-1 pr-3 pl-2 flex flex-row justify-between rounded-t-md items-start">
        <span>Cluster Drift Score</span>
        <div className="flex flex-row text-[1.2rem] mr-2 mt-1">
          {DRIFT_LEVELS.map(({ value, color, rgb, label }, idx) => {
            return (
              <div
                className={`border-b-4`}
                style={{ borderColor: rgb }}
                key={rgb}
              >
                <input
                  id={`exclAnomLevel${value + 1}`}
                  type="checkbox"
                  className={`checkbox checkbox-sm mr-1 ml-2 aria-checkbox-${color} z-50 hover:border hover:border-neutral`}
                  checked={!excludedAnomLevels[idx]}
                  onChange={(e) =>
                    setExcludedAnomLevels((curr) =>
                      curr.map((b, i) => (i === idx ? !e.target.checked : b))
                    )
                  }
                />
                <span className="text-[0.9rem] relative bottom-1">{label}</span>
              </div>
            );
          })}
        </div>
      </h2>

      <div className="bg-white h-full rounded-b-md pb-4">
        <BarChart
          classes={undefined}
          onclick={undefined}
          selectedIndex={undefined}
          showTitle={undefined}
          unit={undefined}
          withStripes
          chartTitle="Daily Breakdown"
          xAxisLabel={"Date"}
          yAxisLabel={"Number of variables"}
          selectedRange={[
            $ds.axisRangeTo.dateString,
            $ds.axisRangeTo.dateString,
          ]}
          chartKeys={["level1", "level2", "level3"]} // todo why is this needed?
          barColors={shownColorBars(DRIFT_LEVELS)}
          data={allLevelCounts.map((s) => {
            return {
              ...s,
              value: shownLevelCounts(s.value),
            };
          })}
          // kinda hacky way of getting tooltip value and positioning to account for hidden levels
          hiddenData={allLevelCounts.map((s) => {
            return {
              value: hiddenLevels(s.value),
              key: s.key,
            };
          })}
          range={undefined}
        />
      </div>
    </div>
  );
}
