import dagre from '@dagrejs/dagre';

import causeMapNodeCompareFn from './causeMapNodeCompareFn';
import getCauseMapEdgeCompareFn from './getCauseMapEdgeCompareFn';

type LayoutOptions = {
  horizontalSpacing: number;
  verticalSpacing: number;
  minX?: number;
  minY?: number;
};
export const DEFAULT_LAYOUT_OPTIONS: LayoutOptions = {
  horizontalSpacing: 200,
  verticalSpacing: 250,
};

const layoutCauseMapNodes = <
  Node extends {
    id: string;
    width: number;
    height: number;
    x: number;
    y: number;
  },
  Edge extends { fromId: string; toId: string }
>(
  nodes: Node[],
  edges: Edge[],
  options: LayoutOptions = DEFAULT_LAYOUT_OPTIONS
): Node[] => {
  if (nodes.length <= 1) {
    return nodes;
  }

  const dagreGraph = new dagre.graphlib.Graph()
    .setGraph({
      rankdir: 'LR',
      ranksep: options.verticalSpacing,
      nodesep: options.horizontalSpacing,
      marginx: options.minX,
      marginy: options.minY,
    })
    .setDefaultEdgeLabel(() => ({}));

  // Create a mapping from node id to node for quick lookup
  const nodeMap = new Map(nodes.map((node) => [node.id, node]));
  nodes
    .sort(causeMapNodeCompareFn)
    .forEach(({ id, width, height }) =>
      dagreGraph.setNode(id, { width, height })
    );
  edges
    .sort(getCauseMapEdgeCompareFn(nodeMap))
    .forEach(({ fromId, toId }) => dagreGraph.setEdge(fromId, toId));

  dagre.layout(dagreGraph);

  const nextNodes = nodes.map((node) => {
    const { x, y } = dagreGraph.node(node.id);
    return {
      ...node,
      x: x - node.width / 2,
      y: y - node.height / 2,
    };
  });
  return nextNodes;
};

export default layoutCauseMapNodes;
