import { compact } from "lodash";
import type { AssertionLevel } from "../components/ResultsComparison/SidePane/AssertionDetails";
import type { Metric } from "../models/GeneralMetricsResponse";
import type {
  LMChecklistAssertionScore,
  LMChecklistAssertionsResponse,
} from "../models/LMChecklistAssertionsResponse";
import type {
  LMChecklistDetails,
  LMChecklistSydneyReply,
} from "../models/LMChecklistDetailsResponse";
import type { PassFailRate } from "../models/LMChecklistPassFailRates";
import type { MetricsData } from "./metricsHelper";

export type ExperimentType = "control" | "experiment";

export type LMChecklistAssertionViewType = "All" | "Regressions";

export type PassFailRateByQuery = {
  query: string;
  queryHash: string;
  segment: string;
  criticalControl?: PassFailRate;
  criticalExperiment?: PassFailRate;
  expectedControl?: PassFailRate;
  expectedExperiment?: PassFailRate;
  aspirationalControl?: PassFailRate;
  aspirationalExperiment?: PassFailRate;
  tControl?: PassFailRate;
  tExperiment?: PassFailRate;
  sydneyReply: LMChecklistSydneyReply;
};

export type ExportedAssertionsData = {
  query: string;
  segment: string;
  assertion: string;
  level: string;
  score_control: string;
  score_experiment: string;
  rationale_control: string;
  rationale_experiment: string;
  sydney_reply_control: string;
  sydney_reply_experiment: string;
};

export const getPassFailRatesByQuery = (
  assertionsResponse: LMChecklistAssertionsResponse,
  record: LMChecklistDetails,
): PassFailRateByQuery => {
  const assertions = assertionsResponse.filter(
    (a) => a.queryHash === record.queryHash,
  );

  const criticalAssertions = assertions.filter((a) => a.level === "critical");
  const expectedAssertions = assertions.filter((a) => a.level === "expected");
  const aspirationalAssertions = assertions.filter(
    (a) => a.level === "aspirational",
  );

  return {
    query: record.query,
    queryHash: record.queryHash,
    segment: record.segment,
    criticalControl: getPassFailRate(criticalAssertions, "control"),
    criticalExperiment: getPassFailRate(criticalAssertions, "experiment"),
    expectedControl: getPassFailRate(expectedAssertions, "control"),
    expectedExperiment: getPassFailRate(expectedAssertions, "experiment"),
    aspirationalControl: getPassFailRate(aspirationalAssertions, "control"),
    aspirationalExperiment: getPassFailRate(
      aspirationalAssertions,
      "experiment",
    ),
    tControl: getPassFailRate(assertions, "control"),
    tExperiment: getPassFailRate(assertions, "experiment"),
    sydneyReply: record.sydneyReply,
  };
};

export const getPassFailRatesByLevel = (
  assertionsResponse: LMChecklistAssertionsResponse,
  level: AssertionLevel,
) => {
  const levelAssertions = assertionsResponse.filter((a) => a.level === level);
  return {
    control: getPassFailRate(levelAssertions, "control"),
    experiment: getPassFailRate(levelAssertions, "experiment"),
  };
};

const getPassFailRate = (
  assertions: LMChecklistAssertionsResponse,
  exp: ExperimentType,
) => {
  const passed = assertions.filter(
    (a) => a.score[exp as keyof LMChecklistAssertionScore] === 2,
  ).length;

  const missing = assertions.filter(
    (a) => a.score[exp as keyof LMChecklistAssertionScore] === undefined,
  ).length;

  const total = assertions.length;

  return total ? { passed: passed, total: total, missing: missing } : undefined;
};

export const getPFValue = (pfRate: PassFailRate | undefined) => {
  if (pfRate !== undefined) {
    if (pfRate.missing > 0 && pfRate.missing >= pfRate.total) {
      return `Missing`;
    }

    return `${pfRate.passed} of ${pfRate.total - pfRate.missing}`;
  }

  return "";
};

export const displayLMChecklistPassFailRatesAreLoading = (record: Metric) => {
  record.score_control = "Loading...";
  record.score_experiment = "Loading...";
  record.score_delta = "N/A";
  record.p_value = "N/A";

  return record;
};

export const useLMChecklistPassFailRates = (
  record: Metric,
  assertions: LMChecklistAssertionsResponse,
  level: AssertionLevel,
) => {
  const pfRate = getPassFailRatesByLevel(assertions, level);
  record.score_control = getPFValue(pfRate.control);
  record.score_experiment = getPFValue(pfRate.experiment);
  record.score_delta = "N/A";
  record.p_value = "N/A";

  return record;
};

// Regexes to attempt matching rationale in descending order of priority
export const rationaleRegexes = [/^(#.+)$/s, /^.+rationale: (.+)$/s, /^(.+)$/s];

export const parseRationale = (response: string | undefined) => {
  if (!response) return "N/A";
  for (const regex of rationaleRegexes) {
    const match = response?.trim().match(regex);
    if (match) {
      return match[1].replace(/# /g, "").trim();
    }
  }
  return "N/A";
};

export const exportedAssertionsData = (
  assertions: LMChecklistAssertionsResponse,
): ExportedAssertionsData[] =>
  assertions.map((assertion) => {
    const score_c = assertion.score.control;
    const score_e = assertion.score.experiment;
    const response_c = assertion.response.control;
    const response_e = assertion.response.experiment;
    const reply_c = assertion.sydneyReply.control;
    const reply_e = assertion.sydneyReply.experiment;

    return {
      query: assertion.query,
      segment: assertion.segment,
      assertion: assertion.assertion,
      level: assertion.level,
      score_control: score_c !== undefined ? score_c.toString() : "N/A",
      score_experiment: score_e !== undefined ? score_e.toString() : "N/A",
      rationale_control: parseRationale(response_c),
      rationale_experiment: parseRationale(response_e),
      sydney_reply_control: reply_c !== undefined ? reply_c : "N/A",
      sydney_reply_experiment: reply_e !== undefined ? reply_e : "N/A",
    };
  });

export const downloadData = (
  assertionsData: ExportedAssertionsData[],
  jobName: string,
) => {
  if (assertionsData.length === 0) return;

  const csvData: string[][] = [];
  csvData.push(Object.keys(assertionsData[0]));
  assertionsData.forEach((assertion) => {
    csvData.push(
      Object.values(assertion).map((value) => {
        return `"${value.replace(/"/g, '""')}"`;
      }),
    );
  });
  const csvContent = csvData.map((row) => row.join(",")).join("\n");

  // Create a Blob object and create a download link
  const blob = new Blob([csvContent], { type: "text/csv" });
  const url = window.URL.createObjectURL(blob);
  const a = document.createElement("a");
  a.href = url;
  a.download = `[Assertions]_${jobName}.csv`;
  a.click();
  window.URL.revokeObjectURL(url);
};

export const hasCriticalRegression = (
  assertions: LMChecklistAssertionsResponse,
) =>
  assertions.some(
    (assertion) =>
      assertion.level === "critical" &&
      assertion.score.control == 2 &&
      assertion.score.experiment !== undefined &&
      assertion.score.experiment == 0,
  );

const listedFailureMetrics = [
  "NDCG_LLM_labeler_failure_rate",
  "sbsleo_failure",
  "groundleo_claimbreak_failure",
  "acrueleo_failure",
  "codeleo_failure",
];
export const metricsFailureAlert = (metricsData: MetricsData[]) => {
  const failedMetrics = (metricFailure: string) => {
    const metrics = metricsData.find((row) => row.metric === metricFailure);
    const isInvalidMetrics =
      Number(metrics?.control) > 0.01 || Number(metrics?.treatment) > 0.01;
    if (isInvalidMetrics) {
      return {
        failureMetrics: metricFailure,
        metrics: metricFailure.split("_failure")[0],
      };
    }
    return undefined;
  };

  const metricsList = compact(
    listedFailureMetrics.map((metric) => failedMetrics(metric)),
  );

  if (metricsList.length > 0) {
    return `Warning: ${metricsList
      .map((metric) => metric?.metrics)
      .join(",")} is not reliable, with ${metricsList
      .map((metric) => metric?.failureMetrics)
      .join(",")}_failure_rate is larger than 1%.`;
  }
  return undefined;
};
