/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.solr.search.join;

import java.io.IOException;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.TimeUnit;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.DocIdSet;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.FixedBitSet;
import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.eq.FieldEqualitor;
import org.apache.solr.client.solrj.io.stream.CloudSolrStream;
import org.apache.solr.client.solrj.io.stream.SolrStream;
import org.apache.solr.client.solrj.io.stream.StreamContext;
import org.apache.solr.client.solrj.io.stream.TupleStream;
import org.apache.solr.client.solrj.io.stream.UniqueStream;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.cloud.CloudDescriptor;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.cloud.ClusterState;
import org.apache.solr.common.cloud.DocRouter;
import org.apache.solr.common.cloud.Slice;
import org.apache.solr.common.params.CommonParams;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.schema.FieldType;
import org.apache.solr.search.BitDocSet;
import org.apache.solr.search.DocSet;
import org.apache.solr.search.DocSetUtil;
import org.apache.solr.search.Filter;
import org.apache.solr.search.SolrIndexSearcher;

public class CrossCollectionJoinQuery extends Query {

  protected final String query;
  protected final String zkHost;
  protected final String solrUrl;
  protected final String collection;
  protected final String fromField;
  protected final String toField;
  protected final boolean routedByJoinKey;

  protected final long timestamp;
  protected final int ttl;

  protected SolrParams otherParams;
  protected String otherParamsString;

  public CrossCollectionJoinQuery(String query, String zkHost, String solrUrl,
                                  String collection, String fromField, String toField,
                                  boolean routedByJoinKey, int ttl, SolrParams otherParams) {

    this.query = query;
    this.zkHost = zkHost;
    this.solrUrl = solrUrl;
    this.collection = collection;
    this.fromField = fromField;
    this.toField = toField;
    this.routedByJoinKey = routedByJoinKey;

    this.timestamp = System.nanoTime();
    this.ttl = ttl;

    this.otherParams = otherParams;
    // SolrParams doesn't implement equals(), so use this string to compare them
    if (otherParams != null) {
      this.otherParamsString = otherParams.toString();
    }
  }

  private interface JoinKeyCollector {
    void collect(Object value) throws IOException;
    DocSet getDocSet() throws IOException;
  }

  private class TermsJoinKeyCollector implements JoinKeyCollector {

    FieldType fieldType;
    SolrIndexSearcher searcher;

    TermsEnum termsEnum;
    BytesRefBuilder bytes;
    PostingsEnum postingsEnum;

    FixedBitSet bitSet;

    public TermsJoinKeyCollector(FieldType fieldType, Terms terms, SolrIndexSearcher searcher) throws IOException {
      this.fieldType = fieldType;
      this.searcher = searcher;

      termsEnum = terms.iterator();
      bytes = new BytesRefBuilder();

      bitSet = new FixedBitSet(searcher.maxDoc());
    }

    @Override
    public void collect(Object value) throws IOException {
      fieldType.readableToIndexed((String) value, bytes);
      if (termsEnum.seekExact(bytes.get())) {
        postingsEnum = termsEnum.postings(postingsEnum, PostingsEnum.NONE);
        bitSet.or(postingsEnum);
      }
    }

    @Override
    public DocSet getDocSet() throws IOException {
      if (searcher.getIndexReader().hasDeletions()) {
        bitSet.and(searcher.getLiveDocSet().getBits());
      }
      return new BitDocSet(bitSet);
    }
  }

  private class PointJoinKeyCollector extends GraphPointsCollector implements JoinKeyCollector {

    SolrIndexSearcher searcher;

    public PointJoinKeyCollector(SolrIndexSearcher searcher) {
      super(searcher.getSchema().getField(toField), null, null);
      this.searcher = searcher;
    }

    @Override
    public void collect(Object value) throws IOException {
      if (value instanceof Long || value instanceof Integer) {
        set.add(((Number) value).longValue());
      } else {
        throw new UnsupportedOperationException("Unsupported field type for XCJFQuery");
      }
    }

    @Override
    public DocSet getDocSet() throws IOException {
      Query query = getResultQuery(searcher.getSchema().getField(toField), false);
      if (query == null) {
        return DocSet.empty();
      }
      return DocSetUtil.createDocSet(searcher, query, null);
    }
  }

  private class CrossCollectionJoinQueryWeight extends ConstantScoreWeight {

    private SolrIndexSearcher searcher;
    private ScoreMode scoreMode;
    private Filter filter;

    public CrossCollectionJoinQueryWeight(SolrIndexSearcher searcher, ScoreMode scoreMode, float score) {
      super(CrossCollectionJoinQuery.this, score);
      this.scoreMode = scoreMode;
      this.searcher = searcher;
    }

    private String createHashRangeFq() {
      if (routedByJoinKey) {
        ClusterState clusterState = searcher.getCore().getCoreContainer().getZkController().getClusterState();
        CloudDescriptor desc = searcher.getCore().getCoreDescriptor().getCloudDescriptor();
        Slice slice = clusterState.getCollection(desc.getCollectionName()).getSlicesMap().get(desc.getShardId());
        DocRouter.Range range = slice.getRange();

        // In CompositeIdRouter, the routing prefix only affects the top 16 bits
        int min = range.min & 0xffff0000;
        int max = range.max | 0x0000ffff;

        return String.format(Locale.ROOT, "{!hash_range f=%s l=%d u=%d}", fromField, min, max);
      } else {
        return null;
      }
    }

    private TupleStream createCloudSolrStream(SolrClientCache solrClientCache) throws IOException {
      String streamZkHost;
      if (zkHost != null) {
        streamZkHost = zkHost;
      } else {
        streamZkHost = searcher.getCore().getCoreContainer().getZkController().getZkServerAddress();
      }

      ModifiableSolrParams params = new ModifiableSolrParams(otherParams);
      params.set(CommonParams.Q, query);
      String fq = createHashRangeFq();
      if (fq != null) {
        params.add(CommonParams.FQ, fq);
      }
      params.set(CommonParams.FL, fromField);
      params.set(CommonParams.SORT, fromField + " asc");
      params.set(CommonParams.QT, "/export");
      params.set(CommonParams.WT, CommonParams.JAVABIN);

      StreamContext streamContext = new StreamContext();
      streamContext.setSolrClientCache(solrClientCache);

      TupleStream cloudSolrStream = new CloudSolrStream(streamZkHost, collection, params);
      TupleStream uniqueStream = new UniqueStream(cloudSolrStream, new FieldEqualitor(fromField));
      uniqueStream.setStreamContext(streamContext);
      return uniqueStream;
    }

    private TupleStream createSolrStream() {
      StreamExpression searchExpr = new StreamExpression("search")
              .withParameter(collection)
              .withParameter(new StreamExpressionNamedParameter(CommonParams.Q, query));
      String fq = createHashRangeFq();
      if (fq != null) {
        searchExpr.withParameter(new StreamExpressionNamedParameter(CommonParams.FQ, fq));
      }
      searchExpr.withParameter(new StreamExpressionNamedParameter(CommonParams.FL, fromField))
              .withParameter(new StreamExpressionNamedParameter(CommonParams.SORT, fromField + " asc"))
              .withParameter(new StreamExpressionNamedParameter(CommonParams.QT, "/export"));

      for (Map.Entry<String,String[]> entry : otherParams) {
        for (String value : entry.getValue()) {
          searchExpr.withParameter(new StreamExpressionNamedParameter(entry.getKey(), value));
        }
      }

      StreamExpression uniqueExpr = new StreamExpression("unique");
      uniqueExpr.withParameter(searchExpr)
              .withParameter(new StreamExpressionNamedParameter("over", fromField));

      ModifiableSolrParams params = new ModifiableSolrParams();
      params.set("expr", uniqueExpr.toString());
      params.set(CommonParams.QT, "/stream");
      params.set(CommonParams.WT, CommonParams.JAVABIN);

      return new SolrStream(solrUrl + "/" + collection, params);
    }

    private DocSet getDocSet() throws IOException {
      SolrClientCache solrClientCache = searcher.getCore().getCoreContainer().getSolrClientCache();
      TupleStream solrStream;
      if (zkHost != null || solrUrl == null) {
        solrStream = createCloudSolrStream(solrClientCache);
      } else {
        solrStream = createSolrStream();
      }

      FieldType fieldType = searcher.getSchema().getFieldType(toField);
      JoinKeyCollector collector;
      if (fieldType.isPointField()) {
        collector = new PointJoinKeyCollector(searcher);
      } else {
        Terms terms = searcher.getSlowAtomicReader().terms(toField);
        if (terms == null) {
          return DocSet.empty();
        }
        collector = new TermsJoinKeyCollector(fieldType, terms, searcher);
      }

      try {
        solrStream.open();
        while (true) {
          Tuple tuple = solrStream.read();
          if (tuple.EXCEPTION) {
            throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, tuple.getException());
          }
          if (tuple.EOF) {
            break;
          }

          Object value = tuple.get(fromField);
          collector.collect(value);
        }
      } catch (IOException e) {
        throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e);
      } finally {
        solrStream.close();
      }

      return collector.getDocSet();
    }

    @Override
    public Scorer scorer(LeafReaderContext context) throws IOException {
      if (filter == null) {
        filter = getDocSet().getTopFilter();
      }

      DocIdSet readerSet = filter.getDocIdSet(context, null);
      if (readerSet == null) {
        return null;
      }
      DocIdSetIterator readerSetIterator = readerSet.iterator();
      if (readerSetIterator == null) {
        return null;
      }
      return new ConstantScoreScorer(this, score(), scoreMode, readerSetIterator);
    }

    @Override
    public boolean isCacheable(LeafReaderContext ctx) {
      return false;
    }
  }

  @Override
  public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
    return new CrossCollectionJoinQueryWeight((SolrIndexSearcher) searcher, scoreMode, boost);
  }

  @Override
  public void visit(QueryVisitor visitor) {
    visitor.visitLeaf(this);
  }

  @Override
  public int hashCode() {
    final int prime = 31;
    int result = classHash();
    result = prime * result + Objects.hashCode(query);
    result = prime * result + Objects.hashCode(zkHost);
    result = prime * result + Objects.hashCode(solrUrl);
    result = prime * result + Objects.hashCode(collection);
    result = prime * result + Objects.hashCode(fromField);
    result = prime * result + Objects.hashCode(toField);
    result = prime * result + Objects.hashCode(routedByJoinKey);
    result = prime * result + Objects.hashCode(otherParamsString);
    // timestamp and ttl should not be included in hash code
    return result;
  }

  @Override
  public boolean equals(Object other) {
    return sameClassAs(other) &&
            equalsTo(getClass().cast(other));
  }

  private boolean equalsTo(CrossCollectionJoinQuery other) {
    return Objects.equals(query, other.query) &&
            Objects.equals(zkHost, other.zkHost) &&
            Objects.equals(solrUrl, other.solrUrl) &&
            Objects.equals(collection, other.collection) &&
            Objects.equals(fromField, other.fromField) &&
            Objects.equals(toField, other.toField) &&
            Objects.equals(routedByJoinKey, other.routedByJoinKey) &&
            Objects.equals(otherParamsString, other.otherParamsString) &&
            TimeUnit.SECONDS.convert(Math.abs(timestamp - other.timestamp), TimeUnit.NANOSECONDS) < Math.min(ttl, other.ttl);
  }

  @Override
  public String toString(String field) {
    return String.format(Locale.ROOT, "{!xcjf collection=%s from=%s to=%s routed=%b ttl=%d}%s",
            collection, fromField, toField, routedByJoinKey, ttl, query.toString());
  }
}