import MainLayout from "../layouts/MainLayout";
import useDocumentTitle from "../common/hooks/useDocumentTitle";
import { useGroupsWithSections } from "../groups/manager/use-groups-with-sections";
import {
  CollapseAllButton,
  ExpandAllButton,
  SectionContainer,
} from "../ov/unit-ov/unit-overview";
import { ViewModeSelectors } from "../common/view-mode-selectors";
import { cn } from "../../lib/utils";
import {
  useGetUseViewModeStore,
  UseViewModeStoreProvider,
} from "../../shared-ui/time-series-2/grid-view-store";
import { useEffect, useState } from "react";
import Footer from "../nav/Footer";
import { useDateState } from "../../zustand/useDateState";
import { LinkWithQuery } from "../nav/LinkWithQuery2";
import BarChart from "../charts/BarChart";
import {
  useAriaClustersQuery,
  useClusterDriftScoreQuery,
} from "../../hooks/tanstack-query";
import getAriaColors from "../common/aria-colors";
import _ from "lodash";
import { Group } from "../../lib/api-schema/group";
import { Switch } from "../../shared-ui/frontend/switch";
import "../../styles/query-empty.scss";

const DRIFT_LEVELS = getAriaColors();
DRIFT_LEVELS.sort((a, b) => b.value - a.value); // ascending order

function Overview() {
  useDocumentTitle("ARIA > DRA");

  const useViewModeStore = useGetUseViewModeStore();
  const viewMode = useViewModeStore((s) => s.viewMode);
  const setViewMode = useViewModeStore((s) => s.setViewMode);
  const numCols = useViewModeStore((s) => s.numCols);
  const setNumCols = useViewModeStore((s) => s.setNumCols);

  const groupsWithSections = useGroupsWithSections();
  const showExpandCollapseAll =
    groupsWithSections && groupsWithSections.sectionsWithGroups.length > 0;
  const numGroups =
    (groupsWithSections?.sectionsWithGroups.flatMap((s) => s.groups).length ||
      0) + (groupsWithSections?.remainingGroups.length || 0);

  const [showEmptyCharts, setShowEmptyCharts] = useState(false);

  const [excludedAnomLevels, setExcludedAnomLevels] = useState(
    DRIFT_LEVELS.map((_) => false)
  ); // parallel array starts as all false

  // dynamically set number of columns based on number of groups and screen width
  useEffect(() => {
    const dynamicallySetNumCols = (screenWidth: number, numGroups: number) => {
      if (screenWidth < 768) {
        setNumCols(2);
      } else if (screenWidth < 1024) {
        setNumCols(3);
      } else if (numGroups < 20) {
        setNumCols(4);
      } else {
        setNumCols(5);
      }
    };

    const handleResize = () => {
      dynamicallySetNumCols(window.innerWidth, numGroups);
    };
    handleResize();
    window.addEventListener("resize", handleResize);
    return () => window.removeEventListener("resize", handleResize);
  }, [numGroups, setNumCols]);

  return (
    <MainLayout showDateNav={true}>
      <div className="bg-bggrey min-h-[60vh] flex flex-col">
        {/* Grid of groups + chart for each group. */}
        <div className="flex py-2 mb-1 px-4 gap-2 items-center">
          <h2 className="tracking-tight text-2xl font-semibold">
            ARIA Overview
          </h2>

          <ViewModeSelectors
            className="ml-auto"
            withLabels
            variant={"default"}
            enabledModes={["grid"]}
            viewMode={viewMode}
            setViewMode={setViewMode}
            numCols={numCols}
            setNumCols={setNumCols}
          />
          <div className="flex gap-2 items-center bg-white rounded-md px-3 py-1 border border-xslate-5">
            {DRIFT_LEVELS.map(({ value, color, rgb, label }, idx) => {
              return (
                <div key={rgb} className="flex place-items-center">
                  <input
                    id={`exclAnomLevel${value + 1}`}
                    type="checkbox"
                    className={`checkbox checkbox-sm mr-1 ml-2 aria-checkbox-${color} z-15 hover:border hover:border-neutral`}
                    checked={!excludedAnomLevels[idx]}
                    onChange={(e) =>
                      setExcludedAnomLevels((curr) =>
                        curr.map((b, i) => (i === idx ? !e.target.checked : b))
                      )
                    }
                  />
                  <span className="text-[0.9rem]">{label}</span>
                </div>
              );
            })}
          </div>
        </div>

        <div
          className="absolute left-0 right-0 top-40 bg-bggrey px-8 w-fit"
          style={{ marginInline: "auto" }}
        >
          <h3 className="text-2xl font-semibold text-xslate-11">
            No clusters to display
          </h3>
        </div>
        <div className="px-8 flex flex-col gap-3 bg-bggrey z-10">
          {groupsWithSections && (
            <div className="flex flex-col gap-3">
              {groupsWithSections.remainingGroups.length > 0 && (
                <>
                  <div
                    className={`grid grid-cols-${numCols} auto-rows-max gap-8 my-4 QUERY_EMPTY`}
                  >
                    {groupsWithSections.remainingGroups.map((g, i) => (
                      <OneChart
                        key={g._id}
                        group={g}
                        excludedAnomLevels={excludedAnomLevels}
                        showEmpty={showEmptyCharts}
                      ></OneChart>
                    ))}
                  </div>
                  <div className="h-[1px] bg-xslate-6 HIDE_IF_EMPTY" />
                </>
              )}
              {groupsWithSections?.sectionsWithGroups.map(
                ({ section, groups }, idx) => {
                  return (
                    <SectionContainer
                      key={section._id}
                      section={section}
                      hasHorizontalLine={
                        idx !==
                        groupsWithSections?.sectionsWithGroups.length - 1
                      }
                    >
                      <div
                        className={`grid grid-cols-${numCols} auto-rows-max gap-8 my-4`}
                      >
                        {groups.map((g, i) => (
                          <OneChart
                            key={g._id}
                            group={g}
                            excludedAnomLevels={excludedAnomLevels}
                            showEmpty={showEmptyCharts}
                          ></OneChart>
                        ))}
                      </div>
                    </SectionContainer>
                  );
                }
              )}
            </div>
          )}
        </div>

        <div className="px-8 flex flex-row gap-3 mb-4 mt-8 place-items-center">
          {showExpandCollapseAll && (
            <>
              <ExpandAllButton />
              <CollapseAllButton />
            </>
          )}
          <label htmlFor="aria-empty-chart-toggle" className="text-sm ml-2">
            Show Groups without Clusters
          </label>
          <Switch
            id="aria-empty-chart-toggle"
            checked={showEmptyCharts}
            onCheckedChange={() => setShowEmptyCharts((prev) => !prev)}
          />
        </div>
      </div>
      <Footer className="mt-auto" />
    </MainLayout>
  );
}

export default function Wrapped() {
  return (
    <UseViewModeStoreProvider>
      <Overview />
    </UseViewModeStoreProvider>
  );
}

function OneChart({
  group,
  excludedAnomLevels,
  showEmpty,
}: {
  group: Group;
  excludedAnomLevels: boolean[];
  showEmpty: boolean;
}) {
  const clustersQuery = useAriaClustersQuery();
  const clusters = clustersQuery.data || [];
  const $ds = useDateState();
  const overlappingClusters = clusters
    .filter(
      (c) =>
        c.variables.some((vid) => group.variables.includes(vid)) &&
        c.aria_enabled &&
        c.type === "static"
    )
    .map((c) => c._id);

  const cdsQuery = useClusterDriftScoreQuery(
    $ds.axisRangeFrom.dateString,
    $ds.axisRangeTo.dateString
  );
  const cdsData = cdsQuery.data || {};
  const allLevelCounts = Object.keys(cdsData)
    .sort()
    .map((day) => {
      const data = Object.fromEntries(
        Object.entries(cdsData[day] || {}).filter(([key]) =>
          overlappingClusters.includes(key)
        )
      );
      const levelCounts = {
        value: [0, 1, 2].map(
          (level) =>
            Object.values(data).filter((l) => l.maxLevel === level).length
        ),
        key: new Date(day).toISOString().split("T")[0],
        stripes: false,
      };

      return levelCounts;
    });

  const shownLevelCounts = (levels: number[]) => {
    const shownLevels = _.reverse(levels).filter(
      (l, i) => !excludedAnomLevels[i]
    );
    const hiddenLevel =
      levels.reduce((s, a) => s + a, 0) -
      shownLevels.reduce((s, a) => s + a, 0);
    _.reverse(shownLevels);
    shownLevels.push(hiddenLevel);
    return shownLevels;
  };

  const hiddenLevels = (levels: number[]) => {
    const hiddenLevels = levels.filter((l, i) => excludedAnomLevels[i]);

    return hiddenLevels.reduce((s, a) => s + a, 0);
  };

  /**
   * Must add white at the end because the blocks at the top are white
   */
  const shownColorBars = (levels: typeof DRIFT_LEVELS) => {
    return [
      ..._.reverse(
        levels.map((c) => c.rgb).filter((_, i) => !excludedAnomLevels[i])
      ),
      "#fff",
    ];
  };

  if (
    !showEmpty &&
    allLevelCounts.every((s) => s.value.every((v) => v === 0))
  ) {
    return null;
  }
  return (
    <div className="overflow-visible card bg-white dark:bg-xslate-3 rounded-md border group-hover/container:border-xslate-7 border-xslate-6 group-hover/container:shadow-xl shadow-sm">
      <LinkWithQuery
        to={`details/${group._id}`}
        className="card-body pb-1 pt-2 px-2 hover:cursor-pointer hover:bg-zinc-200 hover:dark:bg-xslate-4"
        pick={{
          d: true,
          mo: true,
          y: true,
          z: true,
          cd: true,
        }}
      >
        <div className="px-2 text-xslate-11 whitespace-nowrap flex uppercase overflow-hidden text-ellipsis text-[0.85em]">
          {group.name}
        </div>
      </LinkWithQuery>
      <div className="w-full h-[15em] box-border">
        <BarChart
          classes={{
            BarChart__chart: "BarChart__chart-full-width",
          }}
          onclick={undefined}
          selectedIndex={undefined}
          showTitle={undefined}
          unit={undefined}
          withStripes
          chartTitle="Daily Breakdown"
          xAxisLabel={"Date"}
          yAxisLabel={"Number of variables"}
          selectedRange={[
            $ds.axisRangeTo.dateString,
            $ds.axisRangeTo.dateString,
          ]}
          chartKeys={["level1", "level2", "level3"]} // todo why is this needed?
          barColors={shownColorBars(DRIFT_LEVELS)}
          data={allLevelCounts.map((s) => {
            return {
              ...s,
              value: shownLevelCounts(s.value),
            };
          })}
          // kinda hacky way of getting tooltip value and positioning to account for hidden levels
          hiddenData={allLevelCounts.map((s) => {
            return {
              value: hiddenLevels(s.value),
              key: s.key,
            };
          })}
          range={undefined}
        />
      </div>
    </div>
  );
}
