/*
 * Copyright (C) 2017-2019 Dremio Corporation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.dremio.flight;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

import javax.inject.Provider;

import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.auth.BasicServerAuthHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.dremio.exec.proto.UserBitShared;
import com.dremio.exec.proto.UserProtos;
import com.dremio.exec.server.SabotContext;
import com.dremio.exec.server.options.SessionOptionManager;
import com.dremio.exec.server.options.SessionOptionManagerFactoryImpl;
import com.dremio.sabot.rpc.user.UserRpcUtils;
import com.dremio.sabot.rpc.user.UserSession;
import com.dremio.service.users.SystemUser;
import com.dremio.service.users.UserLoginException;
import com.dremio.service.users.UserService;

/**
 * user/pass validation for dremios arrow flight endpoint
 */
public class AuthValidator implements BasicServerAuthHandler.BasicAuthValidator {
  private static final Logger logger = LoggerFactory.getLogger(AuthValidator.class);
  private final Map<ByteArrayWrapper, UserSession> sessions = new HashMap<>();
  private final Map<ByteArrayWrapper, String> passwords = new HashMap<>();
  private final Map<ByteArrayWrapper, FlightSessionOptions> options = new HashMap<>();
  private final Map<String, ByteArrayWrapper> tokens = new HashMap<>();
  private final UserService userService;
  private final SabotContext context;

  public AuthValidator(Provider<UserService> userService, Provider<SabotContext> context) {
    this.userService = userService.get();
    this.context = context.get();
  }

  public AuthValidator(UserService userService, SabotContext context) {
    this.userService = userService;
    this.context = context;
  }


  @Override
  public byte[] getToken(String user, String password) throws Exception {
    try {
      if (userService != null) {
        userService.authenticate(user, password);
      } else {
        if (!(SystemUser.SYSTEM_USERNAME.equals(user) && "".equals(password))) {
          throw new UserLoginException(user, "not default user");
        }
      }
      byte[] b = (user + ":" + password).getBytes();
      sessions.put(new ByteArrayWrapper(b), build(user, password));
      passwords.put(new ByteArrayWrapper(b), password);
      options.put(new ByteArrayWrapper(b), new FlightSessionOptions());
      tokens.put(user, new ByteArrayWrapper(b));
      logger.info("authenticated {}", user);
      return b;
    } catch (Throwable e) {
      logger.error("unable to authenticate {}", user);
    }
    return new byte[0];
  }

  @Override
  public Optional<String> isValid(byte[] bytes) {
    logger.warn("Sessions: " + sessions.keySet().size() + " with entries " + sessions.keySet());
    logger.warn("asking for " + new ByteArrayWrapper(bytes) + " it is " + ((sessions.containsKey(new ByteArrayWrapper(bytes))) ? "in" : "not in") + " the session set");
    UserSession session = sessions.get(new ByteArrayWrapper(bytes));
    String user = null;
    if (session != null) {
      user = session.getCredentials().getUserName();
    }
    return Optional.ofNullable(user);
  }

  private UserSession build(String user, String password) {
    SessionOptionManager optionsManager =
      new SessionOptionManagerFactoryImpl(context.getOptionValidatorListing()).getOrCreate("flight-session-" + user);
    return UserSession.Builder.newBuilder()
      .withCredentials(UserBitShared.UserCredentials.newBuilder().setUserName(user).build())
      .withSessionOptionManager(optionsManager, context.getOptionManager())
      .withUserProperties(
        UserProtos.UserProperties.newBuilder().addProperties(
          UserProtos.Property.newBuilder().setKey("password").setValue(password).build()
        ).build())
      .withClientInfos(UserRpcUtils.getRpcEndpointInfos("Dremio Flight Client"))
      .setSupportComplexTypes(true)
      .build();
  }

  public UserSession getUserSession(FlightProducer.CallContext callContext) {
    return sessions.get(tokens.get(callContext.peerIdentity()));
  }

  public String getUserPassword(FlightProducer.CallContext callContext) {
    return passwords.get(tokens.get(callContext.peerIdentity()));
  }

  public FlightSessionOptions getSessionOptions(FlightProducer.CallContext callContext) {
    return options.get(tokens.get(callContext.peerIdentity()));
  }

  public static class FlightSessionOptions {
    private boolean isParallel;

    public boolean isParallel() {
      return isParallel;
    }

    public void setParallel(boolean parallel) {
      isParallel = parallel;
    }
  }

  /**
   * wrapper class to make byte[] a map key
   */
  private static class ByteArrayWrapper {
    private final byte[] bytes;

    public ByteArrayWrapper(byte[] bytes) {
      this.bytes = bytes;
    }

    @Override
    public boolean equals(Object o) {
      if (this == o) {
        return true;
      }
      if (o == null || getClass() != o.getClass()) {
        return false;
      }
      ByteArrayWrapper that = (ByteArrayWrapper) o;
      return Arrays.equals(bytes, that.bytes);
    }

    @Override
    public int hashCode() {
      return Arrays.hashCode(bytes);
    }
  }
}