import * as D3 from 'd3';
import _ from 'lodash';

import { ClassBox } from './ClassBox';
import { setAttr } from './utils';

export default class ClassDiagram {
  constructor({ classes, refs, wrapper, width, height, posOpts } = {}) {
    this.wrapper = wrapper;
    this.svg = null;
    this.zoom = D3.zoom().on('zoom', this.handleZoom);
    this.g = null;
    this.classesData = classes;
    this.refs = refs;
    this.width = width || wrapper.offsetWidth;
    this.height = height || wrapper.offsetHeight;
    this.boxes = {};
    this.connectors = [];
    this.posOpts = posOpts;
    this.activeBox = null;
    this.mode = null;
    this.modeDisabler = null;
    this.tempConnector = null;
    this.currentTransform = null;

    window.addEventListener('resize', this.resizeHandler);
  }

  destroy() {
    this.svg.remove();
    window.removeEventListener('resize', this.resizeHandler);
  }

  draw() {
    if (!this.wrapper) throw new Error('wrapper is not defined');

    this.modeDisabler?.();
    this.svg?.remove();

    this.svg = D3.select(this.wrapper)
      .append('svg')
      .attr('width', this.width)
      .attr('height', this.height)
      .call(this.zoom);

    this.g = this.svg.append('g');

    if (this.currentTransform) {
      this.zoom.transform(this.svg, this.currentTransform);
    }

    this.addMarkers(this.g.append('defs'));
    this.createClassBoxes();
    this.autoPositioning();
  }

  resizeHandler = () => {
    if (!this.svg) return;

    this.width = this.wrapper.offsetWidth;
    this.height = this.wrapper.offsetHeight;

    this.svg.attr('width', this.wrapper.offsetWidth).attr('height', this.wrapper.offsetHeight);
  };

  handleZoom = (e) => {
    if (e?.sourceEvent?.ctrlKey) return;

    this.currentTransform = e.transform;
    this.g.attr('transform', e.transform);
  };

  addMarkers(defs) {
    defs
      .append('marker')
      .attr('id', 'arrowhead')
      .attr('viewBox', '0 0 10 10')
      .attr('refX', 10)
      .attr('refY', 5)
      .attr('markerWidth', 10)
      .attr('markerHeight', 10)
      .attr('orient', 'auto')
      .append('path')
      .attr('d', 'M10 5 0 10 0 8.7 6.8 5.5 0 5.5 0 4.5 6.8 4.5 0 1.3 0 0Z')
      .attr('stroke', 'none')
      .attr('fill', 'black')
      .attr('markerUnits', 'none');
    defs
      .append('marker')
      .attr('id', 'arrowhead-hover')
      .attr('viewBox', '0 0 20 20')
      .attr('refX', 10)
      .attr('refY', 5)
      .attr('markerWidth', 10)
      .attr('markerHeight', 10)
      .attr('orient', 'auto')
      .append('path')
      .attr('d', 'M10 5 0 10 0 8.7 6.8 5.5 0 5.5 0 4.5 6.8 4.5 0 1.3 0 0Z')
      .attr('stroke', '#ed960b')
      .attr('fill', '#ed960b')
      .attr('markerUnits', 'none');
  }

  createClassBoxes() {
    const boxes = _.chain(this.classesData)
      .keyBy('id')
      .mapValues((c) => {
        return new ClassBox(c, this.g, { width: 300 });
      })
      .value();
    this.boxes = boxes;
  }

  autoPositioning() {
    const {
      width = this.width,
      gap = 100,
      padding = 30,
      paddingX = 30,
      paddingY = 50
    } = this.posOpts || {};

    let currRowY = paddingY || padding,
      nextRowY = 0,
      currBoxX = paddingX,
      registries = Object.values(this.boxes);

    registries.forEach((reg) => {
      if (currBoxX + reg.width() > width) {
        currRowY = nextRowY;
        currBoxX = paddingX || padding;
      }

      reg.x(currBoxX);
      reg.y(currRowY);

      let newNextRowY = currRowY + reg.height() + gap;
      nextRowY = _.max([newNextRowY, nextRowY]);

      currBoxX += reg.width() + gap;

      reg.updateProps();
    });

    let connectors = this.refs.map((ref) => {
      const [src, srcBox] = _.chain(this.boxes)
        .toPairs()
        .find(([key, box]) => box.data().id == ref.column_from.reg_table_id)
        .value();
      const [trgt, trgtBox] = _.chain(this.boxes)
      .toPairs()
      .find(([key, box]) => box.data().id == ref.column_to.reg_table_id)
      .value();
      const points = this.getConnectorPoints(
        srcBox.attributes()[ref.reg_column_id],
        trgtBox.attributes()[ref.ref_reg_column_id],
        gap
      );
      
      let sAttr = srcBox.selection().attr('data-linked'),
        sLinks = new Set(sAttr ? sAttr.split(',') : []),
        tAttr = trgtBox.selection().attr('data-linked'),
        tLinks = new Set(tAttr ? tAttr.split(',') : []);

      sLinks.add(`${trgtBox.classname()}Class`);
      tLinks.add(`${srcBox.classname()}Class`);
      srcBox.selection().attr('data-linked', [...sLinks]);
      trgtBox.selection().attr('data-linked', [...tLinks]);

      return {
        from: `${srcBox.classname()}Class`,
        to: `${trgtBox.classname()}Class`,
        points
      };
    });

    this.connectors = this.createConnectors(connectors);
    this.g.selectAll('g.class').each(function () {
      this.parentNode.appendChild(this);
    });
  }

  createConnectors(connectors, classname = 'connector') {
    var line = D3.line()
      .x(function (d) {
        return d.x;
      })
      .y(function (d) {
        return d.y;
      });

    const connectorsPath = this.g
      .selectAll(`path.${classname}`)
      .data(connectors)
      .enter()
      .append('path')
      .attr('opacity', 1)
      .attr('data-from', (d) => d.from)
      .attr('data-to', (d) => d.to)
      .each(function (d, i) {
        var path = D3.select(this);
        setAttr(path, {
          class: classname,
          d: line(d.points),
          stroke: 'black',
          'stroke-width': 1,
          fill: 'none'
        });
        path.attr('marker-end', 'url(#arrowhead)');
      });

    this.g
      .selectAll(`path.${classname}`)
      .attr('stroke-dasharray', function () {
        var path = D3.select(this),
          totalLength = path.node().getTotalLength(),
          marker = 'url(#arrowhead)',
          hoveredMarker = 'url(#arrowhead-hover)';
          
        path
          .attr('data-dash', totalLength)
          .attr('data-total', totalLength)
          .attr('data-marker', marker)
          .attr('data-hovered-marker', hoveredMarker);

        return totalLength + ' ' + totalLength;
      })
      .attr('stroke-dashoffset', function () {
        return D3.select(this).attr('data-total');
      })
      .on('mouseover', function () {
        let selection = D3.select(this);
        selection.attr('marker-end', selection.attr('data-hovered-marker'));
        selection.transition().attr('stroke-dashoffset', 0).duration(400);
      })
      .on('mouseout', function () {
        let selection = D3.select(this);
        selection.attr('marker-end', selection.attr('data-marker'));
        selection
          .transition()
          .attr('stroke-dashoffset', selection.attr('data-total'))
          .duration(400);
      });

    return connectorsPath;
  }

  getConnectorPoints(source, target, gap = 100) {
    if (!source || !target) return false;

    const sourceX = source.x(),
      sourceRightX = source.rightX(),
      sourceY = source.midY(),
      targetX = target.x(),
      targetRightX = target.rightX(),
      targetY = target.midY(),
      diffX = Math.abs(source.x() - target.x()),
      diffRightX = Math.abs(source.rightX() - target.rightX());

    const curveMidPoint = _.max([
      _.max([
        target instanceof ClassBox ? target.y() : target.class().y(),
        source instanceof ClassBox ? source.y() : source.class().y()
      ]) -
        gap / 2,
      _.min([sourceY, targetY])
    ]);

    switch (true) {
      case source.x() > target.rightX():
        return [
          { x: sourceX, y: sourceY },
          { x: sourceX - 20, y: sourceY },
          { x: sourceX - 20, y: curveMidPoint },
          { x: targetRightX + 20, y: curveMidPoint },
          { x: targetRightX + 20, y: targetY },
          { x: targetRightX, y: targetY }
        ];
      case source.rightX() <= target.x():
        return [
          { x: sourceRightX, y: sourceY },
          { x: sourceRightX + 20, y: sourceY },
          { x: sourceRightX + 20, y: curveMidPoint },
          { x: targetX - 20, y: curveMidPoint },
          { x: targetX - 20, y: targetY },
          { x: targetX, y: targetY }
        ];
      case diffX <= diffRightX:
        return [
          { x: sourceX, y: sourceY },
          { x: sourceX - 20, y: sourceY },
          { x: sourceX - 20, y: curveMidPoint },
          { x: targetX - 20, y: curveMidPoint },
          { x: targetX - 20, y: targetY },
          { x: targetX, y: targetY }
        ];
      case diffX > diffRightX:
        return [
          { x: sourceRightX, y: sourceY },
          { x: sourceRightX + 20, y: sourceY },
          { x: sourceRightX + 20, y: curveMidPoint },
          { x: targetRightX + 20, y: curveMidPoint },
          { x: targetRightX + 20, y: targetY },
          { x: targetRightX, y: targetY }
        ];
      default:
        return [];
    }
  }
}
