import { Radio } from 'components/GraphView/AddNodeDialog';
import { Graph } from 'components/ModelTabs/Graph';
import { Model } from 'components/ModelTabs/Model';
import { getNodesElligibleForDistribution, updateDateAndVersion, useRerender } from 'components/Utils/funcs';
import { Dispatch, SetStateAction, useEffect, useState } from 'react';
import { BaseAnswer, BaseQuery, CausalModel, Distribution } from 'components/openapi';
import Latex from 'react-latex';
import { postBaseQueryGetAnswer } from 'components/Utils/BackendApi';
import { Collapse, Alert, IconButton } from '@mui/material';
import CloseIcon from '@mui/icons-material/Close';
import { textWithLineBreak } from 'components/QueryTabs/QueryTab';
import { NodeIterationDisplay } from 'components/MechanismTabs/NodeIterationDisplay';
import { CustomIcons, setDistributionIcons, setMechanismIcons } from 'components/GraphView/MechDistIcons';
import { useTranslation } from 'react-i18next';
import distributions, {
  DEFAULT_DISTRIBUTION,
  Distribution as PredefinedDist,
  continuousDistributions,
  discreteDistributions,
} from './DistributionTypes';
import { ApexAreaChart } from './ApexArea';
import { parseDistString } from './DistributionUtils';

const queries: Record<string, Record<number, { densityQuery: BaseQuery; densityResponse?: BaseAnswer }>> = {};

type SingleDistInput = {
  hasDistribution?: boolean;
  distType?: 'discrete' | 'continuous';
  hasUserDefinedDist?: boolean;
  distAbbr?: string;
  parametersInput?: string[];
  userDefinedCode?: string;
};

// model.id to node.id to input
type ModelToDistInput = Record<string, Record<number, SingleDistInput>>;
const distributionsInput: ModelToDistInput = {};

/* const isDefaultDist = (dist: SingleDistInput | undefined): boolean => {
  return (
    dist !== undefined &&
    dist.hasUserDefinedDist !== true &&
    dist.distAbbr === DEFAULT_DISTRIBUTION.name &&
    equals(
      dist.parametersInput?.map((str) => Number(str)),
      DEFAULT_DISTRIBUTION.params.map((par) => par.value),
    )
  );
}; */

const distInputToString = (dist: SingleDistInput | undefined): string | undefined => {
  if (dist === undefined || !dist?.hasDistribution) return undefined;
  if (
    dist.hasUserDefinedDist !== true &&
    dist.distAbbr !== undefined &&
    dist.parametersInput !== undefined &&
    distributions[dist.distAbbr].params.length === dist.parametersInput.length
  ) {
    return `${dist.distAbbr}(${dist.parametersInput.join(', ')})`;
  }
  if (dist.hasUserDefinedDist === true && dist.userDefinedCode !== undefined) return dist.userDefinedCode;
  return undefined;
};

const createDefaultDistributionInput = (nodeId: number, model: CausalModel): SingleDistInput => {
  const existingDistribution = Array.from(model.distributions).find((dist) => dist.node === nodeId);
  let predefinedDist;
  if (existingDistribution === undefined) {
    predefinedDist = DEFAULT_DISTRIBUTION;
  } else {
    predefinedDist = parseDistString(existingDistribution.type);
    if (predefinedDist === undefined) {
      // Own distribution
      return {
        hasDistribution: true,
        distType: 'continuous',
        hasUserDefinedDist: true,
        userDefinedCode: existingDistribution.type,
      };
    }
  }

  // Has already defined or using default distribution
  return {
    hasDistribution: true,
    distType: predefinedDist.type,
    hasUserDefinedDist: false,
    distAbbr: predefinedDist.name,
    parametersInput: predefinedDist.params.map((param) => param.value?.toString() ?? ''),
  };
};

export const DistributionTab = ({
  model,
  graph,
  nodeId,
  update,
  setNodeIcons,
  onNextSection: nextSection,
  openMechTab,
  onNodeSelect: selectNode,
}: {
  model: Model;
  graph: Graph;
  update?: number;
  nodeId?: number;
  setNodeIcons: Dispatch<SetStateAction<Record<string, CustomIcons>>>;
  onNextSection: () => void;
  openMechTab: () => void;
  onNodeSelect: (id?: number) => void;
}): JSX.Element => {
  const { t } = useTranslation();

  const rerender = useRerender();
  const distNodes = getNodesElligibleForDistribution(graph.state());
  const [selectedNodeIndex, setSelectedNodeIndex] = useState<number | undefined>(undefined);
  const [alert, setAlert] = useState<{ node: string; opened: boolean }>({ node: '', opened: false });

  if (queries[model.state().id] === undefined) {
    queries[model.state().id] = {};
    distNodes.forEach((node) => {
      queries[model.state().id][node.id] = {
        densityQuery: {
          object: { type: 'model', id: model.state().id },
          body: '',
        },
      };
    });
  }

  if (distributionsInput[model.state().id] === undefined) {
    distributionsInput[model.state().id] = {};
  }

  // Input initialization
  distNodes.forEach((node) => {
    if (distributionsInput[model.state().id][node.id] === undefined) {
      distributionsInput[model.state().id][node.id] = createDefaultDistributionInput(node.id, model.state());
    }
  });

  // Remove input nodes if they are removed through the editor
  const nodesToRemove = Object.keys(distributionsInput[model.state().id]).filter(
    (id) => !distNodes.map((node) => node.id).includes(Number(id)),
  );
  nodesToRemove.forEach((id) => {
    delete distributionsInput[model.state().id][id];
  });

  const nextNode = () => {
    const newIndex = ((selectedNodeIndex ?? -1) + 1) % distNodes.length;
    setSelectedNodeIndex(newIndex);
    selectNode(distNodes[newIndex]?.id);
    setAlert({ node: alert.node, opened: false });
    rerender();
  };
  const previousNode = () => {
    const newIndex = (selectedNodeIndex ?? 1) - 1 < 0 ? distNodes.length - 1 : (selectedNodeIndex ?? 1) - 1;
    setSelectedNodeIndex(newIndex);
    selectNode(distNodes[newIndex]?.id);
    setAlert({ node: alert.node, opened: false });
    rerender();
  };

  const getNodeInput = (index?: number): SingleDistInput | undefined => {
    const nodeIndex = index ?? selectedNodeIndex;
    if (nodeIndex === undefined) return undefined;
    return distributionsInput[model.state().id][distNodes[nodeIndex]?.id] ?? undefined;
  };
  const setNodeInput = (property: string, value: string | boolean | PredefinedDist | string[] | undefined) => {
    if (selectedNodeIndex === undefined || property === undefined) return;
    const currentInput = distributionsInput[model.state().id][distNodes[selectedNodeIndex]?.id] ?? undefined;
    if (currentInput === undefined) return;

    currentInput[property] = value;
    rerender();
  };
  const getNodeDensityInput = (
    index?: number,
  ): { densityQuery: BaseQuery; densityResponse?: BaseAnswer } | undefined => {
    const nodeIndex = index ?? selectedNodeIndex;
    if (nodeIndex === undefined) return undefined;
    return queries[model.state().id][distNodes[nodeIndex]?.id] ?? undefined;
  };

  const hasAllData = (index?): boolean => {
    const input = getNodeInput(index);
    if (input === undefined || input.distType === undefined) return false;
    if (input.hasDistribution === false) return true;
    return (
      input.hasDistribution === true &&
      ((input.hasUserDefinedDist === true && input.userDefinedCode !== undefined && input.userDefinedCode !== '') ||
        (input.hasUserDefinedDist !== true &&
          input.distAbbr !== undefined &&
          input.parametersInput?.filter((val) => val !== '').length === distributions[input.distAbbr].params.length))
    );
  };

  const refreshNodeIcons = (): void => {
    setMechanismIcons(model, graph, setNodeIcons);
    setDistributionIcons(model, graph, setNodeIcons);
  };

  const setModelChanges = (
    inputNodeIndex: number,
    inputNodeId: number,
    currentModel: CausalModel,
  ): { modelChanged: boolean; newModel: CausalModel } => {
    const newModelState = currentModel;
    let modelChanged = false;

    const currentDistributions = Array.from(newModelState.distributions);

    if (getNodeInput(inputNodeIndex)?.hasDistribution === false) {
      const existingDistributionIndex = currentDistributions.findIndex((dist) => dist.node === inputNodeId);
      if (existingDistributionIndex !== -1) {
        currentDistributions.splice(existingDistributionIndex, 1);
        newModelState.distributions = new Set(currentDistributions);
        modelChanged = true;
      }
      return { modelChanged, newModel: newModelState };
    }

    const distString = distInputToString(getNodeInput(inputNodeIndex));
    if (distString !== undefined) {
      modelChanged = true;
      let newDistribution: Distribution;

      let distributionIndex = currentDistributions.findIndex((dist) => dist.node === inputNodeId);
      if (distributionIndex === -1) {
        newDistribution = {
          node: inputNodeId,
          type: '',
        };
        distributionIndex = currentDistributions.length;
      } else {
        newDistribution = { ...currentDistributions[distributionIndex] };
      }
      newDistribution.type = distString;
      currentDistributions[distributionIndex] = newDistribution;
      newModelState.distributions = new Set(currentDistributions);
    }
    return { modelChanged, newModel: newModelState };
  };

  const handleSubmit = async (event?, nodeIndex?: number): Promise<void> => {
    const currentNodeIndex = nodeIndex ?? selectedNodeIndex;
    event?.preventDefault();
    if (!hasAllData(currentNodeIndex) || currentNodeIndex === undefined) return;
    const inputNodeId = distNodes[currentNodeIndex]?.id;
    if (inputNodeId === undefined) return;

    const { modelChanged, newModel } = setModelChanges(currentNodeIndex, inputNodeId, {
      ...model.state(),
    });
    if (modelChanged) {
      model.setState(updateDateAndVersion(newModel));
      graph.setState(updateDateAndVersion({ ...graph.state() }), ['GraphView']);

      if (await graph.save()) model.save();
    }
    refreshNodeIcons();
  };

  const handleAllRemainingSubmit = async (): Promise<void> => {
    let newModelState = { ...model.state() };
    let modelHasChanged = false;
    distNodes.forEach((node, index) => {
      if (hasAllData(index)) {
        const { newModel, modelChanged } = setModelChanges(index, node.id, { ...newModelState });
        if (modelChanged) modelHasChanged = true;
        newModelState = newModel;
      }
    });

    if (modelHasChanged) {
      model.setState(updateDateAndVersion(newModelState));
      graph.setState(updateDateAndVersion({ ...graph.state() }), ['GraphView']);

      if (await graph.save()) model.save();
    }
    refreshNodeIcons();
    nextSection();
  };

  const getDensityGraph = async (index?: number, updateAll = true) => {
    const nodeIndex = index ?? selectedNodeIndex;
    const distString = distInputToString(getNodeInput(nodeIndex));

    if (distString === undefined || nodeIndex === undefined) return;

    // Updates all graphs with the same formula
    const nodesWithSameGraph = updateAll
      ? distNodes.filter((node, i) => distInputToString(getNodeInput(i)) === distString).map((node) => node.id)
      : [distNodes[nodeIndex].id];

    nodesWithSameGraph.forEach((id) => {
      queries[model.state().id][id].densityQuery.body = `PDF(${distString})`;
    });
    const nodeDensity = getNodeDensityInput(nodeIndex);
    if (nodeDensity) {
      const baseAnswer = await postBaseQueryGetAnswer(nodeDensity.densityQuery);
      if (baseAnswer) {
        nodesWithSameGraph.forEach((id) => {
          queries[model.state().id][id].densityResponse = baseAnswer;
        });

        rerender();
      }
    }
  };

  const getSelectedValue = (): string => {
    const currentInput = getNodeInput();
    if (currentInput?.hasUserDefinedDist === true) return 'own';
    if (currentInput?.distAbbr === undefined) return 'DEFAULT';
    const relevantDists = Object.values(distributions)
      .filter((dist) => dist.type === currentInput?.distType)
      .map((dist) => dist.name);
    if (relevantDists.includes(currentInput?.distAbbr)) return currentInput?.distAbbr;
    return 'DEFAULT';
  };

  useEffect(() => {
    if (
      (selectedNodeIndex === undefined || selectedNodeIndex >= distNodes.length) &&
      (nodeId === undefined || !distNodes.map((node) => node.id).includes(nodeId)) &&
      distNodes.length > 0
    ) {
      setSelectedNodeIndex(0);
      if (nodeId === undefined) selectNode(distNodes[0]?.id);
    }
  }, [distNodes]);

  useEffect(() => {
    if (update !== undefined) {
      if (selectedNodeIndex !== undefined) {
        selectNode(distNodes[selectedNodeIndex]?.id);
      }
    }
  }, [update]);

  useEffect(() => {
    setAlert({ node: alert.node, opened: false });
    if (nodeId !== undefined) {
      if (distNodes.map((node) => node.id).includes(nodeId)) {
        const newIndex = distNodes.indexOf(distNodes.find((node) => node.id === nodeId) ?? distNodes[0]);
        if (newIndex !== selectedNodeIndex) {
          setSelectedNodeIndex(newIndex);
        }
      } else {
        const selectedNode = Array.from(graph.state().nodes).find((node) => node.id === nodeId);
        openMechTab();
        if (selectedNode) setAlert({ node: selectedNode.name, opened: true });
      }
    }
  }, [nodeId]);
  if (distNodes.length === 0 && selectedNodeIndex !== undefined) setSelectedNodeIndex(undefined);

  return (
    <div className='distribution-tab-container'>
      {selectedNodeIndex === undefined && <div>{t('distributionTab.noAvailableNodes')}</div>}
      {selectedNodeIndex !== undefined && (
        <div>
          <NodeIterationDisplay
            previousButtonAction={() => {
              handleSubmit();
              previousNode();
            }}
            nodeOnClickAction={() => selectNode(distNodes[selectedNodeIndex ?? 0]?.id)}
            nextButtonAction={() => {
              handleSubmit();
              nextNode();
            }}
            forwardButtonAction={handleAllRemainingSubmit}
            allNodes={distNodes}
            currentNodeIndex={selectedNodeIndex}
            forwardButtonTitle={t('distributionTab.forwardButtonTitle')}
          />

          <Collapse in={alert.opened}>
            <Alert
              action={
                <IconButton
                  aria-label='close'
                  color='inherit'
                  size='small'
                  onClick={() => {
                    setAlert({ opened: false, node: alert.node });
                  }}
                >
                  <CloseIcon fontSize='inherit' />
                </IconButton>
              }
              severity='info'
              sx={{ mb: 2, mt: 2 }}
            >
              {t('distributionTab.alertFirstHalf')}
              <strong>{alert.node}</strong>
              {t('distributionTab.alertSecondHalf')}
            </Alert>
          </Collapse>
          <form
            onSubmit={(event) => {
              handleSubmit(event);
              nextNode();
            }}
          >
            <p>{t('distributionTab.distKnownQuestion')}</p>
            <div className='addNodeDialog__attribute-wrapper'>
              <Radio
                onClick={() => {
                  setNodeInput('hasDistribution', true);
                }}
                enabled={getNodeInput()?.hasDistribution === true}
                label={t('distributionTab.distKnownYes')}
              />
              <Radio
                onClick={() => {
                  setNodeInput('hasDistribution', false);
                }}
                enabled={getNodeInput()?.hasDistribution === false}
                label={t('distributionTab.distKnownNo')}
              />
            </div>
            <div className={getNodeInput()?.hasDistribution === true ? '' : 'd-none'}>
              {/* Dispatch (discrete / continuous variable) */}
              <p>{t('distributionTab.variableTypeQuestion')}</p>
              <div className='addNodeDialog__attribute-wrapper'>
                <Radio
                  onClick={() => {
                    setNodeInput('distType', 'discrete');
                  }}
                  enabled={false /* TODO: getNodeInput()?.distType === 'discrete' */}
                  label={t('distributionTab.variableTypeDiscrete')}
                  disabled
                />
                <Radio
                  onClick={() => {
                    setNodeInput('distType', 'continuous');
                  }}
                  enabled={getNodeInput()?.distType === 'continuous'}
                  label={t('distributionTab.variableTypeContinuous')}
                />
              </div>
              <div className={getNodeInput()?.distType === 'discrete' ? '' : 'd-none'}>
                <select
                  id='discrete-dists-select'
                  className='form-select single-select'
                  value={getSelectedValue()}
                  onChange={(e) => {
                    setNodeInput('hasUserDefinedDist', e.target.value === 'own');
                    if (e.target.value !== 'own' && e.target.value !== 'DEFAULT') {
                      setNodeInput('distAbbr', e.target.value);
                      setNodeInput(
                        'parametersInput',
                        distributions[e.target.value]?.params.map((param) => param.value?.toString() ?? '') ?? [],
                      );
                    } else {
                      setNodeInput('distAbbr', undefined);
                    }
                    rerender();
                  }}
                  required
                >
                  <option value='DEFAULT' disabled>
                    {t('distributionTab.variableSelectDefault')}
                  </option>
                  {[...discreteDistributions, { name: 'own', label: 'distributionTab.variableSelectOwn' }].map(
                    (dist) => (
                      <option value={dist.name} key={`sel${dist.name}`}>
                        {dist.label !== undefined ? t(dist.label) : dist.name}
                      </option>
                    ),
                  )}
                </select>
              </div>
              <div className={getNodeInput()?.distType === 'continuous' ? '' : 'd-none'}>
                <select
                  id='continuous-dists-select'
                  className='form-select single-select'
                  value={getSelectedValue()}
                  onChange={(e) => {
                    setNodeInput('hasUserDefinedDist', e.target.value === 'own');
                    if (e.target.value !== 'own' && e.target.value !== 'DEFAULT') {
                      setNodeInput('distAbbr', e.target.value);
                      setNodeInput(
                        'parametersInput',
                        distributions[e.target.value]?.params.map((param) => param.value?.toString() ?? '') ?? [],
                      );
                    } else {
                      setNodeInput('distAbbr', undefined);
                    }
                    rerender();
                  }}
                  required
                >
                  <option value='DEFAULT' disabled>
                    {t('distributionTab.variableSelectDefault')}
                  </option>
                  {[...continuousDistributions, { name: 'own', label: 'distributionTab.variableSelectOwn' }].map(
                    (dist) => (
                      <option value={dist.name} key={dist.name}>
                        {dist.label !== undefined ? t(dist.label) : dist.name}
                      </option>
                    ),
                  )}
                </select>
              </div>
              {getSelectedValue() !== 'own' && getSelectedValue() !== 'DEFAULT' && (
                <div id='parameters-container'>
                  <p className='text-muted'>{t('distributionTab.parametersQuestion')}</p>
                  <div className='addNodeDialog__attribute-wrapper'>
                    {distributions[getNodeInput()?.distAbbr ?? ''].params.map((param, index) => (
                      <div key={param.name}>
                        <label htmlFor={`param-continuous-${index}`}>
                          <Latex>{param.label !== undefined ? t(param.label) : param.name}</Latex>
                        </label>
                        <input
                          id={`param-continuous-${index}`}
                          type='text'
                          className='addNodeDialog__edit-text'
                          value={getNodeInput()?.parametersInput?.[index] ?? ''}
                          onChange={(e) => {
                            const oldInputs = [...(getNodeInput()?.parametersInput ?? [])];
                            oldInputs[index] = e.target.value;
                            setNodeInput('parametersInput', oldInputs);
                          }}
                        />
                      </div>
                    ))}
                  </div>
                </div>
              )}
              {getSelectedValue() === 'own' && (
                <div id='own-distribution-container'>
                  <p className='text-muted'>
                    {t('distributionTab.ownDistributionQuestion')}
                    <a target='_blank' href='https://docs.scipy.org/doc/scipy/reference/stats.html' rel='noreferrer'>
                      {t('distributionTab.ownDistributionQuestionLink')}
                    </a>
                  </p>
                  <div className='addNodeDialog__attribute-wrapper'>
                    <input
                      type='text'
                      placeholder={t('distributionTab.ownDistributionPlaceholder')}
                      className='addNodeDialog__edit-text'
                      value={getNodeInput()?.userDefinedCode ?? ''}
                      onChange={(e) => {
                        setNodeInput('userDefinedCode', e.target.value);
                      }}
                    />
                  </div>
                </div>
              )}
              <br />
              <button
                type='button'
                onClick={() => getDensityGraph(selectedNodeIndex)}
                disabled={!hasAllData()}
                className={`mb-2 ${
                  hasAllData() ? 'addNodeDialog__add-button' : 'addNodeDialog__add-button disabled text-muted'
                }`}
              >
                {t('distributionTab.buttonShowDist')}
              </button>
              {getNodeDensityInput(selectedNodeIndex)?.densityResponse?.error_message &&
                textWithLineBreak(
                  getNodeDensityInput(selectedNodeIndex)?.densityResponse?.error_message ?? '',
                  'pb-3 distribution-tab-error-message',
                )}

              {getNodeDensityInput(selectedNodeIndex)?.densityResponse?.body?.scatter && (
                <div className='densityGraph'>
                  <ApexAreaChart
                    xLabel={distNodes[selectedNodeIndex]?.name}
                    scatter={getNodeDensityInput(selectedNodeIndex)?.densityResponse?.body?.scatter}
                  />
                </div>
              )}
            </div>
            <br />

            <button
              type='submit'
              disabled={!hasAllData()}
              className={hasAllData() ? 'addNodeDialog__add-button' : 'addNodeDialog__add-button disabled text-muted'}
            >
              {t('distributionTab.buttonNextNode')}
            </button>
          </form>
        </div>
      )}
    </div>
  );
};
