package ch.idsia.blip.core.learn.scorer; import ch.idsia.blip.core.utils.data.ArrayUtils; import ch.idsia.blip.core.utils.data.map.ArrayHashingStrategy; import ch.idsia.blip.core.utils.data.map.TCustomHashSet; import java.lang.management.ManagementFactory; import java.lang.management.ThreadMXBean; import java.util.Arrays; import java.util.Set; import java.util.TreeSet; import java.util.logging.Logger; public class GreedyScorer extends BaseScorer { private static final Logger log = Logger.getLogger( GreedyScorer.class.getName()); /** * Queue size limit (for memory!) */ private final long max_queue_size = (long) Math.pow(2, 20); /** * Maximum size for queue */ private long queue_size; @Override protected String getName() { return "Greedy scoring"; } public GreedyScorer() { super(); } @Override public void prepare() { super.prepare(); queue_size = (long) Math.pow(dat.n_var, 3); if (queue_size > max_queue_size) { queue_size = max_queue_size; } } @Override public GreedySearcher getNewSearcher(int n) { return new GreedySearcher(n); } /** * Entry of a parent set in the linked-list queue. */ public static class ParentSetEntry implements Comparable<ParentSetEntry> { public final int[] s; public final double sk; public final int[][] p_values; /** * Default constructor. * * @param pset hash of the parent set base * @param sk score of the parent set */ ParentSetEntry(int[] pset, double sk, int[][] p_values) { this.s = pset; this.sk = sk; this.p_values = p_values; } @Override public int compareTo(ParentSetEntry other) { if (sk > other.sk) { return 1; } return -1; } public String toString() { return String.format("(%s %.3f)", Arrays.toString(s), sk); } } private class GreedySearcher extends BaseSearcher { /** * Queue for parent set to examine */ private TreeSet<ParentSetEntry> open; /** * Set of already considered parent sets */ private Set<int[]> closed; /** * Holder of the currently worst score saved in queue for evaluation */ private double worstQueueScore; GreedySearcher(int in_n) { super(in_n); } /** * Evaluate the parent sets of the variable in the available time, following an heuristic ordering. */ @Override public void run() { ThreadMXBean bean = ManagementFactory.getThreadMXBean(); double start = bean.getCurrentThreadCpuTime(); double elapsed = 0; prepare(); if (verbose > 2) { log.info( String.format("Starting with: %d, max time: %.2f", n, max_exec_time)); } // int arity = dat.l_n_arity[n]; // Initialize everything closed = new TCustomHashSet<int[]>(new ArrayHashingStrategy()); // Parent set already seen open = new TreeSet<ParentSetEntry>(); // Parent set to evaluate // Compute one-scores for (int p = 0; p < dat.n_var; p++) { if (p == n) { continue; } double sk; sk = oneScores.get(p); addScore(p, sk); addParentSetToEvaluate(new int[] { p}, sk, null); } if (max_exec_time == 0) { max_exec_time = Integer.MAX_VALUE; } // Consider all the parent set for evaluation! while (!open.isEmpty() && (elapsed < max_exec_time)) { ParentSetEntry pset = open.pollLast(); if (pset == null) { continue; } for (int p2 = 0; (p2 < dat.n_var) && (elapsed < max_exec_time); p2++) { if (p2 == n) { continue; } evaluteParentSet(n, pset.s, p2, pset.p_values); elapsed = (bean.getCurrentThreadCpuTime() - start) / 1000000000; } } if (verbose > 2) { log.info( String.format( "ending with: %d, elapsed: %.2f, num evaluated %d", n, elapsed, score.numEvaluated)); } // synchronized (scorer) { if (verbose > 0) { System.out.println("... finishing " + n); } conclude(); } private void evaluteParentSet(int n, int[] old, int p2, int[][] p_values) { if (Arrays.binarySearch(old, p2) >= 0) { return; } int[] pars = ArrayUtils.expandArray(old, p2); if (max_pset_size > 0 && pars.length >= max_pset_size) { return; } if (closed.contains(pars)) { return; } closed.add(pars); if (p_values == null) { p_values = score.computeParentSetValues(pars); } else { p_values = score.expandParentSetValues(pars, p_values, p2); } double sk = score.computeScore(n, pars, p_values); // System.out.println(Arrays.toString(pars)); if ((sk > voidSk)) { addScore(pars, sk); addParentSetToEvaluate(pars, sk, p_values); } } private void addParentSetToEvaluate(int[] p, double sk, int[][] p_values) { boolean toDropWorst = false; if (open.size() > queue_size) { if (sk < worstQueueScore) { // log.conclude("pruned"); return; } toDropWorst = true; } // Drop worst element in queue, to make room! if (toDropWorst) { open.pollLast(); worstQueueScore = open.last().sk; } else // If we didn't drop any element, check if we have to update the current // worst score! if (sk < worstQueueScore) { worstQueueScore = sk; } open.add(new ParentSetEntry(p, sk, p_values)); } } }