/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.blur.lucene.security.index;

import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocsAndPositionsEnum;
import org.apache.lucene.index.DocsEnum;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.Fields;
import org.apache.lucene.index.FilterAtomicReader;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.StoredFieldVisitor;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.automaton.CompiledAutomaton;

import com.google.common.base.Splitter;

/**
 * The current {@link SecureAtomicReader} will protect access to documents based
 * on the {@link AccessControl} object.
 * 
 * NOTE: If you are using the {@link Fields} and {@link Terms} with
 * {@link TermsEnum} to create a type ahead. Make sure that you check that the
 * {@link TermsEnum} actually points to a single document because the
 * {@link SecureAtomicReader} will leak terms that users don't have access to
 * read or discover.
 */
public class SecureAtomicReader extends FilterAtomicReader {

  private final AccessControlReader _accessControl;
  private final AtomicReader _original;

  public static SecureAtomicReader create(AccessControlFactory accessControlFactory, AtomicReader in,
      Collection<String> readAuthorizations, Collection<String> discoverAuthorizations, Set<String> discoverableFields,
      String defaultReadMaskMessage) throws IOException {
    AccessControlReader accessControlReader = accessControlFactory.getReader(readAuthorizations,
        discoverAuthorizations, discoverableFields, defaultReadMaskMessage);
    return new SecureAtomicReader(in, accessControlReader);
  }

  public SecureAtomicReader(AtomicReader in, AccessControlReader accessControlReader) throws IOException {
    super(in);
    _accessControl = accessControlReader.clone(in);
    _original = in;
  }

  public AtomicReader getOriginalReader() {
    return _original;
  }

  @Override
  public Bits getLiveDocs() {
    final Bits liveDocs = in.getLiveDocs();
    final int maxDoc = maxDoc();
    return new Bits() {

      @Override
      public boolean get(int index) {
        if (liveDocs == null || liveDocs.get(index)) {
          // Need to check access
          try {
            if (_accessControl.hasAccess(ReadType.LIVEDOCS, index)) {
              return true;
            }
          } catch (IOException e) {
            throw new RuntimeException(e);
          }
        }
        return false;
      }

      @Override
      public int length() {
        return maxDoc;
      }

    };
  }

  @Override
  public Fields getTermVectors(int docID) throws IOException {
    // use doc auth
    throw new RuntimeException("Not implemented.");
  }

  @Override
  public void document(int docID, final StoredFieldVisitor visitor) throws IOException {
    if (_accessControl.hasAccess(ReadType.DOCUMENT_FETCH_READ, docID)) {
      GetReadMaskFields getReadMaskFields = new GetReadMaskFields(_accessControl.getDefaultReadMaskMessage());
      in.document(docID, getReadMaskFields);
      Map<String, String> readMaskFields = getReadMaskFields.getReadMaskFields();
      if (readMaskFields.isEmpty()) {
        in.document(docID, visitor);
      } else {
        in.document(docID, new ReadMaskStoredFieldVisitor(visitor, readMaskFields));
      }
      return;
    }
    if (_accessControl.hasAccess(ReadType.DOCUMENT_FETCH_DISCOVER, docID)) {
      in.document(docID, new StoredFieldVisitor() {
        @Override
        public Status needsField(FieldInfo fieldInfo) throws IOException {
          if (_accessControl.canDiscoverField(fieldInfo.name)) {
            return visitor.needsField(fieldInfo);
          } else {
            return Status.NO;
          }
        }

        @Override
        public void binaryField(FieldInfo fieldInfo, byte[] value) throws IOException {
          visitor.binaryField(fieldInfo, value);
        }

        @Override
        public void stringField(FieldInfo fieldInfo, String value) throws IOException {
          visitor.stringField(fieldInfo, value);
        }

        @Override
        public void intField(FieldInfo fieldInfo, int value) throws IOException {
          visitor.intField(fieldInfo, value);
        }

        @Override
        public void longField(FieldInfo fieldInfo, long value) throws IOException {
          visitor.longField(fieldInfo, value);
        }

        @Override
        public void floatField(FieldInfo fieldInfo, float value) throws IOException {
          visitor.floatField(fieldInfo, value);
        }

        @Override
        public void doubleField(FieldInfo fieldInfo, double value) throws IOException {
          visitor.doubleField(fieldInfo, value);
        }

      });
      return;
    }
  }

  private static class ReadMaskStoredFieldVisitor extends StoredFieldVisitor {

    private final StoredFieldVisitor _visitor;
    private final Map<String, String> _readMaskFieldsAndMessages;

    public ReadMaskStoredFieldVisitor(StoredFieldVisitor visitor, Map<String, String> readMaskFieldsAndMessages) {
      _visitor = visitor;
      _readMaskFieldsAndMessages = readMaskFieldsAndMessages;
    }

    @Override
    public Status needsField(FieldInfo fieldInfo) throws IOException {
      return Status.YES;
    }

    @Override
    public void binaryField(FieldInfo fieldInfo, byte[] value) throws IOException {
      if (!checkReadMask(fieldInfo)) {
        _visitor.binaryField(fieldInfo, value);
      }
    }

    private boolean checkReadMask(FieldInfo fieldInfo) throws IOException {
      final String message = _readMaskFieldsAndMessages.get(fieldInfo.name);
      if (message != null) {
        if (message.isEmpty()) {
          return true;
        }
        _visitor.stringField(fieldInfo, message);
        return true;
      }
      return false;
    }

    @Override
    public void stringField(FieldInfo fieldInfo, String value) throws IOException {
      if (!checkReadMask(fieldInfo)) {
        _visitor.stringField(fieldInfo, value);
      }
    }

    @Override
    public void intField(FieldInfo fieldInfo, int value) throws IOException {
      if (!checkReadMask(fieldInfo)) {
        _visitor.intField(fieldInfo, value);
      }
    }

    @Override
    public void longField(FieldInfo fieldInfo, long value) throws IOException {
      if (!checkReadMask(fieldInfo)) {
        _visitor.longField(fieldInfo, value);
      }
    }

    @Override
    public void floatField(FieldInfo fieldInfo, float value) throws IOException {
      if (!checkReadMask(fieldInfo)) {
        _visitor.floatField(fieldInfo, value);
      }
    }

    @Override
    public void doubleField(FieldInfo fieldInfo, double value) throws IOException {
      if (!checkReadMask(fieldInfo)) {
        _visitor.doubleField(fieldInfo, value);
      }
    }

  }

  private static class GetReadMaskFields extends StoredFieldVisitor {

    private final Map<String, String> _fieldsAndMessages = new HashMap<String, String>();
    private final Splitter splitter = Splitter.on('|');
    private final String _defaultReadMask;

    GetReadMaskFields(String defaultReadMask) {
      _defaultReadMask = defaultReadMask == null ? "" : defaultReadMask;
    }

    @Override
    public Status needsField(FieldInfo fieldInfo) throws IOException {
      if (fieldInfo.name.equals(FilterAccessControlFactory.READ_MASK_FIELD)) {
        return Status.YES;
      }
      return Status.NO;
    }

    @Override
    public void stringField(FieldInfo fieldInfo, String value) throws IOException {
      Iterable<String> split = splitter.split(value);
      Iterator<String> iterator = split.iterator();
      String field;
      String message = null;
      if (iterator.hasNext()) {
        field = iterator.next();
      } else {
        return;
      }
      if (iterator.hasNext()) {
        message = iterator.next();
      }
      if (message != null) {
        _fieldsAndMessages.put(field, message);
      } else {
        _fieldsAndMessages.put(field, _defaultReadMask);
      }
    }

    Map<String, String> getReadMaskFields() {
      return _fieldsAndMessages;
    }

  }

  @Override
  public Fields fields() throws IOException {
    return new SecureFields(in.fields(), _accessControl, maxDoc());
  }

  @Override
  public NumericDocValues getNumericDocValues(String field) throws IOException {
    return secureNumericDocValues(in.getNumericDocValues(field), ReadType.NUMERIC_DOC_VALUE);
  }

  private NumericDocValues secureNumericDocValues(final NumericDocValues numericDocValues, final ReadType type) {
    if (numericDocValues == null) {
      return null;
    }
    return new NumericDocValues() {

      @Override
      public long get(int docID) {
        try {
          if (_accessControl.hasAccess(type, docID)) {
            return numericDocValues.get(docID);
          }
          return 0L; // Default missing value.
        } catch (IOException e) {
          throw new RuntimeException(e);
        }
      }
    };
  }

  @Override
  public BinaryDocValues getBinaryDocValues(String field) throws IOException {
    final BinaryDocValues binaryDocValues = in.getBinaryDocValues(field);
    if (binaryDocValues == null) {
      return null;
    }
    return new BinaryDocValues() {

      @Override
      public void get(int docID, BytesRef result) {
        try {
          if (_accessControl.hasAccess(ReadType.BINARY_DOC_VALUE, docID)) {
            binaryDocValues.get(docID, result);
            return;
          }
          // Default missing value.
          result.bytes = MISSING;
          result.length = 0;
          result.offset = 0;
        } catch (IOException e) {
          throw new RuntimeException(e);
        }
      }
    };
  }

  @Override
  public SortedDocValues getSortedDocValues(String field) throws IOException {
    final SortedDocValues sortedDocValues = in.getSortedDocValues(field);
    if (sortedDocValues == null) {
      return null;
    }
    return new SortedDocValues() {

      @Override
      public void lookupOrd(int ord, BytesRef result) {
        sortedDocValues.lookupOrd(ord, result);
      }

      @Override
      public int getValueCount() {
        return sortedDocValues.getValueCount();
      }

      @Override
      public int getOrd(int docID) {
        try {
          if (_accessControl.hasAccess(ReadType.SORTED_DOC_VALUE, docID)) {
            return sortedDocValues.getOrd(docID);
          }
          return -1; // Default missing value.
        } catch (IOException e) {
          throw new RuntimeException(e);
        }
      }
    };
  }

  @Override
  public SortedSetDocValues getSortedSetDocValues(String field) throws IOException {
    final SortedSetDocValues sortedSetDocValues = in.getSortedSetDocValues(field);
    if (sortedSetDocValues == null) {
      return null;
    }
    return new SortedSetDocValues() {

      private boolean _access;

      @Override
      public void setDocument(int docID) {
        try {
          if (_access = _accessControl.hasAccess(ReadType.SORTED_SET_DOC_VALUE, docID)) {
            sortedSetDocValues.setDocument(docID);
          }
        } catch (IOException e) {
          throw new RuntimeException(e);
        }
      }

      @Override
      public long nextOrd() {
        if (_access) {
          return sortedSetDocValues.nextOrd();
        }
        return NO_MORE_ORDS;
      }

      @Override
      public void lookupOrd(long ord, BytesRef result) {
        if (_access) {
          sortedSetDocValues.lookupOrd(ord, result);
        } else {
          result.bytes = BinaryDocValues.MISSING;
          result.length = 0;
          result.offset = 0;
        }
      }

      @Override
      public long getValueCount() {
        return sortedSetDocValues.getValueCount();
      }
    };
  }

  @Override
  public NumericDocValues getNormValues(String field) throws IOException {
    return secureNumericDocValues(in.getNormValues(field), ReadType.NORM_VALUE);
  }

  static class SecureFields extends FilterFields {

    private final int _maxDoc;
    private final AccessControlReader _accessControlReader;

    public SecureFields(Fields in, AccessControlReader accessControlReader, int maxDoc) {
      super(in);
      _accessControlReader = accessControlReader;
      _maxDoc = maxDoc;
    }

    @Override
    public Terms terms(String field) throws IOException {
      Terms terms = in.terms(field);
      if (terms == null) {
        return null;
      }
      Terms readMask = getReadMaskTerms(in, field);
      SecureTerms secureTerms = new SecureTerms(terms, _accessControlReader, _maxDoc);
      if (readMask == null) {
        return secureTerms;
      } else {
        return new ReadMaskTerms(secureTerms, readMask);
      }
    }

    private Terms getReadMaskTerms(Fields in, String field) throws IOException {
      return in.terms(field + FilterAccessControlFactory.READ_MASK_SUFFIX);
    }

  }

  static class ReadMaskTerms extends FilterTerms {

    private final Terms _readMask;

    public ReadMaskTerms(Terms in, Terms readMask) {
      super(in);
      _readMask = readMask;
    }

    @Override
    public TermsEnum iterator(TermsEnum reuse) throws IOException {
      TermsEnum maskTermsEnum = _readMask.iterator(null);
      return new ReadMaskTermsEnum(maskTermsEnum, in.iterator(reuse));
    }

    @Override
    public TermsEnum intersect(CompiledAutomaton compiled, BytesRef startTerm) throws IOException {
      TermsEnum maskTermsEnum = _readMask.intersect(compiled, startTerm);
      return new ReadMaskTermsEnum(maskTermsEnum, in.intersect(compiled, startTerm));
    }

  }

  static class SecureTerms extends FilterTerms {

    private final int _maxDoc;
    private final AccessControlReader _accessControlReader;

    public SecureTerms(Terms in, AccessControlReader accessControlReader, int maxDoc) {
      super(in);
      _accessControlReader = accessControlReader;
      _maxDoc = maxDoc;
    }

    @Override
    public TermsEnum iterator(TermsEnum reuse) throws IOException {
      return new SecureTermsEnum(in.iterator(reuse), _accessControlReader, _maxDoc);
    }

    @Override
    public TermsEnum intersect(CompiledAutomaton compiled, BytesRef startTerm) throws IOException {
      return new SecureTermsEnum(in.intersect(compiled, startTerm), _accessControlReader, _maxDoc);
    }
  }

  static class ReadMaskTermsEnum extends FilterTermsEnum {

    private final TermsEnum _maskTermsEnum;

    public ReadMaskTermsEnum(TermsEnum maskTermsEnum, TermsEnum realTermsEnum) {
      super(realTermsEnum);
      _maskTermsEnum = maskTermsEnum;
    }

    @Override
    public SeekStatus seekCeil(BytesRef text, boolean useCache) throws IOException {
      SeekStatus seekStatus = in.seekCeil(text, useCache);
      if (seekStatus == SeekStatus.END) {
        return SeekStatus.END;
      }
      BytesRef term = in.term();
      BytesRef bytesRef = new BytesRef();
      bytesRef.copyBytes(term);
      // If not in mask then return.
      if (!_maskTermsEnum.seekExact(bytesRef, true)) {
        return seekStatus;
      } else {
        if (checkDocs()) {
          return seekStatus;
        }
        // check if any docs for given term are not present in mask terms enums
        if (next() == null) {
          return SeekStatus.END;
        } else {
          return SeekStatus.NOT_FOUND;
        }
      }
    }

    @Override
    public void seekExact(long ord) throws IOException {
      throw new IOException("Not supported");
    }

    @Override
    public BytesRef next() throws IOException {
      while (true) {
        BytesRef ref = in.next();
        if (ref == null) {
          return null;
        }
        if (!_maskTermsEnum.seekExact(ref, true)) {
          return ref;
        }
        if (checkDocs()) {
          return ref;
        }
      }
    }

    private boolean checkDocs() throws IOException {
      DocsEnum maskDocsEnum = _maskTermsEnum.docs(null, null, DocsEnum.FLAG_NONE);
      DocsEnum docsEnum = in.docs(null, null, DocsEnum.FLAG_NONE);
      int docId;
      while ((docId = docsEnum.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
        if (maskDocsEnum.advance(docId) != docId) {
          return true;
        }
      }
      return false;
    }
  }

  static class SecureTermsEnum extends FilterTermsEnum {

    private final int _maxDoc;
    private final AccessControlReader _accessControlReader;

    public SecureTermsEnum(TermsEnum in, AccessControlReader accessControlReader, int maxDoc) {
      super(in);
      _accessControlReader = accessControlReader;
      _maxDoc = maxDoc;
    }

    @Override
    public SeekStatus seekCeil(BytesRef text, boolean useCache) throws IOException {
      SeekStatus seekStatus = in.seekCeil(text, useCache);
      if (seekStatus == SeekStatus.END) {
        return SeekStatus.END;
      }
      BytesRef term = in.term();
      if (hasAccess(term)) {
        return seekStatus;
      }
      if (next() == null) {
        return SeekStatus.END;
      } else {
        return SeekStatus.NOT_FOUND;
      }
    }

    @Override
    public BytesRef next() throws IOException {
      BytesRef t;
      while ((t = in.next()) != null) {
        if (hasAccess(t)) {
          return t;
        }
      }
      return null;
    }

    private boolean hasAccess(BytesRef term) throws IOException {
      DocsEnum docsEnum = in.docs(null, null, DocsEnum.FLAG_NONE);
      int docId;
      while ((docId = docsEnum.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
        if (_accessControlReader.hasAccess(ReadType.TERMS_ENUM, docId)) {
          return true;
        }
      }
      return false;
    }

    @Override
    public DocsEnum docs(Bits liveDocs, DocsEnum reuse, int flags) throws IOException {
      Bits secureLiveDocs = getSecureLiveDocs(liveDocs, _maxDoc, _accessControlReader);
      return in.docs(secureLiveDocs, reuse, flags);
    }

    @Override
    public DocsAndPositionsEnum docsAndPositions(Bits liveDocs, DocsAndPositionsEnum reuse, int flags)
        throws IOException {
      Bits secureLiveDocs = getSecureLiveDocs(liveDocs, _maxDoc, _accessControlReader);
      return in.docsAndPositions(secureLiveDocs, reuse, flags);
    }

  }

  public static Bits getSecureLiveDocs(Bits bits, int maxDoc, final AccessControlReader accessControlReader) {
    final Bits liveDocs;
    if (bits == null) {
      liveDocs = getMatchAll(maxDoc);
    } else {
      liveDocs = bits;
    }
    final int length = liveDocs.length();
    Bits secureLiveDocs = new Bits() {
      @Override
      public boolean get(int index) {
        if (liveDocs.get(index)) {
          try {
            if (accessControlReader.hasAccess(ReadType.DOCS_ENUM, index)) {
              return true;
            }
          } catch (IOException e) {
            throw new RuntimeException(e);
          }
        }
        return false;
      }

      @Override
      public int length() {
        return length;
      }
    };
    return secureLiveDocs;
  }

  public static Bits getMatchAll(final int length) {
    return new Bits() {

      @Override
      public int length() {
        return length;
      }

      @Override
      public boolean get(int index) {
        return true;
      }
    };
  }
}