import { gql } from "@apollo/client";
import { CashBandName, CurrencyCode } from "@asmbl/shared/constants";
import { Currency, exchangeFromTo } from "@asmbl/shared/currency";
import {
  add,
  divide,
  formatCurrency,
  Money,
  money,
  map as moneyMap,
  multiply,
  ratio,
  totalGrossEquityValue,
  unitsOfTotalGrossValue,
  zero,
} from "@asmbl/shared/money";
import { contramap } from "@asmbl/shared/sort";
import { zip } from "@asmbl/shared/utils";
import {
  Divider,
  makeStyles,
  Paper,
  Table,
  TableBody,
  TableCell,
  TableContainer,
  TableHead,
  TableRow,
} from "@material-ui/core";
import { useState } from "react";
import { useNavigate } from "react-router-dom";
import { AssembleButton } from "src/components/AssembleButton/AssembleButton";
import { useCompStructure } from "src/components/CompStructureContext";
import {
  BandUnit,
  CashBandInput,
  EditableBandPointCell_cashBandPoint as CashBandPoint,
  EquityBandInput,
  EditableBandPointCell_equityBandPoint as EquityBandPoint,
  PositionDetailEditTable_position as Position,
  PositionType,
  PositionDetailEditTable_valuation as Valuation,
} from "../../__generated__/graphql";
import { useTrack } from "../../analytics";
import { AssembleTypography } from "../../components/AssembleTypography";
import { useCurrencies } from "../../components/CurrenciesContext";
import { SaveButton } from "../../components/Form/SaveButton";
import {
  assertLocationSelected,
  useLocations,
} from "../../components/LocationsContext";
import UnsavedChangesWarning from "../../components/UnsavedChangesWarning";
import { isBandPointDefined } from "../../models/BandPoint";
import { currencySymbol } from "../../models/Currency";
import { useEmplaceBands } from "../../mutations/Position";
import { GRAY_8 } from "../../theme";
import { ArrayValue, bandNameComparator } from "../../utils";
import { CompensationHeading } from "./CompensationHeading";
import { CustomizeCurrencyModal } from "./CustomizeCurrencyModal";
import { EditableBandPointCell } from "./EditableBandPointCell";
import { EditingMarketWarning } from "./EditingMarketWarning";
import { PositionDetailTableHeader } from "./PositionDetailTableHeader";

const useStyles = makeStyles((theme) => ({
  totalCell: {
    backgroundColor: GRAY_8,
    borderLeft: `1px solid ${theme.palette.divider}`,
  },
}));

type Props = {
  position: Position;
  valuation: Valuation;
};

type CashBand = ArrayValue<Position["unadjustedCashBands"]>;
type EquityBand = ArrayValue<Position["unadjustedEquityBands"]>;

export function PositionDetailEditTable({
  position,
  valuation,
}: Props): JSX.Element {
  const classes = useStyles();
  const navigate = useNavigate();
  const { selectedCurrency, currencies } = useCurrencies();
  const { selectedLocation } = useLocations();
  assertLocationSelected(selectedLocation);
  const { trackEvent } = useTrack();
  const { compStructure } = useCompStructure();

  const { id: positionId } = position;
  const unsortedCashBands = position.unadjustedCashBands ?? [];
  const unsortedEquityBands = position.unadjustedEquityBands ?? [];

  const persistedCashBands = unsortedCashBands
    .slice()
    .sort(contramap((band) => band.name, bandNameComparator));
  const persistedEquityBands = unsortedEquityBands
    .slice()
    .sort(contramap((band) => band.name, bandNameComparator));

  const [cashBands, setCashBands] = useState(persistedCashBands);
  const [equityBands, setEquityBands] = useState(persistedEquityBands);
  const [touched, setTouched] = useState(false);
  const [currencyModalOpen, setCurrencyModalOpen] = useState(false);

  const bandNames = [...cashBands, ...equityBands].map((band) => band.name);
  const bandPoints =
    cashBands.at(0)?.bandPoints.map((bandPoint) => bandPoint.name) ?? [];

  const workingHoursPerYear =
    (compStructure?.employmentHoursPerWeek ?? 40) *
    (compStructure?.employmentWeeksPerYear ?? 52);

  const emplaceBands = useEmplaceBands(
    positionId,
    selectedCurrency.code,
    selectedLocation[0].id,
    selectedLocation[1].id
  );

  const trackSaveEvent = () => {
    const touchCounts = [
      ...zip(cashBands, persistedCashBands).map(
        ([cashBand, persistedCashBand]) =>
          cashTouchCounts(cashBand, persistedCashBand)
      ),
      ...zip(equityBands, persistedEquityBands).map(
        ([equityBand, persistedEquityBand]) =>
          equityTouchCounts(equityBand, persistedEquityBand)
      ),
    ];

    const firstCurrency = cashBands.at(0)?.currencyCode;

    trackEvent({
      object: "Edit Bands Form",
      action: "Saved",
      compComponentCount: touchCounts
        .map((c) => c.compComponentCount)
        .reduce((a, b) => a + b, 0),
      bandPointCount: touchCounts
        .map((c) => c.bandPointCount)
        .reduce((a, b) => a + b, 0),
      // If all cash is in the same currency, use that code. Otherwise, 'Mixed'.
      cashCurrency: cashBands.every((b) => b.currencyCode === firstCurrency)
        ? firstCurrency
        : "Mixed",
      // Equity is always in a single currency.
      equityCurrency: equityBands.at(0)?.currencyCode,
    });
  };

  const onSave = () => {
    trackSaveEvent();

    return emplaceBands(
      cashBands.map(cashBandToInput),
      equityBands.map(equityBandToInput)
    ).then(
      () => {
        setTouched(false);
        return true;
      },
      () => false
    );
  };

  const handleCustomizeCurrencies = () => {
    trackEvent({ object: "Band Currency Modal", action: "Opened" });
    setCurrencyModalOpen(true);
  };

  const handleClose = () => {
    // Remove /edit from the URL.
    navigate(`/positions/${positionId}`, { replace: true });
  };

  const handleCancel = () => {
    trackEvent({ object: "Edit Bands Form", action: "Canceled" });
    handleClose();
  };

  const handleCashBandPointChange = (
    prevBandPoint: CashBandPoint,
    newValue: number
  ) => {
    const isHourly =
      position.type === PositionType.HOURLY &&
      prevBandPoint.bandName === CashBandName.SALARY;

    const newRate = money(newValue, prevBandPoint.value.currencyCode);
    const annualRate = isHourly
      ? multiply(newRate, workingHoursPerYear)
      : newRate;

    const newBandPoint = {
      ...prevBandPoint,
      value: {
        ...prevBandPoint.value,
        rate: newRate,
        annualRate,
      },
      annualCashEquivalent: annualRate,
    };
    setCashBands(replacePoint(prevBandPoint, newBandPoint));
    setTouched(true);
  };

  const handleEquityBandPointChange = (
    prevBandPoint: EquityBandPoint,
    newRawValue: number
  ) => {
    const newBandPoint = calculateNewEquityBandPoint(
      valuation,
      prevBandPoint,
      newRawValue
    );
    setEquityBands(replacePoint(prevBandPoint, newBandPoint));
    setTouched(true);
  };

  const handleChangeCurrencies = (
    newCashBands: {
      name: string;
      currencyCode: CurrencyCode;
    }[]
  ) => {
    // Bands are in the same order as before.
    setCashBands((cashBands) =>
      cashBands.map((prevCashBand, index) => {
        const newCashBand = newCashBands[index];
        if (prevCashBand.currencyCode === newCashBand.currencyCode) {
          return prevCashBand;
        }

        const fromCurrency = currencies.get(prevCashBand.currencyCode);
        const toCurrency = currencies.get(newCashBand.currencyCode);
        if (!fromCurrency || !toCurrency) {
          return prevCashBand;
        }

        return {
          ...prevCashBand,
          currencyCode: newCashBand.currencyCode,
          bandPoints: prevCashBand.bandPoints.map((prevBandPoint) => ({
            ...prevBandPoint,
            value: {
              __typename: "CashValue" as const,
              rate: moneyMap(
                Math.round,
                prevBandPoint.value.rate !== null
                  ? exchangeFromTo(
                      prevBandPoint.value.rate,
                      fromCurrency,
                      toCurrency
                    )
                  : zero(toCurrency.code)
              ),
              annualRate: moneyMap(
                Math.round,
                prevBandPoint.value.annualRate !== null
                  ? exchangeFromTo(
                      prevBandPoint.value.annualRate,
                      fromCurrency,
                      toCurrency
                    )
                  : zero(toCurrency.code)
              ),
              currencyCode: toCurrency.code,
            },
            annualCashEquivalent: moneyMap(
              Math.round,
              exchangeFromTo(
                prevBandPoint.annualCashEquivalent,
                fromCurrency,
                toCurrency
              )
            ),
          })),
        };
      })
    );
    setTouched(true);
    setCurrencyModalOpen(false);
  };

  return (
    <>
      <UnsavedChangesWarning pageEdited={touched} />
      {currencyModalOpen && (
        <CustomizeCurrencyModal
          cashBands={cashBands}
          equityBands={equityBands}
          onSave={handleChangeCurrencies}
          onClose={() => setCurrencyModalOpen(false)}
        />
      )}
      <CompensationHeading
        actions={
          <>
            <AssembleButton
              onClick={handleCustomizeCurrencies}
              variant="outlined"
              size="medium"
              label="Customize Currency"
            />

            <Divider orientation="vertical" flexItem />
            <AssembleButton
              onClick={handleCancel}
              variant="outlined"
              size="medium"
              label="Cancel"
            />
            <SaveButton
              onSave={onSave}
              onAfterSave={handleClose}
              cooldown={500}
              hideEndIcon
            />
          </>
        }
      />
      <EditingMarketWarning />
      <TableContainer component={Paper} elevation={0}>
        <Table>
          <TableHead>
            <PositionDetailTableHeader bandNames={bandNames} />
          </TableHead>
          <TableBody>
            {bandPoints.map((bandPointName) => (
              <TableRow key={bandPointName}>
                <TableCell>{bandPointName}</TableCell>
                {cashBands.map((cashBand) => {
                  const bandPoint = cashBand.bandPoints.find(
                    (bp) => bp.name === bandPointName
                  );

                  if (!bandPoint) {
                    return null;
                  }

                  return (
                    <EditableBandPointCell
                      key={`${cashBand.name}:${cashBand.currencyCode}`}
                      position={position}
                      bandPoint={bandPoint}
                      onChange={(value) =>
                        handleCashBandPointChange(bandPoint, value)
                      }
                    />
                  );
                })}
                {equityBands.map((equityBand) => {
                  const bandPoint = equityBand.bandPoints.find(
                    (bp) => bp.name === bandPointName
                  );

                  if (!bandPoint) {
                    return null;
                  }

                  return (
                    <EditableBandPointCell
                      key={equityBand.name}
                      position={position}
                      bandPoint={bandPoint}
                      onChange={(value) =>
                        handleEquityBandPointChange(bandPoint, value)
                      }
                    />
                  );
                })}
                <TableCell className={classes.totalCell} align="right">
                  {totalsForPoint(
                    cashBands,
                    equityBands,
                    bandPointName,
                    currencies
                  ).map((total) => (
                    <AssembleTypography key={total.currency}>
                      {total.value
                        ? formatCurrency(total.value)
                        : `${currencySymbol(total.currency)} -`}
                    </AssembleTypography>
                  ))}
                </TableCell>
              </TableRow>
            ))}
          </TableBody>
        </Table>
      </TableContainer>
    </>
  );
}

PositionDetailEditTable.fragments = {
  position: gql`
    ${EditableBandPointCell.fragments.position}
    ${CustomizeCurrencyModal.fragments.cashBand}
    ${EditableBandPointCell.fragments.cashBandPoint}
    ${CustomizeCurrencyModal.fragments.equityBand}
    ${EditableBandPointCell.fragments.equityBandPoint}
    fragment PositionDetailEditTable_position on Position {
      ...EditableBandPointCell_position
      id
      type
      unadjustedCashBands(marketId: $marketId) {
        id
        name
        currencyCode
        ...CustomizeCurrencyModal_cashBand
        bandPoints {
          id
          name
          bandUnit
          annualCashEquivalent
          ...EditableBandPointCell_cashBandPoint
        }
      }
      unadjustedEquityBands(marketId: $marketId) {
        id
        name
        currencyCode
        ...CustomizeCurrencyModal_equityBand
        bandPoints {
          id
          name
          annualCashEquivalent
          vestingMonths
          value {
            ... on PercentValue {
              decimalValue
            }
          }
          ...EditableBandPointCell_equityBandPoint
        }
      }
    }
  `,
  valuation: gql`
    fragment PositionDetailEditTable_valuation on Valuation {
      fdso
      valuationMoney
    }
  `,
};

function replacePoint<
  B extends { name: string; bandPoints: BP[] },
  BP extends { bandName: string; name: string },
>(prevPoint: BP, newPoint: BP): (bands: B[]) => B[] {
  return (bands) =>
    bands.map((band) =>
      band.name === prevPoint.bandName
        ? {
            ...band,
            bandPoints: band.bandPoints.map((point) =>
              point.name === prevPoint.name ? newPoint : point
            ),
          }
        : band
    );
}

/**
 * To calculate the new EquityBandPoint, we need to do two things:
 *   1. Calculate the new value in the same units as the old value.
 *   2. Recompute the alternate forms (cash, units, percentage)
 *
 * The most direct path to computing these things varies based on
 * the unit type we have, so to simplify we find the cash value first.
 * Then, we can use that cash value to calculate the alternate forms in
 * a consistent way for all the different unit types.
 */
export function calculateNewEquityBandPoint(
  valuation: Valuation,
  prevBandPoint: EquityBandPoint,
  newRawValue: number
): EquityBandPoint {
  const prevValue = prevBandPoint.value;
  let newValue: EquityBandPoint["value"];
  let newAnnualRate: Money;

  /*
    Depending on the unit, we compute the new value and total cash value.
    This allows the rest of the code to be ambivalent to the unit.
  */
  if (prevValue.__typename === "CashValue") {
    newAnnualRate = money(newRawValue, prevValue.currencyCode);
    newValue = {
      ...prevValue,
      annualRate: newAnnualRate,
      rate: newAnnualRate,
    };
  } else if (prevValue.__typename === "UnitValue") {
    newValue = { ...prevValue, unitValue: newRawValue };
    newAnnualRate = totalGrossEquityValue(
      valuation.fdso,
      valuation.valuationMoney,
      newRawValue
    );
  } else {
    newValue = {
      ...prevValue,
      percentValue: newRawValue,
      decimalValue: newRawValue / 100,
    };
    newAnnualRate = multiply(valuation.valuationMoney, newRawValue / 100);
  }

  /*
    Calculate the other alternate forms (we already have total cash from above)
  */
  const newAnnualCashEquivalent = divide(
    newAnnualRate,
    prevBandPoint.vestingMonths / 12
  );
  const newTotalUnits = unitsOfTotalGrossValue(
    valuation.fdso,
    valuation.valuationMoney,
    newAnnualRate
  );
  const newTotalPercentOwnership = ratio(
    newAnnualRate,
    valuation.valuationMoney
  );

  /*
    Put it all together
  */
  return {
    ...prevBandPoint,
    value: newValue,
    annualCashEquivalent: newAnnualCashEquivalent,
    totalGrossValue: newAnnualRate,
    totalUnits: prevBandPoint.totalUnits === null ? null : newTotalUnits,
    totalPercentOwnership:
      prevBandPoint.totalPercentOwnership === null
        ? null
        : {
            ...prevBandPoint.totalPercentOwnership,
            decimalValue: newTotalPercentOwnership,
          },
  };
}

/**
 * Finds the total cash value for a given band point across all bands.
 * This returns a list of values: if the individual bands are defined in
 * different currencies, an equivalent value is included in each currency.
 */
function totalsForPoint(
  cashBands: CashBand[],
  equityBands: EquityBand[],
  pointName: string,
  currencies: Map<CurrencyCode, Currency>
): { value: Money | null; currency: CurrencyCode }[] {
  // The types are easier if we split cashBands and equityBands
  const points = [
    ...cashBands
      .flatMap((band) => band.bandPoints)
      .filter((point) => point.name === pointName),
    ...equityBands
      .flatMap((band) => band.bandPoints)
      .filter((point) => point.name === pointName),
  ];

  const pointValues = points.flatMap((point) => {
    // Strip values down to only the Money and the Currency to exchange it
    const currency = currencies.get(point.annualCashEquivalent.currency);
    return currency
      ? [
          {
            value: isBandPointDefined(point)
              ? point.annualCashEquivalent
              : null,
            currency,
          },
        ]
      : [];
  });

  const targetCurrencies = [
    ...new Set(pointValues.map((point) => point.currency)),
  ];

  const isEveryValueNull = pointValues.every(({ value }) => value === null);

  return targetCurrencies.map((toCurrency) => ({
    value: isEveryValueNull
      ? null
      : pointValues
          .map(({ value, currency: fromCurrency }) =>
            value
              ? exchangeFromTo(value, fromCurrency, toCurrency)
              : zero(toCurrency.code)
          )
          .reduce(add),
    currency: toCurrency.code,
  }));
}

function cashBandToInput(cashBand: CashBand): CashBandInput {
  return {
    name: cashBand.name,
    currencyCode: cashBand.currencyCode,
    bandPoints: cashBand.bandPoints.map((bandPoint) => ({
      name: bandPoint.name,
      unit: bandPoint.bandUnit,
      value: bandPoint.value.rate?.value ?? 0,
    })),
  };
}

function cashTouchCounts(
  cashBand: CashBand,
  persistedCashBand: CashBand
): { compComponentCount: number; bandPointCount: number } {
  if (cashBand.currencyCode !== persistedCashBand.currencyCode) {
    return {
      compComponentCount: 1,
      bandPointCount: cashBand.bandPoints.length,
    };
  }

  // Currency is the same, so only need to check values.
  const bandPointCount = cashBand.bandPoints.filter((bandPoint, index) => {
    const persistedBandPoint = persistedCashBand.bandPoints[index];
    return (
      bandPoint.value.annualRate?.value !==
      persistedBandPoint.value.annualRate?.value
    );
  }).length;

  return {
    compComponentCount: bandPointCount > 0 ? 1 : 0,
    bandPointCount,
  };
}

function equityTouchCounts(
  equityBand: EquityBand,
  persistedEquityBand: EquityBand
): { compComponentCount: number; bandPointCount: number } {
  // Equity cannot change currency or unit, so only care about values.
  const bandPointCount = equityBand.bandPoints.filter((bandPoint, index) => {
    const persistedBandPoint = persistedEquityBand.bandPoints[index];
    return (
      bandPoint.annualCashEquivalent.value !==
      persistedBandPoint.annualCashEquivalent.value
    );
  }).length;

  return {
    compComponentCount: bandPointCount > 0 ? 1 : 0,
    bandPointCount,
  };
}

function equityBandToInput(equityBand: EquityBand): EquityBandInput {
  return {
    name: equityBand.name,
    currencyCode: equityBand.currencyCode,
    bandPoints: equityBand.bandPoints.map((bandPoint) =>
      bandPoint.value.__typename === "CashValue"
        ? {
            name: bandPoint.name,
            unit: BandUnit.CASH,
            value: bandPoint.value.annualRate?.value ?? 0,
          }
        : bandPoint.value.__typename === "UnitValue"
          ? {
              name: bandPoint.name,
              unit: BandUnit.UNITS,
              value: bandPoint.value.unitValue ?? 0,
            }
          : {
              name: bandPoint.name,
              unit: BandUnit.PERCENTAGE,
              value: (bandPoint.value.percentValue ?? 0) / 100,
            }
    ),
  };
}
