/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.solr.ltr.response.transform;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.ScoreMode;
import org.apache.solr.common.SolrDocument;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.ltr.CSVFeatureLogger;
import org.apache.solr.ltr.FeatureLogger;
import org.apache.solr.ltr.LTRRescorer;
import org.apache.solr.ltr.LTRScoringQuery;
import org.apache.solr.ltr.LTRThreadModule;
import org.apache.solr.ltr.SolrQueryRequestContextUtils;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.ltr.search.LTRQParserPlugin;
import org.apache.solr.ltr.store.FeatureStore;
import org.apache.solr.ltr.store.rest.ManagedFeatureStore;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.response.ResultContext;
import org.apache.solr.response.transform.DocTransformer;
import org.apache.solr.response.transform.TransformerFactory;
import org.apache.solr.search.SolrIndexSearcher;
import org.apache.solr.util.SolrPluginUtils;

/**
 * This transformer will take care to generate and append in the response the
 * features declared in the feature store of the current reranking model,
 * or a specified feature store.  Ex. <code>fl=id,[features store=myStore efi.user_text="ibm"]</code>
 * 
 * <h2>Parameters</h2>
 * <code>store</code> - The feature store to extract features from. If not provided it
 * will default to the features used by your reranking model.<br>
 * <code>efi.*</code> - External feature information variables required by the features
 * you are extracting.<br>
 * <code>format</code> - The format you want the features to be returned in.  Supports (dense|sparse). Defaults to dense.<br>
*/

public class LTRFeatureLoggerTransformerFactory extends TransformerFactory {

  // used inside fl to specify the format (dense|sparse) of the extracted features
  private static final String FV_FORMAT = "format";

  // used inside fl to specify the feature store to use for the feature extraction
  private static final String FV_STORE = "store";

  private static String DEFAULT_LOGGING_MODEL_NAME = "logging-model";

  private String fvCacheName;
  private String loggingModelName = DEFAULT_LOGGING_MODEL_NAME;
  private String defaultStore;
  private FeatureLogger.FeatureFormat defaultFormat = FeatureLogger.FeatureFormat.DENSE;
  private char csvKeyValueDelimiter = CSVFeatureLogger.DEFAULT_KEY_VALUE_SEPARATOR;
  private char csvFeatureSeparator = CSVFeatureLogger.DEFAULT_FEATURE_SEPARATOR;

  private LTRThreadModule threadManager = null;

  public void setFvCacheName(String fvCacheName) {
    this.fvCacheName = fvCacheName;
  }

  public void setLoggingModelName(String loggingModelName) {
    this.loggingModelName = loggingModelName;
  }

  public void setDefaultStore(String defaultStore) {
    this.defaultStore = defaultStore;
  }

  public void setDefaultFormat(String defaultFormat) {
    this.defaultFormat = FeatureLogger.FeatureFormat.valueOf(defaultFormat.toUpperCase(Locale.ROOT));
  }

  public void setCsvKeyValueDelimiter(String csvKeyValueDelimiter) {
    if (csvKeyValueDelimiter.length() != 1) {
      throw new IllegalArgumentException("csvKeyValueDelimiter must be exactly 1 character");
    }
    this.csvKeyValueDelimiter = csvKeyValueDelimiter.charAt(0);
  }

  public void setCsvFeatureSeparator(String csvFeatureSeparator) {
    if (csvFeatureSeparator.length() != 1) {
      throw new IllegalArgumentException("csvFeatureSeparator must be exactly 1 character");
    }
    this.csvFeatureSeparator = csvFeatureSeparator.charAt(0);
  }

  @Override
  @SuppressWarnings({"unchecked"})
  public void init(@SuppressWarnings("rawtypes") NamedList args) {
    super.init(args);
    threadManager = LTRThreadModule.getInstance(args);
    SolrPluginUtils.invokeSetters(this, args);
  }

  @Override
  public DocTransformer create(String name, SolrParams localparams,
      SolrQueryRequest req) {

    // Hint to enable feature vector cache since we are requesting features
    SolrQueryRequestContextUtils.setIsExtractingFeatures(req);

    // Communicate which feature store we are requesting features for
    SolrQueryRequestContextUtils.setFvStoreName(req, localparams.get(FV_STORE, defaultStore));

    // Create and supply the feature logger to be used
    SolrQueryRequestContextUtils.setFeatureLogger(req,
        createFeatureLogger(
            localparams.get(FV_FORMAT)));

    return new FeatureTransformer(name, localparams, req);
  }

  /**
   * returns a FeatureLogger that logs the features
   * 'featureFormat' param: 'dense' will write features in dense format,
   * 'sparse' will write the features in sparse format, null or empty will
   * default to 'sparse'
   *
   *
   * @return a feature logger for the format specified.
   */
  private FeatureLogger createFeatureLogger(String formatStr) {
    final FeatureLogger.FeatureFormat format;
    if (formatStr != null) {
      format = FeatureLogger.FeatureFormat.valueOf(formatStr.toUpperCase(Locale.ROOT));
    } else {
      format = this.defaultFormat;
    }
    if (fvCacheName == null) {
      throw new IllegalArgumentException("a fvCacheName must be configured");
    }
    return new CSVFeatureLogger(fvCacheName, format, csvKeyValueDelimiter, csvFeatureSeparator);
  }

  class FeatureTransformer extends DocTransformer {

    final private String name;
    final private SolrParams localparams;
    final private SolrQueryRequest req;

    private List<LeafReaderContext> leafContexts;
    private SolrIndexSearcher searcher;
    private LTRScoringQuery scoringQuery;
    private LTRScoringQuery.ModelWeight modelWeight;
    private FeatureLogger featureLogger;
    private boolean docsWereNotReranked;

    /**
     * @param name
     *          Name of the field to be added in a document representing the
     *          feature vectors
     */
    public FeatureTransformer(String name, SolrParams localparams,
        SolrQueryRequest req) {
      this.name = name;
      this.localparams = localparams;
      this.req = req;
    }

    @Override
    public String getName() {
      return name;
    }

    @Override
    public void setContext(ResultContext context) {
      super.setContext(context);
      if (context == null) {
        return;
      }
      if (context.getRequest() == null) {
        return;
      }

      searcher = context.getSearcher();
      if (searcher == null) {
        throw new SolrException(
            SolrException.ErrorCode.BAD_REQUEST,
            "searcher is null");
      }
      leafContexts = searcher.getTopReaderContext().leaves();
      if (threadManager != null) {
        threadManager.setExecutor(context.getRequest().getCore().getCoreContainer().getUpdateShardHandler().getUpdateExecutor());
      }
      
      // Setup LTRScoringQuery
      scoringQuery = SolrQueryRequestContextUtils.getScoringQuery(req);
      docsWereNotReranked = (scoringQuery == null);
      String featureStoreName = SolrQueryRequestContextUtils.getFvStoreName(req);
      if (docsWereNotReranked || (featureStoreName != null && (!featureStoreName.equals(scoringQuery.getScoringModel().getFeatureStoreName())))) {
        // if store is set in the transformer we should overwrite the logger

        final ManagedFeatureStore fr = ManagedFeatureStore.getManagedFeatureStore(req.getCore());

        final FeatureStore store = fr.getFeatureStore(featureStoreName);
        featureStoreName = store.getName(); // if featureStoreName was null before this gets actual name

        try {
          final LoggingModel lm = new LoggingModel(loggingModelName,
              featureStoreName, store.getFeatures());

          scoringQuery = new LTRScoringQuery(lm,
              LTRQParserPlugin.extractEFIParams(localparams),
              true,
              threadManager); // request feature weights to be created for all features

        }catch (final Exception e) {
          throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
              "retrieving the feature store "+featureStoreName, e);
        }
      }

      if (scoringQuery.getOriginalQuery() == null) {
        scoringQuery.setOriginalQuery(context.getQuery());
      }
      if (scoringQuery.getFeatureLogger() == null){
        scoringQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) );
      }
      scoringQuery.setRequest(req);

      featureLogger = scoringQuery.getFeatureLogger();

      try {
        modelWeight = scoringQuery.createWeight(searcher, ScoreMode.COMPLETE, 1f);
      } catch (final IOException e) {
        throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e.getMessage(), e);
      }
      if (modelWeight == null) {
        throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
            "error logging the features, model weight is null");
      }
    }

    @Override
    public void transform(SolrDocument doc, int docid, float score)
        throws IOException {
      implTransform(doc, docid, score);
    }

    @Override
    public void transform(SolrDocument doc, int docid)
        throws IOException {
      implTransform(doc, docid, null);
    }

    private void implTransform(SolrDocument doc, int docid, Float score)
        throws IOException {
      Object fv = featureLogger.getFeatureVector(docid, scoringQuery, searcher);
      if (fv == null) { // FV for this document was not in the cache
        fv = featureLogger.makeFeatureVector(
            LTRRescorer.extractFeaturesInfo(
                modelWeight,
                docid,
                (docsWereNotReranked ? score : null),
                leafContexts));
      }

      doc.addField(name, fv);
    }

  }

  private static class LoggingModel extends LTRScoringModel {

    public LoggingModel(String name, String featureStoreName, List<Feature> allFeatures){
      this(name, Collections.emptyList(), Collections.emptyList(),
          featureStoreName, allFeatures, Collections.emptyMap());
    }

    protected LoggingModel(String name, List<Feature> features,
        List<Normalizer> norms, String featureStoreName,
        List<Feature> allFeatures, Map<String,Object> params) {
      super(name, features, norms, featureStoreName, allFeatures, params);
    }

    @Override
    public float score(float[] modelFeatureValuesNormalized) {
      return 0;
    }

    @Override
    public Explanation explain(LeafReaderContext context, int doc, float finalScore, List<Explanation> featureExplanations) {
      return Explanation.match(finalScore, toString()
          + " logging model, used only for logging the features");
    }

  }

}