package com.github.steveash.jopenfst.operations;

import com.carrotsearch.hppc.IntObjectOpenHashMap;
import com.carrotsearch.hppc.IntOpenHashSet;
import com.github.steveash.jopenfst.Arc;
import com.github.steveash.jopenfst.Fst;
import com.github.steveash.jopenfst.IndexWeight;
import com.github.steveash.jopenfst.MutableFst;
import com.github.steveash.jopenfst.MutableState;
import com.github.steveash.jopenfst.State;
import com.github.steveash.jopenfst.semiring.Semiring;

import java.util.Deque;
import java.util.HashMap;

/**
* N-shortest paths operation.
*
* See: M. Mohri, M. Riley, "An Efficient Algorithm for the n-best-strings problem", Proceedings of the International
* Conference on Spoken Language Processing 2002 (ICSLP ’02).
*
* @author John Salatas [email protected]
*/
public class NShortestPaths {

/**
* Calculates the shortest distances from each state to the final.
*
* See: M. Mohri, "Semiring Framework and Algorithms for Shortest-Distance Problems", Journal of Automata, Languages
* and Combinatorics, 7(3), pp. 321-350, 2002.
*
* @param fst the fst to calculate the shortest distances
* @return the array containing the shortest distances
*/
private static double[] shortestDistance(Fst fst) {

Fst reversed = Reverse.reverse(fst);

double[] d = new double[reversed.getStateCount()];
double[] r = new double[reversed.getStateCount()];

Semiring semiring = reversed.getSemiring();

for (int i = 0; i < d.length; i++) {
d[i] = semiring.zero();
r[i] = semiring.zero();
}

IntObjectOpenHashMap<State> stateMap = new IntObjectOpenHashMap<>();
IntOpenHashSet enqueuedStateIds = new IntOpenHashSet();

stateMap.put(reversed.getStartState().getId(), reversed.getStartState());

d[reversed.getStartState().getId()] = semiring.one();
r[reversed.getStartState().getId()] = semiring.one();

while (!queue.isEmpty()) {
int thisStateId = queue.removeFirst();
enqueuedStateIds.remove(thisStateId);
State thisState = stateMap.get(thisStateId);
double rnew = r[thisState.getId()];
r[thisState.getId()] = semiring.zero();

for (int i = 0; i < thisState.getArcCount(); i++) {
Arc arc = thisState.getArc(i);
State nextState = arc.getNextState();
double dnext = d[arc.getNextState().getId()];
double dnextnew = semiring.plus(dnext, semiring.times(rnew, arc.getWeight()));
if (dnext != dnextnew) {
d[arc.getNextState().getId()] = dnextnew;
r[arc.getNextState().getId()] = semiring.plus(r[arc.getNextState().getId()], semiring.times(rnew, arc.getWeight()));
int nextStateId = nextState.getId();
if (!enqueuedStateIds.contains(nextStateId)) {
stateMap.put(nextStateId, nextState);
}
}
}
}
return d;
}

/**
* Calculates the n-best shortest path from the initial to the final state.
*
* @param fst         the fst to calculate the nbest shortest paths
* @param topk           number of best paths to return
* @return an fst containing the n-best shortest paths
*/
public static MutableFst apply(Fst fst, int topk) {
fst.throwIfInvalid();
final Semiring semiring = fst.getSemiring();
final double[] d = shortestDistance(fst);

MutableFst res = MutableFst.emptyWithCopyOfSymbols(fst);
MutableFst copy = ExtendFinal.apply(fst);

int[] r = new int[copy.getStateCount()];
for (int i = 0; i < r.length; i++) {
r[i] = 0;
}

MinMaxPriorityQueue<IndexWeight> qq = MinMaxPriorityQueue.orderedBy(new Ordering<IndexWeight>() {
@Override
public int compare(IndexWeight left, IndexWeight right) {
double dx = d[left.getIndex()];
double dy = d[right.getIndex()];
double wx = semiring.times(dx, left.getWeight());
double wy = semiring.times(dy, right.getWeight());
if (semiring.naturalLess(wx, wy)) {
return -1;
}
if (semiring.naturalLess(wy, wx)) {
return +1;
}
return 0;
}
}).create();
HashMap<IndexWeight, IndexWeight> previous = new HashMap<>(copy.getStateCount());
// source -> res id
HashMap<IndexWeight, Integer> stateMap = new HashMap<>(copy.getStateCount());

State start = copy.getStartState();
IndexWeight first = new IndexWeight(start.getId(), semiring.one());
previous.put(first, null);

while (!qq.isEmpty()) {
IndexWeight pair = qq.removeFirst();
State prevOld = copy.getState(pair.getIndex());
double pairWeight = pair.getWeight();

MutableState resNext = new MutableState(prevOld.getFinalWeight());
stateMap.put(pair, resNext.getId());
IndexWeight prevEntry = previous.get(pair);
if (prevEntry == null) {
// this is the start state
res.setStart(resNext);
} else {
// add the incoming arc from previous to current
MutableState previousStateNew = res.getState(stateMap.get(prevEntry));
State previousOldState = copy.getState(prevEntry.getIndex());
int numArcs = previousOldState.getArcCount();
for (int j = 0; j < numArcs; j++) {
Arc a = previousOldState.getArc(j);
if (a.getNextState().getId() == prevOld.getId()) {
}
}
}

int stateIndex = prevOld.getId();
r[stateIndex]++;

if ((r[stateIndex] == topk) && (res.getSemiring().isNotZero(prevOld.getFinalWeight()))) {
break;
}

if (r[stateIndex] <= topk) {
int numArcs = prevOld.getArcCount();
for (int j = 0; j < numArcs; j++) {
Arc a = prevOld.getArc(j);
double cnew = semiring.times(pairWeight, a.getWeight());
IndexWeight next = new IndexWeight(a.getNextState().getId(), cnew);
previous.put(next, pair);