import { useState } from 'react';
import { useTheme } from '@mui/material/styles';
import { scaleBand, scaleLinear } from '@visx/scale';
import { Group } from '@visx/group';
import { AxisBottom, AxisLeft } from '@visx/axis';
import { GridRows } from '@visx/grid';
import { TooltipWithBounds, useTooltip } from '@visx/tooltip';
import { localPoint } from '@visx/event';
import { extent } from 'd3-array';

import { TOOLTIP_STYLES, getDiscreteAxisNumTicks, useCommonChartStyles } from './utils';

export type UnitInfo = { id: string; name: string };

export type HistogramBarChartProps = {
  width: number;
  height: number;
  margin?: { top: number; right: number; bottom: number; left: number };
  data: Map<number, UnitInfo[]>;
  barColoring?: (key: number, value: number) => string;
  formatTooltipLabel?: (key: number) => string;
  formatBottomAxisLabel?: (key: number) => string;
  minBottomAxisTickGap?: number;
};

export type TooltipData = {
  key: number;
  units: UnitInfo[];
};

const TOOLTIP_TIMEOUT = 150; // ms
const TOOLTIP_MAX_LINES = 9;
const DEFAULT_MARGIN = { top: 10, left: 30, right: 10, bottom: 20 };

export const HistogramBarChart = ({
  width,
  height,
  margin = DEFAULT_MARGIN,
  data,
  barColoring,
  formatTooltipLabel,
  formatBottomAxisLabel,
  minBottomAxisTickGap = 40,
}: HistogramBarChartProps) => {
  const { tooltipData, tooltipLeft, tooltipTop, tooltipOpen, showTooltip, hideTooltip } =
    useTooltip<TooltipData>();
  const [hoveredBar, setHoveredBar] = useState<number | null>(null);
  let tooltipTimeout: number | undefined;

  const theme = useTheme();
  const {
    elementColor1: barColor,
    highlightColor,
    axisColor,
    axisBottomTickLabelProps,
    axisLeftTickLabelProps,
    gridColor,
    gridStrokeDasharray,
  } = useCommonChartStyles();

  const xMax = width - margin.left - margin.right;
  const yMax = height - margin.top - margin.bottom;

  const [minKey, maxKey] = extent(Array.from(data.keys()));
  const keys =
    minKey != null && maxKey != null
      ? [...Array(maxKey - minKey + 1).keys()].map((i) => i + minKey)
      : [];

  const getUnits = (key: number) => data.get(key) ?? [];

  const defaultColoring = () => barColor;
  const barColoringFn = barColoring ?? defaultColoring;

  const barScale = scaleBand<number>({
    domain: keys,
    range: [0, xMax],
    padding: 0.05,
    round: true,
  });

  const barHeights = keys.map((key) => getUnits(key).length);
  const yScale = scaleLinear<number>({
    domain: [0, barHeights.length > 0 ? Math.max(...barHeights) : 10],
    range: [yMax, 0],
    round: true,
  });

  if (width < 10) return null;

  return (
    <div style={{ position: 'relative' }}>
      <svg width={width} height={height}>
        <GridRows
          scale={yScale}
          width={xMax}
          stroke={gridColor}
          strokeDasharray={gridStrokeDasharray}
          top={margin.top}
          left={margin.left}
        />
        <Group top={margin.top} left={margin.left}>
          <AxisBottom
            hideAxisLine
            hideTicks
            top={yMax}
            scale={barScale}
            numTicks={getDiscreteAxisNumTicks(width, keys.length, minBottomAxisTickGap)}
            stroke={axisColor}
            tickStroke={axisColor}
            tickLabelProps={axisBottomTickLabelProps}
            tickFormat={(key) => formatBottomAxisLabel?.(key) ?? key.toString()}
          />
          <AxisLeft
            scale={yScale}
            stroke={axisColor}
            tickStroke={axisColor}
            tickLabelProps={axisLeftTickLabelProps}
          />

          <Group>
            {keys.map((key) => {
              const value = getUnits(key).length;
              if (value === 0) return null;
              const width = barScale.bandwidth();
              return (
                <rect
                  key={`bar-${key}`}
                  x={barScale(key)}
                  width={width}
                  y={yScale(value)}
                  height={yMax - yScale(value)}
                  fill={
                    hoveredBar != null && hoveredBar === key
                      ? highlightColor
                      : barColoringFn(key, value)
                  }
                  onMouseOver={() => setHoveredBar(key)}
                  onMouseOut={() => {
                    tooltipTimeout = window.setTimeout(() => {
                      setHoveredBar(null);
                      hideTooltip();
                    }, TOOLTIP_TIMEOUT);
                  }}
                  onMouseMove={(event) => {
                    clearTimeout(tooltipTimeout);
                    const coords = localPoint(event);
                    showTooltip({
                      tooltipData: { key, units: getUnits(key) },
                      tooltipLeft: coords?.x,
                      tooltipTop: coords?.y,
                    });
                  }}
                />
              );
            })}
          </Group>
        </Group>
      </svg>
      {tooltipOpen && tooltipData && (
        <TooltipWithBounds
          key={Math.random()}
          top={tooltipTop}
          left={tooltipLeft}
          style={TOOLTIP_STYLES}
        >
          <div style={{ color: highlightColor }}>
            <strong>{formatTooltipLabel?.(tooltipData.key) ?? tooltipData.key.toString()}</strong>
          </div>
          <div>
            {tooltipData.units
              .slice(
                0,
                tooltipData.units.length > TOOLTIP_MAX_LINES
                  ? TOOLTIP_MAX_LINES - 1 // make room for the "and n more" line
                  : TOOLTIP_MAX_LINES,
              )
              .map((unit) => (
                <div key={unit.id}>{unit.name}</div>
              ))}
            {tooltipData.units.length > TOOLTIP_MAX_LINES && (
              <div style={{ color: theme.palette.grey[500] }}>
                and {tooltipData.units.length - (TOOLTIP_MAX_LINES - 1)} more
              </div>
            )}
          </div>
        </TooltipWithBounds>
      )}
    </div>
  );
};
