/*
** This file is part of OSPREY 3.0
** 
** OSPREY Protein Redesign Software Version 3.0
** Copyright (C) 2001-2018 Bruce Donald Lab, Duke University
** 
** OSPREY is free software: you can redistribute it and/or modify
** it under the terms of the GNU General Public License version 2
** as published by the Free Software Foundation.
** 
** You should have received a copy of the GNU General Public License
** along with OSPREY.  If not, see <http://www.gnu.org/licenses/>.
** 
** OSPREY relies on grants for its development, and since visibility
** in the scientific literature is essential for our success, we
** ask that users of OSPREY cite our papers. See the CITING_OSPREY
** document in this distribution for more information.
** 
** Contact Info:
**    Bruce Donald
**    Duke University
**    Department of Computer Science
**    Levine Science Research Center (LSRC)
**    Durham
**    NC 27708-0129
**    USA
**    e-mail: www.cs.duke.edu/brd/
** 
** <signature of Bruce Donald>, Mar 1, 2018
** Bruce Donald, Professor of Computer Science
*/

package edu.duke.cs.osprey.pruning;


import edu.duke.cs.osprey.confspace.ParametricMolecule;
import edu.duke.cs.osprey.confspace.RCTuple;
import edu.duke.cs.osprey.confspace.SimpleConfSpace;
import edu.duke.cs.osprey.dof.DegreeOfFreedom;
import edu.duke.cs.osprey.energy.ResInterGen;
import edu.duke.cs.osprey.energy.ResidueInteractions;
import edu.duke.cs.osprey.parallelism.TaskExecutor;
import edu.duke.cs.osprey.structure.*;
import edu.duke.cs.osprey.tools.Progress;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optim.SimpleBounds;
import org.apache.commons.math3.optim.linear.*;

import java.util.*;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;


/**
 * Pruning of Local Unrealistic Geometries (PLUG)
 *
 * prunes RC tuples if probe-style clashes are unavoidable
 */
public class PLUG {

	public final SimpleConfSpace confSpace;

	/** max num iterations of generalized Newton iteration to find a "boundary" point where the violation function is zero */
	public int maxNumIterations = 30;

	/** distance threshold to claim that a violation value is close enough to zero */
	public double violationThreshold = 1e-2;

	/** factor of the voxel width used to approximate the gradient of the violation function */
	public double gradientDxFactor = 1e-4;

	private final Probe probe;
	private final AtomConnectivity connectivity;

	public PLUG(SimpleConfSpace confSpace) {

		this.confSpace = confSpace;

		// load probe
		this.probe = new Probe();
		this.probe.matchTemplates(this.confSpace);

		// init atom connectivity
		this.connectivity = new AtomConnectivity.Builder()
			.set15HasNonBonded(false) // follows probe convention
			.build();
	}

	public void pruneSingles(PruningMatrix pmat, double tolerance) {
		pruneSingles(pmat, tolerance, new TaskExecutor());
	}

	public void pruneSingles(PruningMatrix pmat, double tolerance, TaskExecutor tasks) {

		// count unpruned singles
		AtomicLong numSingles = new AtomicLong(0);
		pmat.forEachUnprunedSingle((pos1, rc1) -> {
			numSingles.incrementAndGet();
			return PruningMatrix.IteratorCommand.Continue;
		});
		Progress progress = new Progress(numSingles.get());

		// try to prune each single
		pmat.forEachUnprunedSingle((pos1, rc1) -> {
			tasks.submit(
				() -> shouldPruneTuple(new RCTuple(pos1, rc1), tolerance),
				(shouldPrune) -> {
					if (shouldPrune) {
						pmat.pruneSingle(pos1, rc1);
					}
					progress.incrementProgress();
				}
			);
			return PruningMatrix.IteratorCommand.Continue;
		});

		tasks.waitForFinish();
	}

	public void prunePairs(PruningMatrix pmat, double tolerance) {
		prunePairs(pmat, tolerance, new TaskExecutor());
	}

	public void prunePairs(PruningMatrix pmat, double tolerance, TaskExecutor tasks) {

		// count unpruned pairs
		AtomicLong numPairs = new AtomicLong(0);
		pmat.forEachUnprunedPair((pos1, rc1, pos2, rc2) -> {
			numPairs.incrementAndGet();
			return PruningMatrix.IteratorCommand.Continue;
		});
		Progress progress = new Progress(numPairs.get());

		// try to prune each pair
		pmat.forEachUnprunedPair((pos1, rc1, pos2, rc2) -> {
			tasks.submit(
				() -> shouldPruneTuple(new RCTuple(pos1, rc1, pos2, rc2), tolerance),
				(shouldPrune) -> {
					if (shouldPrune) {
						pmat.prunePair(pos1, rc1, pos2, rc2);
					}
					progress.incrementProgress();
				}
			);
			return PruningMatrix.IteratorCommand.Continue;
		});

		tasks.waitForFinish();
	}

	public void pruneTriples(PruningMatrix pmat, double tolerance) {
		pruneTriples(pmat, tolerance, new TaskExecutor());
	}

	public void pruneTriples(PruningMatrix pmat, double tolerance, TaskExecutor tasks) {

		// count unpruned triple
		AtomicLong numTriples = new AtomicLong(0);
		pmat.forEachUnprunedTriple((pos1, rc1, pos2, rc2, pos3, rc3) -> {
			numTriples.incrementAndGet();
			return PruningMatrix.IteratorCommand.Continue;
		});
		Progress progress = new Progress(numTriples.get());

		// try to prune each triple
		pmat.forEachUnprunedTriple((pos1, rc1, pos2, rc2, pos3, rc3) -> {
			tasks.submit(
				() -> shouldPruneTuple(new RCTuple(pos1, rc1, pos2, rc2, pos3, rc3), tolerance),
				(shouldPrune) -> {
					if (shouldPrune) {
						pmat.pruneTriple(pos1, rc1, pos2, rc2, pos3, rc3);
					}
					progress.incrementProgress();
				}
			);
			return PruningMatrix.IteratorCommand.Continue;
		});

		tasks.waitForFinish();
	}

	private class Voxel {

		final ParametricMolecule pmol;

		final int numDofs;
		final double[] width;
		final double[] width2;
		final double[] min;
		final double[] max;
		final double[] center;

		Voxel (ParametricMolecule pmol) {

			this.pmol = pmol;

			numDofs = pmol.dofs.size();
			width = new double[numDofs];
			width2 = new double[numDofs];
			min = new double[numDofs];
			max = new double[numDofs];
			center = new double[numDofs];

			for (int d=0; d<numDofs; d++) {
				width[d] = pmol.dofBounds.getWidth(d);
				width2[d] = width[d]*width[d];
				min[d] = pmol.dofBounds.getMin(d);
				max[d] = pmol.dofBounds.getMax(d);
				center[d] = pmol.dofBounds.getCenter(d);
			}
		}

		DegreeOfFreedom getDof(int d) {
			return pmol.dofs.get(d);
		}

		void applyDof(int d, double val) {
			getDof(d).apply(val);
		}
	}

	public boolean shouldPruneTuple(RCTuple tuple, double tolerance) {

		// make the molecule and get all the residue interactions for the tuple
		ParametricMolecule pmol = confSpace.makeMolecule(tuple);

		ResidueInteractions inters = ResInterGen.of(confSpace)
			.addIntras(tuple)
			.addInters(tuple)
			.addShell(tuple)
			.make();

		Voxel voxel = new Voxel(pmol);

		try {

			// get linear constraints for each atom pair
			List<LinearConstraint> constraints = getLinearConstraints(voxel, inters, tolerance);

			// no constraints? don't prune
			if (constraints.isEmpty()) {
				return false;
			}

			// use an LP solver (eg simplex) to determine if the constraints allow any feasible points
			new SimplexSolver().optimize(
				new SimpleBounds(voxel.min, voxel.max),
				new LinearConstraintSet(constraints),
				// dummy function: don't really need to minimize, but can't call simplex phase 1 solver directly
				new LinearObjectiveFunction(new double[pmol.dofs.size()], 0.0)
			);

			// if we got here, simplex didn't throw an exception
			// meaning at least one feasible point exists, so don't prune this tuple
			return false;

		} catch (NoFeasibleSolutionException ex) {

			// no feasible points, prune this tuple
			return true;
		}
	}

	private class AtomVoxel {

		final Atom atom;
		final Probe.AtomInfo probeInfo;
		final List<Integer> dofIndices = new ArrayList<>();

		AtomVoxel(Atom atom, Probe probe) {
			this.atom = atom;
			this.probeInfo = probe.getAtomInfo(atom);
		}

		AtomVoxel(Atom atom, Probe probe, Voxel voxel) {
			this(atom, probe);

			// determine which dofs affect this atom position
			for (int d=0; d<voxel.numDofs; d++) {

				// skip this dof if it's not continuously flexible
				if (voxel.width[d] <= 0.0) {
					continue;
				}

				// start at the center
				voxel.applyDof(d, voxel.center[d]);
				double[] start = atom.getCoords();

				// move a little bit away from the center
				voxel.applyDof(d, voxel.center[d] + gradientDxFactor*voxel.width[d]);
				double[] stop = atom.getCoords();

				// pick the dof if the positions are different by even a little
				if (!Arrays.equals(start, stop)) {
					dofIndices.add(d);
				}
			}
		}

		public boolean hasDofs() {
			return !dofIndices.isEmpty();
		}
	}

	private class AtomPairVoxel {

		final Voxel voxel;
		final Probe.AtomPair probePair;
		final List<Integer> dofIndices = new ArrayList<>();
		final int numDofs;

		AtomPairVoxel(Voxel voxel, AtomVoxel v1, AtomVoxel v2) {

			this.voxel = voxel;

			// make the probe atom pair
			this.probePair = probe.new AtomPair(v1.atom, v2.atom, v1.probeInfo, v2.probeInfo);

			// combine the dofs
			for (int d : v1.dofIndices) {
				if (!dofIndices.contains(d)) {
					dofIndices.add(d);
				}
			}
			for (int d : v2.dofIndices) {
				if (!dofIndices.contains(d)) {
					dofIndices.add(d);
				}
			}
			numDofs = dofIndices.size();
		}

		double min(int d) {
			d = dofIndices.get(d);
			return voxel.min[d];
		}

		double max(int d) {
			d = dofIndices.get(d);
			return voxel.max[d];
		}

		double center(int d) {
			d = dofIndices.get(d);
			return voxel.center[d];
		}

		double width(int d) {
			d = dofIndices.get(d);
			return voxel.width[d];
		}

		double width2(int d) {
			d = dofIndices.get(d);
			return voxel.width2[d];
		}

		void applyDof(int d, double val) {
			d = dofIndices.get(d);
			voxel.applyDof(d, val);
		}

		double getViolation(double[] x, double tolerance) {

			// apply the dofs
			for (int d=0; d<numDofs; d++) {
				applyDof(d, x[d]);
			}

			// get the violation
			return probePair.getViolation(tolerance);
		}

		double getViolationAlong(int d, double x, double dx, double tolerance) {

			// move along one dof
			applyDof(d, x + dx);

			// get the violation
			double violation = probePair.getViolation(tolerance);

			// put the dof back
			applyDof(d, x);

			return violation;
		}

		boolean outOfRange(double[] x) {
			for (int d=0; d<numDofs; d++) {
				if (x[d] < min(d) || x[d] > max(d)) {
					return true;
				}
			}
			return false;
		}
	}

	public List<LinearConstraint> getLinearConstraints(Voxel voxel, ResidueInteractions inters, double tolerance) {

		Map<Atom,AtomVoxel> atomVoxels = new HashMap<>();
		List<LinearConstraint> constraints = new ArrayList<>();

		// for each res pair
		for (ResidueInteractions.Pair resPair : inters) {
			Residue res1 = voxel.pmol.mol.residues.getOrThrow(resPair.resNum1);
			Residue res2 = voxel.pmol.mol.residues.getOrThrow(resPair.resNum2);

			// for each atom pair
			for (int[] atomPair : connectivity.getAtomPairs(res1, res2).getPairs(AtomNeighbors.Type.NONBONDED)) {
				Atom a1 = res1.atoms.get(atomPair[0]);
				Atom a2 = res2.atoms.get(atomPair[1]);

				// get voxel info for each atom, or skip the pair if no dofs
				AtomVoxel v1 = atomVoxels.computeIfAbsent(a1, (key) -> new AtomVoxel(a1, probe, voxel));
				AtomVoxel v2 = atomVoxels.computeIfAbsent(a2, (key) -> new AtomVoxel(a2, probe, voxel));
				if (!v1.hasDofs() && !v2.hasDofs()) {
					continue;
				}

				AtomPairVoxel pairVoxel = new AtomPairVoxel(voxel, v1, v2);
				LinearConstraint constraint = getLinearConstraint(pairVoxel, tolerance);
				if (constraint != null) {
					constraints.add(constraint);
				}
			}
		}

		return constraints;
	}

	public LinearConstraint getLinearConstraint(AtomPairVoxel voxel, double tolerance) {

		BoundaryPoint p = findBoundaryNewton(voxel, tolerance);

		if (p == null) {
			// dofs don't affect this atom pair, not useful for pruning, so don't make a constraint at all
			return null;
		}

		// did we not find a zero-point in the violation function?
		if (!p.atBoundary()) {

			// if we didn't find a zero-point, then assume all zero points lie outside the voxel
			// since all voxel points correspond to violations, this constraint is unsatisfiable
			if (p.violation > 0.0) {
				throw new NoFeasibleSolutionException();
			}

			// voxel always satisfiable for this atom pair, no constraint needed
			return null;
		}

		// use the boundary point to make a linear constraint on the dofs
		int n = p.dofValues.length;

		// make the linear constraint u.x >= w, where:
		//    u = -g
		//    w = -g.x*
		//    x* is the boundary point where the atom pair overlap is approx 0
		//    g is the gradient at x*
		// ie, the tangent hyperplane (d-1 linear subspace) to the isosurface at this point in the violation function
		RealVector u = new ArrayRealVector(n);
		double w = 0.0;
		for (int d=0; d<n; d++) {
			double g = -p.gradient[d];
			u.setEntry(d, g);
			w += p.dofValues[d]*g;
		}

		return new LinearConstraint(u, Relationship.GEQ, w);
	}

	public static class BoundaryPoint {

		double[] dofValues;
		double violation;
		double[] gradient;

		public BoundaryPoint(double[] dofValues, double violation, double[] gradient) {
			this.dofValues = dofValues.clone();
			this.violation = violation;
			this.gradient = gradient.clone();
		}

		public BoundaryPoint(double violation) {
			this.dofValues = null;
			this.violation = violation;
			this.gradient = null;
		}

		public boolean atBoundary() {
			return gradient != null;
		}

		@Override
		public String toString() {
			StringBuilder buf = new StringBuilder();
			buf.append(String.format("violation: %.3f", violation));
			if (dofValues != null) {
				for (int d=0; d<dofValues.length; d++) {
					buf.append(String.format("\n\tx[%d]: %.3f", d, dofValues[d]));
				}
			}
			if (gradient != null) {
				for (int d=0; d<gradient.length; d++) {
					buf.append(String.format("\n\tg[%d]: %.3f", d, gradient[d]));
				}
			}
			return buf.toString();
		}
	}

	/**
	 * find a point in the voxel where the atom pair overlap is close to 0
	 * using a generalization of Newton's method starting at the voxel center
	 */
	public BoundaryPoint findBoundaryNewton(AtomPairVoxel pairVoxel, double tolerance) {

		// objective function
		Function<double[],Double> f = (x) ->
			pairVoxel.getViolation(x, tolerance);

		// gradient function (approximate)
		double[] gout = new double[pairVoxel.numDofs];
		Function<double[],double[]> g = (x) -> {
			double baseViolation = f.apply(x);
			for (int d=0; d<pairVoxel.numDofs; d++) {
				double dx = gradientDxFactor*pairVoxel.width(d);
				gout[d] = (pairVoxel.getViolationAlong(d, x[d], dx, tolerance) - baseViolation)/dx;
			}
			return gout;
		};

		// start at the center of the voxel
		double[] x = new double[pairVoxel.numDofs];
		for (int d=0; d<pairVoxel.numDofs; d++) {
			x[d] = pairVoxel.center(d);
		}

		// get the initial violation
		double violation = f.apply(x);

		if (violation == 0.0) {
			// well, that was easy
			return new BoundaryPoint(x, violation, g.apply(x));
		}

		// iterate a newton-like method until convergence, or we hit a voxel boundary
		for (int i=0; i<maxNumIterations; i++) {

			// calculate the gradient at x
			double[] grad = g.apply(x);

			double s = 0.0;
			for (int d=0; d<pairVoxel.numDofs; d++) {
				s += pairVoxel.width2(d)*grad[d]*grad[d];
			}

			if (s == 0.0) {
				// gradient was zero (assuming no voxel width was zero), so can't find f=0 point
				// this usually means the DoFs have no effect on distance for this atom pair, and hence no boundary point exists
				return null;
			}

			// take a step along the gradient direction
			for (int d=0; d<pairVoxel.numDofs; d++) {
				x[d] -= violation*grad[d]*pairVoxel.width2(d)/s;
			}

			// did we step out of the voxel?
			if (pairVoxel.outOfRange(x)) {

				// the gradient suggests all boundary points are outside of the voxel
				// so just return the violation of the initial point, and assume that represents the entire voxel
				return new BoundaryPoint(violation);
			}

			// is the current violation close enough to 0?
			violation = f.apply(x);
			double diff = Math.abs(violation);
			if (diff <= violationThreshold) {
				return new BoundaryPoint(x, violation, grad);
			}
		}

		// ran out of iterations and didn't find a boundary
		final boolean developerIsInvestigatingFrequencyOfThisHappening = false;
		if (developerIsInvestigatingFrequencyOfThisHappening) {
			throw new Error("can't find boundary point after " + maxNumIterations + " iterations for " + pairVoxel.probePair);
		}

		// don't model this atom pair with a constraint at all, just to be conservative
		return null;
	}
}