import type { Points, StageFromAPI } from "./types";

/**
 * Find the stage such that the target time falls within the stage's domain.
 * It is up to you to provide a transform function if the target time is not in the same
 * units as the points' time. At the time of writing this, the points are always in
 * absolute time. That is, we never transform them from the API. But the target time
 * comes from a d3 axis. Which is sometimes manipulated so that we can do weird things
 * like align by stage.
 *
 * @param stages
 * @param target
 * @param transformTime Optionally provide a function to transform the time value
 * of points before comparison to the target time. This is useful for when the target
 * time is "in the language of the canvas/svg", but the points are still in absolute time.
 * @returns
 */
const binarySearchStage = <S extends Pick<StageFromAPI, "ptsPartitioned">>(
  stages: S[],
  target: number, // time
  transformTime?: (s: S, t: number) => number
): S | undefined => {
  let left = 0;
  let right = stages.length - 1;

  while (left <= right) {
    const mid = Math.floor((left + right) / 2);
    const s = stages[mid] as (typeof stages)[number];
    const firstPointOfStage = s.ptsPartitioned[0].pts[0];
    const lastSegment = s.ptsPartitioned[s.ptsPartitioned.length - 1] as Points;
    const lastPointOfLastSegment = lastSegment.pts[
      lastSegment.pts.length - 1
    ] as Points["pts"][number];

    if (
      target >=
        (transformTime?.(s, firstPointOfStage.t) ?? firstPointOfStage.t) &&
      target <=
        (transformTime?.(s, lastPointOfLastSegment.t) ??
          lastPointOfLastSegment.t)
    )
      return s;
    if (
      target < (transformTime?.(s, firstPointOfStage.t) ?? firstPointOfStage.t)
    )
      right = mid - 1;
    else left = mid + 1;
  }

  return undefined;
};

/**
 * Find the segment such that the time falls within the segment's domain
 */
const binarySearchSegment = (
  segments: StageFromAPI["ptsPartitioned"],
  target: number, // time
  transformTime?: (t: number) => number
): Points | undefined => {
  let left = 0;
  let right = segments.length - 1;

  while (left <= right) {
    const mid = Math.floor((left + right) / 2);
    const segment = segments[mid] as (typeof segments)[number];
    const firstPointOfSegment = segment.pts[0];
    const lastPointOfLastSegment =
      segment.pts[segment.pts.length - 1] ?? firstPointOfSegment;

    if (
      target >=
        (transformTime?.(firstPointOfSegment.t) ?? firstPointOfSegment.t) &&
      target <=
        (transformTime?.(lastPointOfLastSegment.t) ?? lastPointOfLastSegment.t)
    )
      return segment;
    if (
      target < (transformTime?.(firstPointOfSegment.t) ?? firstPointOfSegment.t)
    )
      right = mid - 1;
    else left = mid + 1;
  }

  return undefined;
};

/**
 * Find the point such that the time falls within the point's domain
 */
const binarySearchPointWithLinearInterpolation = (
  points: Points,
  targetTime: number, // time
  transformTime?: (t: number) => number
): Points["pts"][number] => {
  let left = 0;
  let right = points.pts.length - 1;

  while (left <= right) {
    const mid = Math.floor((left + right) / 2);
    const point = points.pts[mid] as (typeof points)["pts"][number];
    const pointTime = transformTime?.(point.t) ?? point.t;

    if (targetTime === pointTime)
      return {
        t: pointTime,
        v: point.v,
        // d: point.d,
      };
    if (targetTime < pointTime) {
      const beforeMid = points.pts[mid - 1];
      if (beforeMid) {
        const beforeMidTime = transformTime?.(beforeMid.t) ?? beforeMid.t;

        if (beforeMidTime < targetTime) {
          // handle if it falls between 2 points
          // linear interp
          const slope = (point.v - beforeMid.v) / (pointTime - beforeMidTime);
          const interpolatedValue =
            slope * (targetTime - beforeMidTime) + beforeMid.v;
          return {
            t: targetTime,
            v: interpolatedValue,
            // d: points.d, // all points have the same anomaly level
          };
        }
      }
      right = mid - 1;
    } else {
      const afterMid = points.pts[mid + 1];
      if (afterMid) {
        const afterMidTime = transformTime?.(afterMid.t) ?? afterMid.t;

        if (afterMidTime > targetTime) {
          // handle if it falls between 2 points
          // linear interp
          const slope = (afterMid.v - point.v) / (afterMidTime - pointTime);
          const interpolatedValue = slope * (targetTime - pointTime) + point.v;
          return {
            t: targetTime,
            v: interpolatedValue,
            // d: points.d, // all points have the same anomaly level
          };
        }
      }
      left = mid + 1;
    }
  }

  throw new Error(
    "Make sure to only call this if target time is within the domain of the points. Then, with linear interpolation you should always find a point. "
  );
};

export {
  binarySearchStage,
  binarySearchSegment,
  binarySearchPointWithLinearInterpolation as binarySearchPoint,
};
