import { useState } from 'react';
import { scaleBand, scaleLinear } from '@visx/scale';
import { Group } from '@visx/group';
import { BarGroup } from '@visx/shape';
import { AxisBottom, AxisLeft } from '@visx/axis';
import { GridRows } from '@visx/grid';
import { TooltipWithBounds, useTooltip } from '@visx/tooltip';
import { localPoint } from '@visx/event';

import { TOOLTIP_STYLES, getDiscreteAxisNumTicks, useCommonChartStyles } from './utils';

const TOOLTIP_TIMEOUT = 150; // ms

export type TooltipData = {
  groupIndex: number;
  barIndex: number;
  value: number;
};

export type GroupedBarChartProps = {
  width: number;
  height: number;
  margin?: { top: number; right: number; bottom: number; left: number };
  data: number[];
  groupSize: number;
  yAxisLimitGranularity: number; // number of decimals
  barColoring?: (groupIndex: number, barIndex: number, value: number) => string;
  tooltipUpperText: (data: TooltipData) => string;
  tooltipLowerText: (data: TooltipData) => string;
};

type GroupType = {
  [key: number]: number;
  group: number;
};

const DEFAULT_MARGIN = { top: 20, left: 40, right: 10, bottom: 40 };

export const GroupedBarChart = ({
  width,
  height,
  margin = DEFAULT_MARGIN,
  data,
  groupSize,
  yAxisLimitGranularity, // number of decimals to round the y-axis limits to
  barColoring,
  tooltipUpperText,
  tooltipLowerText,
}: GroupedBarChartProps) => {
  const { tooltipData, tooltipLeft, tooltipTop, tooltipOpen, showTooltip, hideTooltip } =
    useTooltip<TooltipData>();
  const [hoveredBarIndex, setHoveredBarIndex] = useState<number | null>(null);

  const {
    elementColor1: barColor1,
    elementColor2: barColor2,
    highlightColor,
    axisColor,
    axisBottomTickLabelProps,
    axisLeftTickLabelProps,
    gridColor,
    gridStrokeDasharray,
  } = useCommonChartStyles();

  const defaultColoring = (groupIndex: number) => (groupIndex % 2 === 0 ? barColor1 : barColor2);
  const barColoringFn = barColoring ?? defaultColoring;

  const nOfGroups = Math.ceil(data.length / groupSize);

  const groupKeys = [...Array(nOfGroups).keys()];
  const barKeys = [...Array(groupSize).keys()];

  let tooltipTimeout: number;

  const groupedData: GroupType[] = groupKeys.map((group) => ({
    group,
    ...Object.fromEntries(
      data.slice(group * groupSize, (group + 1) * groupSize).map((value, index) => [index, value]),
    ),
  }));

  // apply rounding if possible because it makes more crisp, better looking bars,
  // trying to round when the chart is too narrow leads to bars with 0 width
  const roundScales = width > 2 * data.length;

  // each band in this scale corresponds to a group
  const groupScale = scaleBand<number>({
    domain: groupKeys,
    padding: 0.05,
    round: roundScales,
  });

  // the entire scale spans a single group, each band is a single bar
  const barScale = scaleBand<number>({
    domain: barKeys,
    paddingInner: 0.35,
    paddingOuter: 0.3,
    round: roundScales,
  });

  const filteredData = data.filter((v) => v != null && Number.isFinite(v));

  // the y axis limits come from rounding outwards to the nearest 10^(-yAxisLimitGranularity)
  // and adding 10^(-yAxisLimitGranularity) of empty space to either side
  const multiplier = Math.pow(10, yAxisLimitGranularity);
  const emptySpace = Math.pow(10, -yAxisLimitGranularity);

  const heightScale = scaleLinear<number>({
    domain: [
      Math.floor(Math.min(...filteredData) * multiplier) / multiplier - emptySpace,
      Math.ceil(Math.max(...filteredData) * multiplier) / multiplier + emptySpace,
    ],
  });

  const xMax = width - margin.left - margin.right;
  const yMax = height - margin.top - margin.bottom;

  groupScale.range([0, xMax]);
  barScale.range([0, groupScale.bandwidth()]);
  heightScale.range([yMax, 0]);

  if (width < 10) return null;

  return (
    <div style={{ position: 'relative' }}>
      <svg width={width} height={height}>
        <Group top={margin.top} left={margin.left}>
          <GridRows
            scale={heightScale}
            width={xMax}
            stroke={gridColor}
            strokeDasharray={gridStrokeDasharray}
          />
          <AxisBottom
            hideAxisLine
            hideTicks
            top={yMax}
            scale={groupScale}
            numTicks={getDiscreteAxisNumTicks(width, groupKeys.length, 25)}
            stroke={axisColor}
            tickStroke={axisColor}
            tickLabelProps={axisBottomTickLabelProps}
          />
          <AxisLeft
            scale={heightScale}
            stroke={axisColor}
            tickStroke={axisColor}
            tickLabelProps={axisLeftTickLabelProps}
          />

          <BarGroup
            data={groupedData}
            keys={barKeys}
            height={yMax}
            x0={(obj: GroupType) => obj.group}
            x0Scale={groupScale}
            x1Scale={barScale}
            yScale={heightScale}
            // bar.color is required but never used,
            // the color is set based on the groupIndex later
            color={() => barColor1}
          >
            {(barGroups) =>
              barGroups.map((barGroup, groupIndex) => (
                <Group key={`bar-group-${barGroup.index}-${barGroup.x0}`} left={barGroup.x0}>
                  {barGroup.bars.map(
                    (bar) =>
                      bar.value != null &&
                      Number.isFinite(bar.value) && (
                        <rect
                          key={`bar-group-bar-${barGroup.index}-${bar.index}-${bar.value}`}
                          x={bar.x}
                          y={bar.y}
                          width={bar.width}
                          height={bar.height}
                          fill={
                            groupIndex * groupSize + bar.index === hoveredBarIndex
                              ? highlightColor
                              : barColoringFn(groupIndex, bar.index, bar.value)
                          }
                          onMouseOver={() => {
                            setHoveredBarIndex(groupIndex * groupSize + bar.index);
                          }}
                          onMouseOut={() => {
                            tooltipTimeout = window.setTimeout(() => {
                              setHoveredBarIndex(null);
                              hideTooltip();
                            }, TOOLTIP_TIMEOUT);
                          }}
                          onMouseMove={(event) => {
                            clearTimeout(tooltipTimeout);
                            const coords = localPoint(event);
                            showTooltip({
                              tooltipData: {
                                groupIndex,
                                barIndex: bar.index,
                                value: bar.value,
                              },
                              tooltipLeft: coords?.x,
                              tooltipTop: coords?.y,
                            });
                          }}
                        />
                      ),
                  )}
                </Group>
              ))
            }
          </BarGroup>
        </Group>
      </svg>
      {tooltipOpen && tooltipData && (
        <TooltipWithBounds
          key={Math.random()}
          top={tooltipTop}
          left={tooltipLeft}
          style={TOOLTIP_STYLES}
        >
          <div style={{ color: highlightColor }}>
            <strong>{tooltipUpperText(tooltipData)}</strong>
          </div>
          <div>{tooltipLowerText(tooltipData)}</div>
        </TooltipWithBounds>
      )}
    </div>
  );
};
