import { useMemo } from 'react';

import { useSelector } from 'react-redux';

import { ForecastParameterType } from 'shared/lib/types';
import { selectTrial } from 'shared/state/slices/trialSlice';

import useForecastParamEditorRows from './useForecastParamEditorRows';

type PatientAverageRow =
  | {
      row_title: string;
      actual_average: number;
      forecasted_average: number;
      isTotal?: boolean;
      trialCurrency?: string;
      parameter_type?: ForecastParameterType;
      parameter_trace_id?: string;
      region_name?: string;
    }
  | { isEmpty: boolean };

function usePatientAveragesGridRows(): PatientAverageRow[] {
  const trial = useSelector(selectTrial);

  const avgCostPerPatientRows = useForecastParamEditorRows(
    ForecastParameterType.AVERAGE_COST_PER_PATIENT,
  );
  const avgProcedureCostRows = useForecastParamEditorRows(
    ForecastParameterType.AVERAGE_PROCEDURE_COST,
  );
  const avgTreatmentLengthRows = useForecastParamEditorRows(
    ForecastParameterType.AVERAGE_TREATMENT_LENGTH,
  );
  const avgLabCostRows = useForecastParamEditorRows(
    ForecastParameterType.AVERAGE_LAB_COST,
  );

  const trialCurrency = trial.currency;

  return useMemo(() => {
    const rows: PatientAverageRow[] = [];

    for (const row of avgCostPerPatientRows) {
      if (row.region_name !== 'Global') {
        rows.push({
          row_title: `Avg total cost per patient ${row.region_name}`,
          actual_average: row.actual,
          forecasted_average: row.forecast,
          parameter_type: row.parameter_type,
          parameter_trace_id: row.parameter_trace_id,
          region_name: row.region_name,
          trialCurrency,
        });
      }
    }

    const globalPatientRow = avgCostPerPatientRows.find(
      ({ region_name }) => region_name === 'Global',
    );
    if (globalPatientRow) {
      rows.push({
        row_title: 'Avg total cost per patient global',
        actual_average: globalPatientRow.actual,
        forecasted_average: globalPatientRow.forecast,
        parameter_type: globalPatientRow.parameter_type,
        parameter_trace_id: globalPatientRow.parameter_trace_id,
        region_name: globalPatientRow.region_name,
        isTotal: true,
        trialCurrency,
      });
    }
    rows.push({ isEmpty: true });
    for (const row of avgProcedureCostRows) {
      if (row.region_name !== 'Global') {
        rows.push({
          row_title: `Avg total procedure cost per patient ${row.region_name}`,
          actual_average: row.actual,
          forecasted_average: row.forecast,
          parameter_type: row.parameter_type,
          parameter_trace_id: row.parameter_trace_id,
          region_name: row.region_name,
          trialCurrency,
        });
      }
    }

    const globalProcedureRow = avgProcedureCostRows.find(
      ({ region_name }) => region_name === 'Global',
    );
    if (globalProcedureRow) {
      rows.push({
        row_title: 'Avg total procedure per patient global',
        actual_average: globalProcedureRow.actual,
        forecasted_average: globalProcedureRow.forecast,
        parameter_type: globalProcedureRow.parameter_type,
        parameter_trace_id: globalProcedureRow.parameter_trace_id,
        region_name: globalProcedureRow.region_name,
        isTotal: true,
        trialCurrency,
      });
    }

    rows.push({ isEmpty: true });

    const avgTreatmentLengthRow = avgTreatmentLengthRows.find(
      ({ region_name }) => region_name === 'Global',
    );
    rows.push({
      row_title: 'Avg treatment length (days)',
      actual_average: avgTreatmentLengthRow?.actual ?? 0,
      forecasted_average: avgTreatmentLengthRow?.forecast ?? 0,
      parameter_type: avgTreatmentLengthRow?.parameter_type,
      parameter_trace_id: avgTreatmentLengthRow?.parameter_trace_id ?? '',
    });

    const avgLabCostRow = avgLabCostRows.find(
      ({ region_name }) => region_name === 'Global',
    );
    rows.push({
      row_title: 'Avg lab cost',
      actual_average: avgLabCostRow?.actual ?? 0,
      forecasted_average: avgLabCostRow?.forecast ?? 0,
      parameter_type: avgLabCostRow?.parameter_type,
      parameter_trace_id: avgLabCostRow?.parameter_trace_id ?? '',
      trialCurrency,
    });

    return rows;
  }, [
    trialCurrency,
    avgCostPerPatientRows,
    avgTreatmentLengthRows,
    avgLabCostRows,
    avgProcedureCostRows,
  ]);
}

export default usePatientAveragesGridRows;
