/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.solr.ltr.model;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.util.SolrPluginUtils;

/**
 * A scoring model that computes scores based on the summation of multiple weighted trees.
 * Example models are LambdaMART and Gradient Boosted Regression Trees (GBRT) .
 * <p>
 * Example configuration:
<pre>{
   "class" : "org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
   "name" : "multipleadditivetreesmodel",
   "features":[
       { "name" : "userTextTitleMatch"},
       { "name" : "originalScore"}
   ],
   "params" : {
       "trees" : [
           {
               "weight" : "1",
               "root": {
                   "feature" : "userTextTitleMatch",
                   "threshold" : "0.5",
                   "left" : {
                       "value" : "-100"
                   },
                   "right" : {
                       "feature" : "originalScore",
                       "threshold" : "10.0",
                       "left" : {
                           "value" : "50"
                       },
                       "right" : {
                           "value" : "75"
                       }
                   }
               }
           },
           {
               "weight" : "2",
               "root" : {
                   "value" : "-10"
               }
           }
       ]
   }
}</pre>
 * <p>
 * Training libraries:
 * <ul>
 * <li> <a href="http://sourceforge.net/p/lemur/wiki/RankLib/">RankLib</a>
 * </ul>
 * <p>
 * Background reading:
 * <ul>
 * <li> <a href="http://research.microsoft.com/pubs/132652/MSR-TR-2010-82.pdf">
 * Christopher J.C. Burges. From RankNet to LambdaRank to LambdaMART: An Overview.
 * Microsoft Research Technical Report MSR-TR-2010-82.</a>
 * </ul>
 * <ul>
 * <li> <a href="https://papers.nips.cc/paper/3305-a-general-boosting-method-and-its-application-to-learning-ranking-functions-for-web-search.pdf">
 * Z. Zheng, H. Zha, T. Zhang, O. Chapelle, K. Chen, and G. Sun. A General Boosting Method and its Application to Learning Ranking Functions for Web Search.
 * Advances in Neural Information Processing Systems (NIPS), 2007.</a>
 * </ul>
 */
public class MultipleAdditiveTreesModel extends LTRScoringModel {

  /**
   * fname2index is filled from constructor arguments
   * (that are already part of the base class hashCode)
   * and therefore here it does not individually
   * influence the class hashCode, equals, etc.
   */
  private final HashMap<String,Integer> fname2index;
  /**
   * trees is part of the LTRScoringModel params map
   * and therefore here it does not individually
   * influence the class hashCode, equals, etc.
   */
  private List<RegressionTree> trees;

  private RegressionTree createRegressionTree(Map<String,Object> map) {
    final RegressionTree rt = new RegressionTree();
    if (map != null) {
      SolrPluginUtils.invokeSetters(rt, map.entrySet());
    }
    return rt;
  }

  private RegressionTreeNode createRegressionTreeNode(Map<String,Object> map) {
    final RegressionTreeNode rtn = new RegressionTreeNode();
    if (map != null) {
      SolrPluginUtils.invokeSetters(rtn, map.entrySet());
    }
    return rtn;
  }

  public class RegressionTreeNode {
    private static final float NODE_SPLIT_SLACK = 1E-6f;

    private float value = 0f;
    private String feature;
    private int featureIndex = -1;
    private Float threshold;
    private RegressionTreeNode left;
    private RegressionTreeNode right;

    public void setValue(float value) {
      this.value = value;
    }

    public void setValue(String value) {
      this.value = Float.parseFloat(value);
    }

    public void setFeature(String feature) {
      this.feature = feature;
      final Integer idx = fname2index.get(this.feature);
      // this happens if the tree specifies a feature that does not exist
      // this could be due to lambdaSmart building off of pre-existing trees
      // that use a feature that is no longer output during feature extraction
      featureIndex = (idx == null) ? -1 : idx;
    }

    public void setThreshold(float threshold) {
      this.threshold = threshold + NODE_SPLIT_SLACK;
    }

    public void setThreshold(String threshold) {
      this.threshold = Float.parseFloat(threshold) + NODE_SPLIT_SLACK;
    }

    @SuppressWarnings({"unchecked"})
    public void setLeft(Object left) {
      this.left = createRegressionTreeNode((Map<String,Object>) left);
    }

    @SuppressWarnings({"unchecked"})
    public void setRight(Object right) {
      this.right = createRegressionTreeNode((Map<String,Object>) right);
    }

    public boolean isLeaf() {
      return feature == null;
    }

    public float score(float[] featureVector) {
      if (isLeaf()) {
        return value;
      }

      // unsupported feature (tree is looking for a feature that does not exist)
      if  ((featureIndex < 0) || (featureIndex >= featureVector.length)) {
        return 0f;
      }

      if (featureVector[featureIndex] <= threshold) {
        return left.score(featureVector);
      } else {
        return right.score(featureVector);
      }
    }

    public String explain(float[] featureVector) {
      if (isLeaf()) {
        return "val: " + value;
      }

      // unsupported feature (tree is looking for a feature that does not exist)
      if  ((featureIndex < 0) || (featureIndex >= featureVector.length)) {
        return  "'" + feature + "' does not exist in FV, Return Zero";
      }

      // could store extra information about how much training data supported
      // each branch and report
      // that here

      if (featureVector[featureIndex] <= threshold) {
        String rval = "'" + feature + "':" + featureVector[featureIndex] + " <= "
            + threshold + ", Go Left | ";
        return rval + left.explain(featureVector);
      } else {
        String rval = "'" + feature + "':" + featureVector[featureIndex] + " > "
            + threshold + ", Go Right | ";
        return rval + right.explain(featureVector);
      }
    }

    @Override
    public String toString() {
      final StringBuilder sb = new StringBuilder();
      if (isLeaf()) {
        sb.append(value);
      } else {
        sb.append("(feature=").append(feature);
        sb.append(",threshold=").append(threshold.floatValue()-NODE_SPLIT_SLACK);
        sb.append(",left=").append(left);
        sb.append(",right=").append(right);
        sb.append(')');
      }
      return sb.toString();
    }

    public RegressionTreeNode() {
    }

    public void validate() throws ModelException {
      if (isLeaf()) {
        if (left != null || right != null) {
          throw new ModelException("MultipleAdditiveTreesModel tree node is leaf with left="+left+" and right="+right);
        }
        return;
      }
      if (null == threshold) {
        throw new ModelException("MultipleAdditiveTreesModel tree node is missing threshold");
      }
      if (null == left) {
        throw new ModelException("MultipleAdditiveTreesModel tree node is missing left");
      } else {
        left.validate();
      }
      if (null == right) {
        throw new ModelException("MultipleAdditiveTreesModel tree node is missing right");
      } else {
        right.validate();
      }
    }

  }

  public class RegressionTree {

    private Float weight;
    private RegressionTreeNode root;

    public void setWeight(float weight) {
      this.weight = weight;
    }

    public void setWeight(String weight) {
      this.weight = Float.valueOf(weight);
    }

    @SuppressWarnings({"unchecked"})
    public void setRoot(Object root) {
      this.root = createRegressionTreeNode((Map<String,Object>)root);
    }

    public float score(float[] featureVector) {
      return weight.floatValue() * root.score(featureVector);
    }

    public String explain(float[] featureVector) {
      return root.explain(featureVector);
    }

    @Override
    public String toString() {
      final StringBuilder sb = new StringBuilder();
      sb.append("(weight=").append(weight);
      sb.append(",root=").append(root);
      sb.append(")");
      return sb.toString();
    }

    public RegressionTree() {
    }

    public void validate() throws ModelException {
      if (weight == null) {
        throw new ModelException("MultipleAdditiveTreesModel tree doesn't contain a weight");
      }
      if (root == null) {
        throw new ModelException("MultipleAdditiveTreesModel tree doesn't contain a tree");
      } else {
        root.validate();
      }
    }
  }

  @SuppressWarnings({"unchecked"})
  public void setTrees(Object trees) {
    this.trees = new ArrayList<RegressionTree>();
    for (final Object o : (List<Object>) trees) {
      final RegressionTree rt = createRegressionTree((Map<String,Object>) o);
      this.trees.add(rt);
    }
  }

  public MultipleAdditiveTreesModel(String name, List<Feature> features,
      List<Normalizer> norms,
      String featureStoreName, List<Feature> allFeatures,
      Map<String,Object> params) {
    super(name, features, norms, featureStoreName, allFeatures, params);

    fname2index = new HashMap<String,Integer>();
    for (int i = 0; i < features.size(); ++i) {
      final String key = features.get(i).getName();
      fname2index.put(key, i);
    }
  }

  @Override
  protected void validate() throws ModelException {
    super.validate();
    if (trees == null) {
      throw new ModelException("no trees declared for model "+name);
    }
    for (RegressionTree tree : trees) {
      tree.validate();
    }
  }

  @Override
  public float score(float[] modelFeatureValuesNormalized) {
    float score = 0;
    for (final RegressionTree t : trees) {
      score += t.score(modelFeatureValuesNormalized);
    }
    return score;
  }

  // /////////////////////////////////////////
  // produces a string that looks like:
  // 40.0 = multipleadditivetreesmodel [ org.apache.solr.ltr.model.MultipleAdditiveTreesModel ]
  // model applied to
  // features, sum of:
  // 50.0 = tree 0 | 'matchedTitle':1.0 > 0.500001, Go Right |
  // 'this_feature_doesnt_exist' does not
  // exist in FV, Go Left | val: 50.0
  // -10.0 = tree 1 | val: -10.0
  @Override
  public Explanation explain(LeafReaderContext context, int doc,
      float finalScore, List<Explanation> featureExplanations) {
    final float[] fv = new float[featureExplanations.size()];
    int index = 0;
    for (final Explanation featureExplain : featureExplanations) {
      fv[index] = featureExplain.getValue().floatValue();
      index++;
    }

    final List<Explanation> details = new ArrayList<>();
    index = 0;

    for (final RegressionTree t : trees) {
      final float score = t.score(fv);
      final Explanation p = Explanation.match(score, "tree " + index + " | "
          + t.explain(fv));
      details.add(p);
      index++;
    }

    return Explanation.match(finalScore, toString()
        + " model applied to features, sum of:", details);
  }

  @Override
  public String toString() {
    final StringBuilder sb = new StringBuilder(getClass().getSimpleName());
    sb.append("(name=").append(getName());
    sb.append(",trees=[");
    for (int ii = 0; ii < trees.size(); ++ii) {
      if (ii>0) {
        sb.append(',');
      }
      sb.append(trees.get(ii));
    }
    sb.append("])");
    return sb.toString();
  }

}