import type React from 'react';
import { useCallback, useId, useMemo, useState } from 'react';
import { scaleLinear, scaleTime } from '@visx/scale';
import { Group } from '@visx/group';
import { AreaClosed, Bar, Line, LinePath } from '@visx/shape';
import { curveLinear, curveStepAfter } from '@visx/curve';
import { AxisBottom, AxisLeft } from '@visx/axis';
import { Grid } from '@visx/grid';
import { TooltipWithBounds, useTooltip } from '@visx/tooltip';
import { localPoint } from '@visx/event';
import { bisector, extent } from 'd3-array';

import { type DateTimeRange } from '~utils/time';

import {
  TOOLTIP_STYLES,
  formatAxisDate,
  formatTooltipDate,
  getContinuousAxisNumTicks,
  useCommonChartStyles,
} from './utils';

export type TooltipData = {
  time: Date;
  values: Record<string, number>;
};

export type ChartSeries = {
  data: DataPoint[];
  interpolation: 'linear' | 'step-after';
  name: string;
  color: string;
  strokeDasharray?: string;
};

export type ChartArea = {
  data: AreaDataPoint[];
  interpolation: 'linear' | 'step-after';
  color: string;
  opacity?: number;
};

export type ChartAnnotation =
  | {
      type: 'point';
      x: Date;
      y: number;
      color: string;
    }
  | {
      type: 'line';
      height: number;
      color: string;
      strokeDasharray?: string;
    }
  | {
      type: 'line';
      time: Date;
      color: string;
      strokeDasharray?: string;
    }
  | {
      type: 'area';
      x?: { from?: Date; to?: Date };
      y?: { from?: number; to?: number };
      color: string;
      opacity?: number;
    };

export type DataPoint = {
  time: Date;
  value: number;
};

export type AreaDataPoint = {
  time: Date;
  lower: number;
  upper: number;
};

export type MultiLineChartProps = {
  width: number;
  height: number;
  margin?: { top: number; right: number; bottom: number; left: number };
  series: ChartSeries[];
  areas: ChartArea[];
  annotations?: ChartAnnotation[];
  highlightedSeriesName?: string | null;
  units: string;
  tooltipDecimalDigits?: number;
  tooltipSnapIndex?: Date[];
  xAxisInclude?: Date[];
  yAxisInclude?: number[];
  yAxisClipMin?: number;
  yAxisClipMax?: number;
  niceYScale?: boolean;
  onZoom?: (range: DateTimeRange) => void;
};

const defaultMargin = { top: 20, left: 40, right: 15, bottom: 20 };

export const MultiLineChart = ({
  width,
  height,
  margin = defaultMargin,
  series,
  areas,
  annotations = [],
  highlightedSeriesName = null,
  units,
  tooltipDecimalDigits = 2,
  tooltipSnapIndex,
  xAxisInclude,
  yAxisInclude,
  yAxisClipMin,
  yAxisClipMax,
  niceYScale = true,
  onZoom,
}: MultiLineChartProps) => {
  const {
    tooltipData,
    tooltipLeft = 0,
    tooltipTop = 0,
    tooltipOpen,
    showTooltip,
    hideTooltip,
  } = useTooltip<TooltipData>();

  const [zoomSelection, setZoomSelection] = useState<{ start: Date; end: Date } | null>(null);

  const {
    highlightColor,
    axisColor,
    axisBottomTickLabelProps,
    axisLeftTickLabelProps,
    gridColor,
    gridStrokeDasharray,
  } = useCommonChartStyles();

  const xMax = width - margin.left - margin.right;
  const yMax = height - margin.top - margin.bottom;

  const timeScale = useMemo(() => {
    // min and max timestamps across all series in one array
    const minAndMaxDates = series.flatMap((s) =>
      s.data.length > 0 ? [s.data[0].time, s.data[s.data.length - 1].time] : [],
    );

    if (xAxisInclude != null) {
      minAndMaxDates.push(...xAxisInclude);
    }

    const [min, max] = extent(minAndMaxDates);

    return scaleTime({
      range: [0, xMax],
      domain:
        min == null || max == null
          ? undefined
          : min.getTime() === max.getTime()
          ? [new Date(min.getTime() - 10_000), new Date(max.getTime() + 10_000)]
          : [min, max],
    });
  }, [series, xMax, xAxisInclude]);

  const yScale = useMemo(() => {
    // min and max values across all series in one array
    // extend ignores NaN's
    const minAndMaxValues = series
      .flatMap((s) => extent(s.data, (d) => d.value))
      .filter((n): n is number => n !== undefined);

    if (yAxisInclude != null) {
      minAndMaxValues.push(...yAxisInclude);
    }

    let [min, max] = extent(minAndMaxValues);
    if (yAxisClipMin !== undefined) {
      min = min !== undefined ? Math.max(min, yAxisClipMin) : yAxisClipMin;
    }
    if (yAxisClipMax !== undefined) {
      max = max !== undefined ? Math.min(max, yAxisClipMax) : yAxisClipMax;
    }
    if (min === undefined) min = 0;
    if (max === undefined) max = 1;
    if (min === max) [min, max] = [min - 1, max + 1];

    return scaleLinear({
      range: [yMax, 0],
      domain: [min, max],
      nice: niceYScale,
    });
  }, [series, yMax, yAxisInclude, yAxisClipMin, yAxisClipMax, niceYScale]);

  const handleTooltip = useCallback(
    (event: React.TouchEvent<SVGRectElement> | React.MouseEvent<SVGRectElement>): void => {
      const { x, y } = localPoint(event) || { x: 0, y: 0 };
      const dateUnderCursor = timeScale.invert(x - margin.left);
      const tooltipDate =
        tooltipSnapIndex != null && tooltipSnapIndex.length > 0
          ? tooltipSnapIndex[bisector((d) => d).center(tooltipSnapIndex, dateUnderCursor)]
          : dateUnderCursor;
      showTooltip({
        tooltipData: {
          time: tooltipDate,
          values: Object.fromEntries(
            series.map((s) => {
              if (s.data.length === 0) return [s.name, NaN];
              const index = bisector((d: DataPoint) => d.time).left(s.data, tooltipDate);
              if (index >= s.data.length) return [s.name, NaN];

              // date at index is at or after dateUnderCursor
              if (s.data[index].time.getTime() === tooltipDate.getTime()) {
                return [s.name, s.data[index].value];
              }
              // date at index must be after dateUnderCursor

              if (index === 0) return [s.name, NaN];
              const indexBefore = index - 1;
              if (s.interpolation === 'step-after') {
                return [s.name, s.data[indexBefore].value];
              }

              // linear interpolation
              // dValue / DValue = dTime / DTime
              const DTime = s.data[index].time.getTime() - s.data[indexBefore].time.getTime();
              const DValue = s.data[index].value - s.data[indexBefore].value;
              const dTime = tooltipDate.getTime() - s.data[indexBefore].time.getTime();
              const dValue = DValue * (dTime / DTime);
              return [s.name, s.data[indexBefore].value + dValue];
            }),
          ),
        },
        tooltipLeft: x,
        tooltipTop: y,
      });
    },
    [showTooltip, series, tooltipSnapIndex, margin, timeScale],
  );

  const handleDrag = useCallback(
    (event: React.TouchEvent<SVGRectElement> | React.MouseEvent<SVGRectElement>): void => {
      if (onZoom == null) return;

      // prevent accidental text selection
      event.preventDefault();

      const { x } = localPoint(event) || { x: 0, y: 0 };
      const dateUnderCursor = timeScale.invert(x - margin.left);

      setZoomSelection((prev) => {
        if (prev == null) {
          return { start: dateUnderCursor, end: dateUnderCursor };
        }
        return { start: prev.start, end: dateUnderCursor };
      });
    },
    [margin, timeScale, onZoom],
  );

  const clipPathId = useId();

  if (width < 10) return null;

  return (
    <div style={{ position: 'relative' }}>
      <svg width={width} height={height}>
        <defs>
          <clipPath id={clipPathId}>
            {/* Add just a bit of padding, as we don't want to
                cut into lines running along the chart edge. */}
            <rect x={-1} y={-1} width={xMax + 2} height={yMax + 2} />
          </clipPath>
        </defs>

        <Group left={margin.left} top={margin.top}>
          <Grid
            xScale={timeScale}
            numTicksColumns={getContinuousAxisNumTicks(width)}
            yScale={yScale}
            width={xMax}
            height={yMax}
            stroke={gridColor}
            strokeDasharray={gridStrokeDasharray}
          />
          <AxisLeft
            scale={yScale}
            stroke={axisColor}
            tickStroke={axisColor}
            tickLabelProps={axisLeftTickLabelProps}
          />
          <AxisBottom
            hideAxisLine
            hideTicks
            top={yMax}
            scale={timeScale}
            numTicks={getContinuousAxisNumTicks(width)}
            tickFormat={(date) => formatAxisDate(date as Date)}
            stroke={axisColor}
            tickStroke={axisColor}
            tickLabelProps={axisBottomTickLabelProps}
          />

          {/* This group contains all the series, areas and line+area annotations,
              they are clipped to restrict them to the inner chart area only.
              All lines should be drawn after all areas. */}
          <g clipPath={`url(#${clipPathId})`}>
            {annotations
              .flatMap((a) => (a.type === 'area' ? [a] : []))
              .map((a) => ({
                x0: a.x?.from != null ? Math.max(0, timeScale(a.x.from)) : 0,
                x1: a.x?.to != null ? Math.min(xMax, timeScale(a.x.to)) : xMax,
                y0: a.y?.to != null ? Math.max(0, yScale(a.y.to)) : 0,
                y1: a.y?.from != null ? Math.min(yMax, yScale(a.y.from)) : yMax,
                color: a.color,
                opacity: a.opacity ?? 0.2,
              }))
              .filter(
                (a) =>
                  a.x0 < a.x1 && a.y0 < a.y1 && a.x0 < xMax && a.x1 > 0 && a.y0 < yMax && a.y1 > 0,
              )
              .map((a) => (
                <Bar
                  key={`annotation-area-${a.x0}-${a.x1}-${a.y0}-${a.y1}-${a.color}-${a.opacity}`}
                  x={a.x0}
                  width={a.x1 - a.x0}
                  y={a.y0}
                  height={a.y1 - a.y0}
                  fill={a.color}
                  opacity={a.opacity}
                />
              ))}

            {areas.map((a, index) => (
              <AreaClosed
                key={index}
                data={a.data}
                x={(d) => timeScale(d.time)}
                yScale={yScale}
                y0={(d) => yScale(d.lower)}
                y1={(d) => yScale(d.upper)}
                defined={(d) => Number.isFinite(d.lower) && Number.isFinite(d.upper)}
                curve={a.interpolation === 'linear' ? curveLinear : curveStepAfter}
                fill={a.color}
                opacity={highlightedSeriesName != null ? 0.05 : a.opacity ?? 0.4}
              />
            ))}

            {annotations
              .flatMap((a) => (a.type === 'line' && 'height' in a ? [a] : []))
              .filter((a) => yScale.domain()[0] <= a.height && a.height <= yScale.domain()[1])
              .map((a) => (
                <Line
                  key={`annotation-line-horizontal-${a.height}-${a.color}-${a.strokeDasharray}`}
                  from={{ x: 0, y: yScale(a.height) }}
                  to={{ x: xMax, y: yScale(a.height) }}
                  stroke={a.color}
                  strokeWidth={1}
                  strokeDasharray={a.strokeDasharray}
                  pointerEvents="none"
                />
              ))}

            {annotations
              .flatMap((a) => (a.type === 'line' && 'time' in a ? [a] : []))
              .filter((a) => timeScale.domain()[0] <= a.time && a.time <= timeScale.domain()[1])
              .map((a) => (
                <Line
                  key={`annotation-line-vertical-${a.time}-${a.color}-${a.strokeDasharray}`}
                  from={{ x: timeScale(a.time), y: 0 }}
                  to={{ x: timeScale(a.time), y: yMax }}
                  stroke={a.color}
                  strokeWidth={1}
                  strokeDasharray={a.strokeDasharray}
                  pointerEvents="none"
                />
              ))}

            {series.map((s, index) => (
              <LinePath
                key={index}
                stroke={s.color}
                strokeWidth={1}
                strokeDasharray={s.strokeDasharray}
                opacity={
                  highlightedSeriesName != null && highlightedSeriesName !== s.name ? 0.2 : 1
                }
                data={s.data}
                curve={s.interpolation === 'linear' ? curveLinear : curveStepAfter}
                x={(d) => timeScale(d.time)}
                y={(d) => yScale(d.value)}
                defined={(d) => Number.isFinite(d.value)}
              />
            ))}
          </g>

          {annotations
            .flatMap((a) => (a.type === 'point' ? [a] : []))
            .filter(
              (a) =>
                timeScale.domain()[0] <= a.x &&
                a.x <= timeScale.domain()[1] &&
                yScale.domain()[0] <= a.y &&
                a.y <= yScale.domain()[1],
            )
            .map((a) => (
              <circle
                key={`annotation-point-${a.x}-${a.y}-${a.color}`}
                cx={timeScale(a.x)}
                cy={yScale(a.y)}
                r={4}
                fill={a.color}
                stroke="white"
                strokeWidth={1}
              />
            ))}

          {zoomSelection != null && (
            <>
              <Bar
                y={0}
                height={yMax}
                x={Math.min(timeScale(zoomSelection.start), timeScale(zoomSelection.end))}
                width={Math.abs(timeScale(zoomSelection.end) - timeScale(zoomSelection.start))}
                fill={highlightColor}
                opacity={0.2}
              />
              <Line
                from={{ x: timeScale(zoomSelection.start), y: 0 }}
                to={{ x: timeScale(zoomSelection.start), y: yMax }}
                stroke={highlightColor}
                strokeWidth={1}
                pointerEvents="none"
                strokeDasharray="5,2"
              />
              <Line
                from={{ x: timeScale(zoomSelection.end), y: 0 }}
                to={{ x: timeScale(zoomSelection.end), y: yMax }}
                stroke={highlightColor}
                strokeWidth={1}
                pointerEvents="none"
                strokeDasharray="5,2"
              />
            </>
          )}

          {tooltipData && (
            <g>
              <Line
                from={{ x: timeScale(tooltipData.time), y: 0 }}
                to={{ x: timeScale(tooltipData.time), y: yMax }}
                stroke={highlightColor}
                strokeWidth={1}
                pointerEvents="none"
                strokeDasharray="5,2"
              />
              {series.map(
                (s) =>
                  Number.isFinite(tooltipData.values[s.name]) &&
                  yScale.domain()[0] <= tooltipData.values[s.name] &&
                  tooltipData.values[s.name] <= yScale.domain()[1] && (
                    <circle
                      key={`dot-${s.name}-${tooltipData.time.getTime()}`}
                      cx={timeScale(tooltipData.time)}
                      cy={yScale(tooltipData.values[s.name])}
                      r={4}
                      fill={s.color}
                      stroke="white"
                      strokeWidth={1}
                      pointerEvents="none"
                    />
                  ),
              )}
            </g>
          )}
        </Group>

        <rect
          x={margin.left}
          y={margin.top}
          width={xMax}
          height={yMax}
          fill="transparent"
          onTouchStart={handleTooltip}
          onTouchMove={handleTooltip}
          onMouseMove={(event) => {
            if (zoomSelection != null) {
              handleDrag(event);
              hideTooltip();
            } else {
              handleTooltip(event);
            }
          }}
          onMouseDown={handleDrag}
          onMouseUp={() => {
            if (zoomSelection != null) {
              const { start: a, end: b } = zoomSelection;
              setZoomSelection(null);
              if (onZoom != null && zoomSelection.start !== zoomSelection.end) {
                if (a < b) onZoom({ start: a, end: b });
                else onZoom({ start: b, end: a });
              }
            }
          }}
          onMouseLeave={() => {
            hideTooltip();
            setZoomSelection(null);
          }}
        />
      </svg>

      {tooltipOpen && tooltipData && (
        <TooltipWithBounds
          key={`tooltip-${tooltipTop}-${tooltipLeft}-${tooltipData.time.getTime()}`}
          top={tooltipTop}
          left={tooltipLeft}
          style={TOOLTIP_STYLES}
        >
          <div style={{ color: highlightColor }}>
            <strong>{formatTooltipDate(tooltipData.time)}</strong>
          </div>
          {series.map((s) => (
            <div key={s.name} style={{ display: 'flex', alignItems: 'center' }}>
              <div style={{ marginRight: 3 }}>
                <svg width="10" height="10">
                  <g transform="translate(5, 5)">
                    <circle r="5" fill={s.color} />
                  </g>
                </svg>
              </div>
              {s.name}:{' '}
              {Number.isFinite(tooltipData.values[s.name])
                ? `${tooltipData.values[s.name].toFixed(tooltipDecimalDigits)} ${units}`
                : '-'}
            </div>
          ))}
        </TooltipWithBounds>
      )}
    </div>
  );
};
