/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    FTNode.java
 *    Copyright (C) 2007 University of Porto, Porto, Portugal
 *
 */

package weka.classifiers.trees.ft;

import weka.classifiers.functions.SimpleLinearRegression;
import weka.classifiers.trees.j48.BinC45ModelSelection;
import weka.classifiers.trees.j48.BinC45Split;
import weka.classifiers.trees.j48.C45Split;
import weka.classifiers.trees.j48.ClassifierSplitModel;
import weka.classifiers.trees.j48.Distribution;
import weka.classifiers.trees.j48.ModelSelection;
import weka.classifiers.trees.j48.Stats;
import weka.classifiers.trees.lmt.LogisticBase;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.supervised.attribute.NominalToBinary;

import java.util.Vector;

/**
 * Abstract class for Functional tree structure. 
 * 
 * @author Jo\~{a}o Gama
 * @author Carlos Ferreira
 *
 * @version $Revision: 1.4 $
 */
public abstract class FTtree 
  extends LogisticBase {   
  
  /** for serialization */
  static final long serialVersionUID = 1862737145870398755L;
    
  /** Total number of training instances. */
  protected double m_totalInstanceWeight;
    
  /** Node id*/
  protected int m_id;
    
  /** ID of logistic model at leaf*/
  protected int m_leafModelNum;
 
  /**minimum number of instances at which a node is considered for splitting*/
  protected int m_minNumInstances;

  /**ModelSelection object (for splitting)*/
  protected ModelSelection m_modelSelection;     

  /**Filter to convert nominal attributes to binary*/
  protected NominalToBinary m_nominalToBinary;  
   
  /**Simple regression functions fit by LogitBoost at higher levels in the tree*/
  protected SimpleLinearRegression[][] m_higherRegressions;
    
  /**Number of simple regression functions fit by LogitBoost at higher levels in the tree*/
  protected int m_numHigherRegressions = 0;
    
  /**Number of instances at the node*/
  protected int m_numInstances;   

  /**The ClassifierSplitModel (for splitting)*/
  protected ClassifierSplitModel m_localModel; 
    
  /**Auxiliary copy ClassifierSplitModel (for splitting)*/
  protected ClassifierSplitModel m_auxLocalModel; 
 
  /**Array of children of the node*/
  protected FTtree[] m_sons; 
   
  /** Stores leaf class value */ 
  protected int m_leafclass;
    
  /**True if node is leaf*/
  protected boolean m_isLeaf;
    
  /**True if node has or splits on constructor */
  protected boolean m_hasConstr=true;
    
  /** Constructor error */
  protected double  m_constError=0;
    
  /** Confidence level */
  protected float m_CF = 0.10f;  
                       
  /**
   * Method for building a Functional Tree (only called for the root node).
   * Grows an initial Functional Tree.
   *
   * @param data the data to train with
   * @throws Exception if something goes wrong
   */
  public abstract void buildClassifier(Instances data) throws Exception;

  /**
   * Abstract method for building the tree structure.
   * Builds a logistic model, splits the node and recursively builds tree for child nodes.
   * @param data the training data passed on to this node
   * @param higherRegressions An array of regression functions produced by LogitBoost at higher 
   * levels in the tree. They represent a logistic regression model that is refined locally 
   * at this node.
   * @param totalInstanceWeight the total number of training examples
   * @param higherNumParameters effective number of parameters in the logistic regression model built
   * in parent nodes
   * @throws Exception if something goes wrong
   */
  public abstract void buildTree(Instances data, SimpleLinearRegression[][] higherRegressions, 
                                 double totalInstanceWeight, double higherNumParameters) throws Exception;
    
  /**
   * Abstract Method that prunes a tree using C4.5 pruning procedure.
   *
   * @exception Exception if something goes wrong
   */
  public abstract double prune() throws Exception; 
 
  /** Inserts new attributes in current dataset or instance 
   *
   * @exception Exception if something goes wrong
   */
  protected Instances insertNewAttr(Instances data) throws Exception{
    
    int i;
    for (i=0; i<data.classAttribute().numValues(); i++)
      {
        data.insertAttributeAt( new Attribute("N"+ i), i); 
      }
    return data;
  }

  /** Removes extended attributes in current dataset or instance 
   *
   * @exception Exception if something goes wrong
   */
  protected Instances removeExtAttributes(Instances  data) throws Exception{
    
    for (int i=0; i< data.classAttribute().numValues(); i++)
      {
        data.deleteAttributeAt(0);
      }
    return data;
  }

  /**
   * Computes estimated errors for tree.
   */
  protected double getEstimatedErrors(){

    double errors = 0;
    int i;

    if (m_isLeaf)
      return getEstimatedErrorsForDistribution(m_localModel.distribution());
    else{
      for (i=0;i<m_sons.length;i++)
        errors = errors+ m_sons[i].getEstimatedErrors();

      return errors;
    }
  }

  /**
   * Computes estimated errors for one branch.
   *
   * @exception Exception if something goes wrong
   */
  protected double getEstimatedErrorsForBranch(Instances data)
    throws Exception {

    Instances [] localInstances;
    double errors = 0;
    int i;

    if (m_isLeaf)
      return getEstimatedErrorsForDistribution(new Distribution(data));
    else{
      Distribution savedDist = m_localModel.distribution();
      m_localModel.resetDistribution(data);
      localInstances = (Instances[])m_localModel.split(data);
      //m_localModel.m_distribution=savedDist;
      for (i=0;i<m_sons.length;i++)
        errors = errors+
          m_sons[i].getEstimatedErrorsForBranch(localInstances[i]);
      return errors;
    }
  }

  /**
   * Computes estimated errors for leaf.
   */
  protected double getEstimatedErrorsForDistribution(Distribution
                                                     theDistribution){
    double numInc;
    double numTotal;
    if (Utils.eq(theDistribution.total(),0))
      return 0;
    else// stats.addErrs returns p - numberofincorrect.=p
      {
        numInc=theDistribution.numIncorrect();
        numTotal=theDistribution.total();
        return ((Stats.addErrs(numTotal, numInc,m_CF)) + numInc)/numTotal;
      }

  }

  /**
   * Computes estimated errors for Constructor Model.
   */
  protected double getEtimateConstModel(Distribution theDistribution){
    double numInc;
    double numTotal;
    if (Utils.eq(theDistribution.total(),0))
      return 0;
    else// stats.addErrs returns p - numberofincorrect.=p
      {
        numTotal=theDistribution.total();
        return ((Stats.addErrs(numTotal,m_constError,m_CF)) + m_constError)/numTotal;
      }
  }
    

  /**
   * Method to count the number of inner nodes in the tree
   * @return the number of inner nodes
   */
  public int getNumInnerNodes(){
    if (m_isLeaf) return 0;
    int numNodes = 1;
    for (int i = 0; i < m_sons.length; i++) numNodes += m_sons[i].getNumInnerNodes();
    return numNodes;
  }

  /**
   * Returns the number of leaves in the tree.
   * Leaves are only counted if their logistic model has changed compared to the one of the parent node.
   * @return the number of leaves
   */
  public int getNumLeaves(){
    int numLeaves;
    if (!m_isLeaf) {
      numLeaves = 0;
      int numEmptyLeaves = 0;
      for (int i = 0; i < m_sons.length; i++) {
        numLeaves += m_sons[i].getNumLeaves();
        if (m_sons[i].m_isLeaf && !m_sons[i].hasModels()) numEmptyLeaves++;
      }
      if (numEmptyLeaves > 1) {
        numLeaves -= (numEmptyLeaves - 1);
      }
    } else {
      numLeaves = 1;
    }	   
    return numLeaves;	
  }


     
  /**
   * Merges two arrays of regression functions into one
   * @param a1 one array
   * @param a2 the other array
   *
   * @return an array that contains all entries from both input arrays
   */
  protected SimpleLinearRegression[][] mergeArrays(SimpleLinearRegression[][] a1,	
                                                   SimpleLinearRegression[][] a2){
    int numModels1 = a1[0].length;
    int numModels2 = a2[0].length;		
	
    SimpleLinearRegression[][] result =
      new SimpleLinearRegression[m_numClasses][numModels1 + numModels2];
	
    for (int i = 0; i < m_numClasses; i++)
      for (int j = 0; j < numModels1; j++) {
        result[i][j]  = a1[i][j];
      }
    for (int i = 0; i < m_numClasses; i++)
      for (int j = 0; j < numModels2; j++) result[i][j+numModels1] = a2[i][j];
    return result;
  }

  /**
   * Return a list of all inner nodes in the tree
   * @return the list of nodes
   */
  public Vector getNodes(){
    Vector nodeList = new Vector();
    getNodes(nodeList);
    return nodeList;
  }

  /**
   * Fills a list with all inner nodes in the tree
   * 
   * @param nodeList the list to be filled
   */
  public void getNodes(Vector nodeList) {
    if (!m_isLeaf) {
      nodeList.add(this);
      for (int i = 0; i < m_sons.length; i++) m_sons[i].getNodes(nodeList);
    }	
  }
    
  /**
   * Returns a numeric version of a set of instances.
   * All nominal attributes are replaced by binary ones, and the class variable is replaced
   * by a pseudo-class variable that is used by LogitBoost.
   */
  protected Instances getNumericData(Instances train) throws Exception{
	
    Instances filteredData = new Instances(train);	
    m_nominalToBinary = new NominalToBinary();			
    m_nominalToBinary.setInputFormat(filteredData);
    filteredData = Filter.useFilter(filteredData, m_nominalToBinary);	

    return super.getNumericData(filteredData);
  }

  /**
   * Computes the F-values of LogitBoost for an instance from the current logistic model at the node
   * Note that this also takes into account the (partial) logistic model fit at higher levels in 
   * the tree.
   * @param instance the instance
   * @return the array of F-values 
   */
  protected double[] getFs(Instance instance) throws Exception{
	
    double [] pred = new double [m_numClasses];
	
    //Need to take into account partial model fit at higher levels in the tree (m_higherRegressions) 
    //and the part of the model fit at this node (m_regressions).

    //Fs from m_regressions (use method of LogisticBase)
    double [] instanceFs = super.getFs(instance);		

    //Fs from m_higherRegressions
    for (int i = 0; i < m_numHigherRegressions; i++) {
      double predSum = 0;
      for (int j = 0; j < m_numClasses; j++) {
        pred[j] = m_higherRegressions[j][i].classifyInstance(instance);
        predSum += pred[j];
      }
      predSum /= m_numClasses;
      for (int j = 0; j < m_numClasses; j++) {
        instanceFs[j] += (pred[j] - predSum) * (m_numClasses - 1) 
          / m_numClasses;
      }
    }
    return instanceFs; 
  }
     
  /**
   *
   * @param <any> probsConst
   */
  public int getConstError(double[] probsConst)
  {
    return Utils.maxIndex(probsConst);
  }
    
  /**
   *Returns true if the logistic regression model at this node has changed compared to the
   *one at the parent node.
   *@return whether it has changed
   */
  public boolean hasModels() {
    return (m_numRegressions > 0);
  }

  /**
   * Returns the class probabilities for an instance according to the logistic model at the node.
   * @param instance the instance
   * @return the array of probabilities
   */
  public double[] modelDistributionForInstance(Instance instance) throws Exception {
	
    //make copy and convert nominal attributes
    instance = (Instance)instance.copy();		
    m_nominalToBinary.input(instance);
    instance = m_nominalToBinary.output();	
	
    //set numeric pseudo-class
    instance.setDataset(m_numericDataHeader);		
	
    return probs(getFs(instance));
  }

  /**
   * Returns the class probabilities for an instance given by the Functional tree.
   * @param instance the instance
   * @return the array of probabilities
   */
  public abstract double[] distributionForInstance(Instance instance) throws Exception;
  
  
    
  /**
   * Returns a description of the Functional tree (tree structure and logistic models)
   * @return describing string
   */
  public String toString(){	
    //assign numbers to logistic regression functions at leaves
    assignLeafModelNumbers(0);	
    try{
      StringBuffer text = new StringBuffer();
	    
      if (m_isLeaf && !m_hasConstr) {
        text.append(": ");
        text.append("Class"+"="+ m_leafclass);
        //text.append("FT_"+m_leafModelNum+":"+getModelParameters());
      } else {
                
        if (m_isLeaf && m_hasConstr) {
          text.append(": ");
          text.append("FT_"+m_leafModelNum+":"+getModelParameters());
                    
        } else {
          dumpTree(0,text);  
        }	    	    
      }
      text.append("\n\nNumber of Leaves  : \t"+numLeaves()+"\n");
      text.append("\nSize of the Tree : \t"+numNodes()+"\n");	
	        
      //This prints logistic models after the tree, comment out if only tree should be printed
      text.append(modelsToString());
      return text.toString();
    } catch (Exception e){
      return "Can't print logistic model tree";
    }
  }
    
  /**
   * Returns the number of leaves (normal count).
   * @return the number of leaves
   */
  public int numLeaves() {	
    if (m_isLeaf) return 1;	
    int numLeaves = 0;
    for (int i = 0; i < m_sons.length; i++) numLeaves += m_sons[i].numLeaves();
    return numLeaves;
  }
    
  /**
   * Returns the number of nodes.
   * @return the number of nodes
   */
  public int numNodes() {
    if (m_isLeaf) return 1;	
    int numNodes = 1;
    for (int i = 0; i < m_sons.length; i++) numNodes += m_sons[i].numNodes();
    return numNodes;
  }
   
  /**
   * Returns a string describing the number of LogitBoost iterations performed at this node, the total number
   * of LogitBoost iterations performed (including iterations at higher levels in the tree), and the number
   * of training examples at this node.
   * @return the describing string
   */
  public String getModelParameters(){
	
    StringBuffer text = new StringBuffer();
    int numModels = m_numRegressions+m_numHigherRegressions;
    text.append(m_numRegressions+"/"+numModels+" ("+m_numInstances+")");
    return text.toString();
  }
       
  /**
   * Help method for printing tree structure.
   *
   * @throws Exception if something goes wrong
   */
  protected void dumpTree(int depth,StringBuffer text) 
    throws Exception {
	
    for (int i = 0; i < m_sons.length; i++) {
      text.append("\n");
      for (int j = 0; j < depth; j++)
        text.append("|   ");
      if(m_hasConstr)
        text.append(m_localModel.leftSide(m_train)+ "#" + m_id);
      else 
        text.append(m_localModel.leftSide(m_train)); 
      text.append(m_localModel.rightSide(i, m_train) );
      if (m_sons[i].m_isLeaf && m_sons[i].m_hasConstr ) {
        text.append(": ");
        text.append("FT_"+m_sons[i].m_leafModelNum+":"+m_sons[i].getModelParameters());
      }else {                
        if(m_sons[i].m_isLeaf && !m_sons[i].m_hasConstr)
          {
            text.append(": ");
            text.append("Class"+"="+ m_sons[i].m_leafclass);  
          }
        else{
            
          m_sons[i].dumpTree(depth+1,text);
        }
      }
    }
  }

  /**
   * Assigns unique IDs to all nodes in the tree
   */
  public int assignIDs(int lastID) {
	
    int currLastID = lastID + 1;
	
    m_id = currLastID;
    if (m_sons != null) {
      for (int i = 0; i < m_sons.length; i++) {
        currLastID = m_sons[i].assignIDs(currLastID);
      }
    }
    return currLastID;
  }
    
  /**
   * Assigns numbers to the logistic regression models at the leaves of the tree
   */
  public int assignLeafModelNumbers(int leafCounter) {
    if (!m_isLeaf) {
      m_leafModelNum = 0;
      for (int i = 0; i < m_sons.length; i++){
        leafCounter = m_sons[i].assignLeafModelNumbers(leafCounter);
      }
    } else {
      leafCounter++;
      m_leafModelNum = leafCounter;
    } 
    return leafCounter;
  }

  /**
   * Returns an array containing the coefficients of the logistic regression function at this node.
   * @return the array of coefficients, first dimension is the class, second the attribute. 
   */
  protected double[][] getCoefficients(){
       
    //Need to take into account partial model fit at higher levels in the tree (m_higherRegressions) 
    //and the part of the model fit at this node (m_regressions).
	
    //get coefficients from m_regressions: use method of LogisticBase
    double[][] coefficients = super.getCoefficients();
    //get coefficients from m_higherRegressions:
    double constFactor = (double)(m_numClasses - 1) / (double)m_numClasses; // (J - 1)/J
    for (int j = 0; j < m_numClasses; j++) {
      for (int i = 0; i < m_numHigherRegressions; i++) {		
        double slope = m_higherRegressions[j][i].getSlope();
        double intercept = m_higherRegressions[j][i].getIntercept();
        int attribute = m_higherRegressions[j][i].getAttributeIndex();
        coefficients[j][0] += constFactor * intercept;
        coefficients[j][attribute + 1] += constFactor * slope;
      }
    }

    return coefficients;
  }
    
  /**
   * Returns a string describing the logistic regression function at the node.
   */
  public String modelsToString(){
	
    StringBuffer text = new StringBuffer();
    if (m_isLeaf && m_hasConstr) {
      text.append("FT_"+m_leafModelNum+":"+super.toString());
            
    }else{
      if (!m_isLeaf && m_hasConstr) {
        if (m_modelSelection instanceof BinC45ModelSelection){
          text.append("FT_N"+((BinC45Split)m_localModel).attIndex()+"#"+m_id +":"+super.toString()); 
        }else{
          text.append("FT_N"+((C45Split)m_localModel).attIndex()+"#"+m_id +":"+super.toString());
        }
        for (int i = 0; i < m_sons.length; i++) { 
          text.append("\n"+ m_sons[i].modelsToString());
        }
      }else{
        if (!m_isLeaf && !m_hasConstr) 
          {
            for (int i = 0; i < m_sons.length; i++) { 
              text.append("\n"+ m_sons[i].modelsToString());
            }
          }else{
          if (m_isLeaf && !m_hasConstr)
            {
              text.append("");
            }
        }
                
      }
    }
        
    return text.toString();
  }

  /**
   * Returns graph describing the tree.
   *
   * @throws Exception if something goes wrong
   */
  public String graph() throws Exception {
	
    StringBuffer text = new StringBuffer();
	
    assignIDs(-1);
    assignLeafModelNumbers(0);
    text.append("digraph FTree {\n");
    if (m_isLeaf && m_hasConstr) {
      text.append("N" + m_id + " [label=\"FT_"+m_leafModelNum+":"+getModelParameters()+"\" " + 
                  "shape=box style=filled");
      text.append("]\n");
    }else{
      if (m_isLeaf && !m_hasConstr){
        text.append("N" + m_id + " [label=\"Class="+m_leafclass+ "\" " + 
                    "shape=box style=filled");
        text.append("]\n");
             
      }else {
        text.append("N" + m_id 
                    + " [label=\"" + 
                    m_localModel.leftSide(m_train) + "\" ");
        text.append("]\n");
        graphTree(text);
      }
    }
    return text.toString() +"}\n";
  }

  /**
   * Helper function for graph description of tree
   *
   * @throws Exception if something goes wrong
   */
  protected void graphTree(StringBuffer text) throws Exception {
	
    for (int i = 0; i < m_sons.length; i++) {
      text.append("N" + m_id  
                  + "->" + 
                  "N" + m_sons[i].m_id +
                  " [label=\"" + m_localModel.rightSide(i,m_train).trim() + 
                  "\"]\n");
      if (m_sons[i].m_isLeaf && m_sons[i].m_hasConstr) {
        text.append("N" +m_sons[i].m_id + " [label=\"FT_"+m_sons[i].m_leafModelNum+":"+
                    m_sons[i].getModelParameters()+"\" " + "shape=box style=filled");
        text.append("]\n");
      } else { 
        if (m_sons[i].m_isLeaf && !m_sons[i].m_hasConstr) {
          text.append("N" +m_sons[i].m_id + " [label=\"Class="+m_sons[i].m_leafclass+"\" " + "shape=box style=filled");
          text.append("]\n");
        }else{
          text.append("N" + m_sons[i].m_id +
                      " [label=\""+m_sons[i].m_localModel.leftSide(m_train) + 
                      "\" ");
          text.append("]\n");
          m_sons[i].graphTree(text);
        }
      }
    } 
  }  

  /**
   * Cleanup in order to save memory.
   */
  public void cleanup() {
    super.cleanup();
    if (!m_isLeaf) {
      for (int i = 0; i < m_sons.length; i++) m_sons[i].cleanup();
    }
  }
  
  /**
   * Returns the revision string.
   * 
   * @return		the revision
   */
  public String getRevision() {
    return RevisionUtils.extract("$Revision: 1.4 $");
  }
}