import React, { useState, useEffect, createRef } from 'react';
import { DataSet } from 'vis-data';
import { Network } from 'vis-network/standalone';
import { createTheme } from '@mui/material';

import { Edge, Node, CausalData, Option } from '../model';

const theme = createTheme();

const options = {
  edges: {
    arrows: {
      to: {
        enabled: true,
        type: 'arrow',
      },
    },
    length: 300, // Longer edges between nodes.
    color: {
      inherit: false,
    },
  },
  physics: {
    stabilization: false,
    barnesHut: {
      springLength: 200,
    },
  },
  interaction: {
    dragNodes: true,
  },
};

type NetworkPlotProps = {
  data: CausalData;
  option: Option;
  highlightPath: boolean;
  height: string;
  onClick: (element: string, value: string) => void;
};

type NetworkPlotState = {
  network: any;
  appRef: any;
  selectedNode: Node;
  selectedEdge: Edge;
  dataset: any;
};

function NetworkPlot(props: NetworkPlotProps) {
  const [network, setNetwork] = useState<any>(null);
  const [appRef, setAppRef] = useState<any>(createRef());
  const [selectedNode, setSelectedNode] = useState<string>("");
  const [selectedEdge, setSelectedEdge] = useState<string>("");
  const [dataset, setDataset] = useState({
    nodes: new DataSet(props.data.nodes),
    edges: new DataSet(props.data.edges),
  });

  useEffect(() => {
    let network = new Network(appRef.current, dataset, options);
    network.stabilize();
    network.on('click', onElementClick);

    setNetwork(network);
  }, []);

  useEffect(() => {
    resetView();
    highlightStartEndNode();

    // only highlight if necessary
    if (props.highlightPath) {
      highlightPathes();
    }
  }, [props.data, props.option, props.highlightPath]);

  const resetView = () => {
    // remove all colorings from the edges
    let edges: Edge[] = [];
    edges = dataset.edges.get();
    edges.forEach((edge: Edge) => {
      // @ts-ignore
      delete edge.value;
    });
    dataset.edges.updateOnly(edges);

    // remove all colorings from the nodes
    let nodes: Node[] = [];
    nodes = dataset.nodes.get();
    nodes.forEach((node: Node) => {
      node.hidden = false;
      node.color = '#82A0BC';
      node.title = '';
    });
    dataset.nodes.updateOnly(nodes);
  };

  const highlightPathes = () => {
    if (props.option.paths && props.option.paths.length > 0) {
      // highlight pathes
      let updatedEdges: Edge[] = [];
      let strNode: string[] = [];
      dataset.edges.get().forEach((edge: Edge) => {
        props.option.paths.forEach((path: string[]) => {
          for (let i = 0; i < path.length - 1; i++) {
            // if the edge is part of the path
            if (
              (edge.from === path[i] && edge.to === path[i + 1]) ||
              (edge.from === path[i + 1] && edge.to === path[i])
            ) {
              strNode.push(edge.from, edge.to);
              edge.value = props.option.causal_effect * 100;
              updatedEdges.push(edge);
            }
          }
        });
      });

      dataset.edges.updateOnly(updatedEdges);

      // remove nodes which are not part of the causal path
      strNode = strNode.filter((item, index) => strNode.indexOf(item) === index);
      let nodes: Node[] = dataset.nodes.get({
        filter: function (item: { id: string }) {
          return strNode.indexOf(item.id) === -1;
        },
      });
      nodes.forEach((node: Node) => {
        node.hidden = true;
      });

      dataset.nodes.updateOnly(nodes);
    }
  };

  const highlightStartEndNode = () => {
    if (props.option.treatment && props.option.outcome) {
      let nodes: Node[] = [];

      // update treatment and outcome node color
      let treatment = dataset.nodes.get(props.option.treatment) as Node;
      if (props.option.causal_effect > 0) {
        treatment.color = '#AE1917';
        treatment.font = { color: 'white' };
      } else {
        treatment.color = '#4E8341';
        treatment.font = { color: 'white' };
      }
      nodes.push(treatment);

      let outcome = dataset.nodes.get(props.option.outcome) as Node;
      outcome.color = '#1a3761';
      outcome.font = { color: 'white' };
      nodes.push(outcome);

      dataset.nodes.updateOnly(nodes);
    }
  };

  const onElementClick = (properties: any) => {
    let object = '';
    let name = '';

    if (properties.nodes.length > 0) {
      name = 'selectedNode';
      object = properties.nodes[0];
      setSelectedNode(object);
    } else if (properties.edges.length > 0) {
      name = 'selectedEdge';
      object = properties.edges[0];
      setSelectedEdge(object);
    }
  };

  return (
    <div
      style={{
        width: '100%',
        height: props.height,
        marginTop: theme.spacing(2),
      }}
      ref={appRef}
    />
  );
}

export default NetworkPlot;
