import {
  BidirectionalArrowIcon,
  Box,
  Button,
  ChakraTooltip,
  Column,
  Row,
  Select,
  Spinner,
  StopIcon,
  Text,
} from "@hightouchio/ui";
import groupBy from "lodash/groupBy";
import { FC, useCallback, useEffect, useMemo, useState } from "react";
import { GraphSeries } from "src/components/analytics/cross-audience-graph/types";
import { formatMetricValue } from "src/components/analytics/cross-audience-graph/utils";
import { GraphTooltip } from "src/components/analytics/graph-tooltip";
import { TextWithTooltip } from "src/components/text-with-tooltip";
import { isPresent } from "ts-extras";
import { shouldUsePercentFormat } from "./utils";

type HeatmapTableProps = {
  data: GraphSeries[];
  isLoading?: boolean;
};

const COLOR_SCHEMES = ["electric", "ocean", "peridot"] as const;

type HeatmapData = {
  dimension1Values: string[];
  dimension2Values: string[];
  grid: {
    dimension1: string;
    cells: {
      value: number;
      heatmapGrouping?: {
        type: string;
        dimension: string;
        value: string;
      }[];
    }[];
  }[];
  minValue: number;
  maxValue: number;
};

export const generateHeatmapData = (
  seriesGroups: {
    seriesName: string;
    series: {
      data: { metricValue: number }[];
      heatmapGrouping: {
        type: string;
        dimension: string;
        value: string;
      }[];
    }[];
  }[],
  selectedSeriesName: string | undefined,
  axesFlipped: boolean,
): HeatmapData => {
  const seriesData = seriesGroups.find(
    (group) => group.seriesName === selectedSeriesName,
  )?.series;

  if (!seriesData?.length) {
    return {
      dimension1Values: [],
      dimension2Values: [],
      grid: [],
      minValue: 0,
      maxValue: 0,
    };
  }

  // Create a cell for a given series
  const createCell = (matchingSeries: (typeof seriesData)[0] | undefined) => ({
    value: matchingSeries?.data[0]?.metricValue ?? 0,
    heatmapGrouping: matchingSeries?.heatmapGrouping,
  });

  // Find a series that matches the given dimension values
  const findMatchingSeries = (dim1Value: string, dim2Value?: string) => {
    return seriesData.find((series) => {
      if (!dim2Value) {
        return series.heatmapGrouping?.[0]?.value === dim1Value;
      }
      return axesFlipped
        ? series.heatmapGrouping?.[1]?.value === dim1Value &&
            series.heatmapGrouping?.[0]?.value === dim2Value
        : series.heatmapGrouping?.[0]?.value === dim1Value &&
            series.heatmapGrouping?.[1]?.value === dim2Value;
    });
  };

  // Get unique values for a given dimension
  const getUniqueValues = (dimensionIndex: number) => {
    const values = new Set<string>();
    seriesData.forEach((row) => {
      values.add(row.heatmapGrouping?.[dimensionIndex]?.value ?? "");
    });
    return Array.from(values);
  };

  // Track min/max values across all cells
  let minValue = Infinity;
  let maxValue = -Infinity;
  const updateMinMax = (value: number) => {
    minValue = Math.min(minValue, value);
    maxValue = Math.max(maxValue, value);
  };

  const numDimensions = seriesData[0]?.heatmapGrouping?.length ?? 0;

  switch (numDimensions) {
    case 0: {
      const cell = createCell(seriesData[0]);
      updateMinMax(cell.value);
      return {
        dimension1Values: [""],
        dimension2Values: [""],
        grid: [{ dimension1: "", cells: [cell] }],
        minValue: cell.value,
        maxValue: cell.value,
      };
    }

    case 1: {
      const values = getUniqueValues(0);
      const gridData = values.map((value) => {
        const cell = createCell(findMatchingSeries(value));
        updateMinMax(cell.value);
        return cell;
      });

      return axesFlipped
        ? {
            dimension1Values: values,
            dimension2Values: [""],
            grid: values.map((value, i) => ({
              dimension1: value,
              cells: [gridData[i]!],
            })),
            minValue,
            maxValue,
          }
        : {
            dimension1Values: [""],
            dimension2Values: values,
            grid: [{ dimension1: "", cells: gridData }],
            minValue,
            maxValue,
          };
    }

    case 2: {
      const dim1Values = getUniqueValues(axesFlipped ? 1 : 0);
      const dim2Values = getUniqueValues(axesFlipped ? 0 : 1);

      const gridData = dim1Values.map((dim1) => ({
        dimension1: dim1,
        cells: dim2Values.map((dim2) => {
          const cell = createCell(findMatchingSeries(dim1, dim2));
          updateMinMax(cell.value);
          return cell;
        }),
      }));

      return {
        dimension1Values: dim1Values,
        dimension2Values: dim2Values,
        grid: gridData,
        minValue,
        maxValue,
      };
    }

    default:
      return {
        dimension1Values: [],
        dimension2Values: [],
        grid: [],
        minValue: 0,
        maxValue: 0,
      };
  }
};

export const HeatmapTable: FC<HeatmapTableProps> = ({
  data,
  isLoading = false,
}) => {
  const [axesFlipped, setAxesFlipped] = useState(false);

  // We want to be able to show the following as dimensions on the heatmap:
  // - group by columns
  // - splits
  // - audiences
  const processedData = useMemo(() => {
    return data.map((series) => ({
      ...series,
      heatmapGrouping: [
        ...(series.grouping?.map((group) => ({
          type: "groupBy",
          dimension: group.alias ?? "",
          value: group.value ?? "",
        })) ?? []),
        series.splitName
          ? {
              dimension: "split",
              type: "split",
              value: series.splitName,
            }
          : undefined,
        series.audienceName
          ? {
              dimension: "audience",
              type: "audience",
              value: series.audienceName,
            }
          : undefined,
      ].filter(isPresent),
    }));
  }, [data]);

  // The heatmap can only support up to 2 dimensions, so we need to group the series by
  // the remaining dimensions and allow the user to select the series group they want to see
  const seriesGroups = useMemo(() => {
    const groups = groupBy(processedData, (series) => {
      const metricPart = series.metricName;
      const groupingPart = series.heatmapGrouping
        .slice(2)
        .map((group) => group.value)
        .join(" | ");

      return groupingPart ? `${metricPart} | ${groupingPart}` : metricPart;
    });

    return Object.entries(groups).map(([seriesName, series], index) => ({
      seriesName,
      series: series.map((s) => ({
        ...s,
        // We take the first two groupings as the actual groupings used in the heatmap
        heatmapGrouping: s.heatmapGrouping.slice(0, 2),
      })),
      metricName: series?.[0]?.metricName,
      audienceName: series?.[0]?.audienceName,
      splitName: series?.[0]?.splitName,
      colorScheme: COLOR_SCHEMES[index % COLOR_SCHEMES.length],
    }));
  }, [processedData]);

  // State for selected series
  const [selectedSeriesName, setSelectedSeriesName] = useState<string>();

  // Replace the existing heatMapData useMemo with:
  const heatMapData = useMemo(
    () => generateHeatmapData(seriesGroups, selectedSeriesName, axesFlipped),
    [seriesGroups, selectedSeriesName, axesFlipped],
  );

  const getColorForValue = useCallback(
    (value: number, colorScheme: (typeof COLOR_SCHEMES)[number]) => {
      if (value === 0) {
        return "base.background";
      }

      if (heatMapData.minValue === heatMapData.maxValue) {
        return `${colorScheme}.500`;
      }

      const normalizedValue =
        (value - heatMapData.minValue) /
        (heatMapData.maxValue - heatMapData.minValue);

      const colorIndex = Math.floor(normalizedValue * 8);
      return `${colorScheme}.${(colorIndex + 1) * 100}`;
    },
    [heatMapData.minValue, heatMapData.maxValue],
  );

  const getTextColorForValue = useCallback(
    (value: number) => {
      if (value === 0) return "text.secondary";

      if (heatMapData.minValue === heatMapData.maxValue) {
        return "white";
      }

      const normalizedValue =
        (value - heatMapData.minValue) /
        (heatMapData.maxValue - heatMapData.minValue);

      return normalizedValue >= 0.5 ? "white" : "text.primary";
    },
    [heatMapData.minValue, heatMapData.maxValue],
  );

  useEffect(() => {
    if (
      !selectedSeriesName ||
      !seriesGroups.find((group) => group.seriesName === selectedSeriesName)
    ) {
      setSelectedSeriesName(seriesGroups[0]?.seriesName);
    }
  }, [seriesGroups, selectedSeriesName]);

  const selectedSeries = seriesGroups.find(
    (group) => group.seriesName === selectedSeriesName,
  );

  if (isLoading) {
    return (
      <Column align="center" justify="center" flex={1} minHeight={0} gap={4}>
        <Spinner size="lg" />
      </Column>
    );
  }

  return (
    <Column flex={1} minHeight={0} gap={4}>
      <Row gap={2} justifyContent="space-between">
        <Select
          size="sm"
          placeholder="Select a series..."
          value={selectedSeriesName}
          onChange={(v) => setSelectedSeriesName(v ?? "")}
          options={seriesGroups.map(({ seriesName }) => ({
            label: seriesName,
            value: seriesName,
          }))}
          optionAccessory={(option) => {
            const series = seriesGroups.find(
              (group) => group.seriesName === option.value,
            );
            return {
              type: "icon",
              icon: StopIcon,
              color: `${series?.colorScheme}.500`,
            };
          }}
        />

        <Button
          size="sm"
          variant="secondary"
          onClick={() => setAxesFlipped(!axesFlipped)}
          icon={BidirectionalArrowIcon}
        >
          Switch axes
        </Button>
      </Row>

      <Box flex={1} minHeight={0}>
        {!heatMapData.grid.length ? (
          <Column align="center" justify="center" height="100%" gap={4}>
            <Text color="text.secondary">No data available</Text>
          </Column>
        ) : (
          <Box
            position="relative"
            height="100%"
            overflowX="auto"
            overflowY="auto"
            sx={{
              scrollbarGutter: "stable",
            }}
          >
            <Box
              display="grid"
              gridTemplateColumns={`160px repeat(${heatMapData.dimension2Values.length}, minmax(120px, 400px))`}
              gap="2px"
              position="relative"
              bg="base.canvas"
              minWidth="fit-content"
            >
              {/* Empty cell in top-left corner */}
              <Box
                p={4}
                bg="white"
                position="sticky"
                left={0}
                top={0}
                zIndex={3}
              />

              {/* Header row */}
              {heatMapData.dimension2Values.map((header, index) => (
                <Box
                  key={index}
                  p={4}
                  textAlign="center"
                  position="sticky"
                  top={0}
                  bg="white"
                  zIndex={2}
                >
                  <TextWithTooltip
                    message={header || selectedSeries?.seriesName}
                    color="text.secondary"
                    size="sm"
                  >
                    {header || selectedSeries?.seriesName}
                  </TextWithTooltip>
                </Box>
              ))}

              {heatMapData.grid.map((row, rowIndex) => (
                <>
                  {/* Row header - make it sticky */}
                  <Box
                    key={`header-${rowIndex}`}
                    p={4}
                    position="sticky"
                    left={0}
                    bg="white"
                    zIndex={2}
                  >
                    <TextWithTooltip
                      message={row.dimension1 || selectedSeries?.seriesName}
                      color="text.secondary"
                      size="sm"
                    >
                      {row.dimension1 || selectedSeries?.seriesName}
                    </TextWithTooltip>
                  </Box>

                  {row.cells.map((cell, colIndex) => {
                    return (
                      <ChakraTooltip
                        key={`cell-${rowIndex}-${colIndex}`}
                        label={
                          <GraphTooltip
                            title={selectedSeries?.seriesName ?? ""}
                            subtitles={[
                              cell?.heatmapGrouping
                                ?.map(
                                  (group) =>
                                    `${group.dimension}: ${group.value}`,
                                )
                                .join(", "),
                            ].filter(Boolean)}
                            color={`${selectedSeries?.colorScheme}.500`}
                            value={[
                              {
                                value: formatMetricValue(cell?.value ?? 0),
                                description: shouldUsePercentFormat(
                                  undefined,
                                  undefined,
                                )
                                  ? "percent"
                                  : undefined,
                              },
                            ]}
                          />
                        }
                      >
                        <Box
                          p={4}
                          textAlign="center"
                          bg={getColorForValue(
                            cell?.value ?? 0,
                            selectedSeries?.colorScheme ?? COLOR_SCHEMES[0],
                          )}
                        >
                          <Text
                            size="sm"
                            color={getTextColorForValue(cell?.value ?? 0)}
                          >
                            {formatMetricValue(cell?.value ?? 0)}
                          </Text>
                        </Box>
                      </ChakraTooltip>
                    );
                  })}
                </>
              ))}
            </Box>
          </Box>
        )}
      </Box>
    </Column>
  );
};
