/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.beam.runners.fnexecution.artifact;

import com.google.auto.value.AutoValue;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.channels.Channels;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi;
import org.apache.beam.model.jobmanagement.v1.ArtifactStagingServiceGrpc;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.fnexecution.FnService;
import org.apache.beam.sdk.fn.IdGenerator;
import org.apache.beam.sdk.fn.IdGenerators;
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.fs.MatchResult;
import org.apache.beam.sdk.io.fs.MoveOptions;
import org.apache.beam.sdk.io.fs.ResolveOptions;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.util.MimeTypes;
import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.InvalidProtocolBufferException;
import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.Status;
import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.StatusException;
import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.stub.StreamObserver;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Charsets;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.hash.Hashing;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ArtifactStagingService
    extends ArtifactStagingServiceGrpc.ArtifactStagingServiceImplBase implements FnService {

  private static final Logger LOG = LoggerFactory.getLogger(ArtifactStagingService.class);

  private final ArtifactDestinationProvider destinationProvider;

  private final ConcurrentMap<String, Map<String, List<RunnerApi.ArtifactInformation>>> toStage =
      new ConcurrentHashMap<>();

  private final ConcurrentMap<String, Map<String, List<RunnerApi.ArtifactInformation>>> staged =
      new ConcurrentHashMap<>();

  public ArtifactStagingService(ArtifactDestinationProvider destinationProvider) {
    this.destinationProvider = destinationProvider;
  }

  /**
   * Registers a set of artifacts to be staged with this service.
   *
   * <p>A client (e.g. a Beam SDK) is expected to connect to this service with the given staging
   * token and offer resolution and retrieval of this set of artifacts.
   *
   * @param stagingToken a staging token for this job
   * @param artifacts all artifacts to stage, keyed by environment
   */
  public void registerJob(
      String stagingToken, Map<String, List<RunnerApi.ArtifactInformation>> artifacts) {
    assert !toStage.containsKey(stagingToken);
    toStage.put(stagingToken, artifacts);
  }

  /**
   * Returns the rewritten artifacts associated with this job, keyed by environment.
   *
   * <p>This should be called after the client has finished offering artifacts.
   *
   * @param stagingToken a staging token for this job
   */
  public Map<String, List<RunnerApi.ArtifactInformation>> getStagedArtifacts(String stagingToken) {
    toStage.remove(stagingToken);
    return staged.remove(stagingToken);
  }

  public void removeStagedArtifacts(String stagingToken) throws IOException {
    destinationProvider.removeStagedArtifacts(stagingToken);
  }

  /** Provides a concrete location to which artifacts can be staged on retrieval. */
  public interface ArtifactDestinationProvider {
    ArtifactDestination getDestination(String stagingToken, String name) throws IOException;

    void removeStagedArtifacts(String stagingToken) throws IOException;
  }

  /**
   * A pairing of a newly created artifact type and an ouptut stream that will be readable at that
   * type.
   */
  @AutoValue
  public abstract static class ArtifactDestination {
    public static ArtifactDestination create(
        String typeUrn, ByteString typePayload, OutputStream out) {
      return new AutoValue_ArtifactStagingService_ArtifactDestination(typeUrn, typePayload, out);
    }

    public static ArtifactDestination fromFile(String path) throws IOException {
      return fromFile(
          path,
          Channels.newOutputStream(
              FileSystems.create(
                  FileSystems.matchNewResource(path, false /* isDirectory */), MimeTypes.BINARY)));
    }

    public static ArtifactDestination fromFile(String path, OutputStream out) {
      return create(
          ArtifactRetrievalService.FILE_ARTIFACT_URN,
          RunnerApi.ArtifactFilePayload.newBuilder().setPath(path).build().toByteString(),
          out);
    }

    public abstract String getTypeUrn();

    public abstract ByteString getTypePayload();

    public abstract OutputStream getOutputStream();
  }

  /**
   * An ArtifactDestinationProvider that places new artifacts as files in a Beam filesystem.
   *
   * @param root the directory in which to place all artifacts
   */
  public static ArtifactDestinationProvider beamFilesystemArtifactDestinationProvider(String root) {
    return new ArtifactDestinationProvider() {
      @Override
      public ArtifactDestination getDestination(String stagingToken, String name)
          throws IOException {
        ResourceId path =
            stagingDir(stagingToken)
                .resolve(name, ResolveOptions.StandardResolveOptions.RESOLVE_FILE);
        return ArtifactDestination.fromFile(path.toString());
      }

      @Override
      public void removeStagedArtifacts(String stagingToken) throws IOException {
        // TODO(robertwb): Consider adding recursive delete.
        ResourceId stagingDir = stagingDir(stagingToken);
        List<ResourceId> toDelete = new ArrayList<>();
        for (MatchResult match :
            FileSystems.matchResources(
                ImmutableList.of(
                    stagingDir.resolve("*", ResolveOptions.StandardResolveOptions.RESOLVE_FILE)))) {
          for (MatchResult.Metadata m : match.metadata()) {
            toDelete.add(m.resourceId());
          }
        }
        FileSystems.delete(toDelete, MoveOptions.StandardMoveOptions.IGNORE_MISSING_FILES);
        FileSystems.delete(
            ImmutableList.of(stagingDir), MoveOptions.StandardMoveOptions.IGNORE_MISSING_FILES);
      }

      private ResourceId stagingDir(String stagingToken) {
        return FileSystems.matchNewResource(root, true)
            .resolve(
                Hashing.sha256().hashString(stagingToken, Charsets.UTF_8).toString(),
                ResolveOptions.StandardResolveOptions.RESOLVE_DIRECTORY);
      }
    };
  }

  private enum State {
    START,
    RESOLVE,
    GET,
    GETCHUNK,
    DONE,
    ERROR,
  }

  /**
   * Like the standard Semaphore, but allows an aquire to go over the limit if there is any room.
   *
   * <p>Also allows setting an error, to avoid issues with un-released aquires after error.
   */
  private static class OverflowingSemaphore {
    private int totalPermits;
    private int usedPermits;
    private Exception exception;

    public OverflowingSemaphore(int totalPermits) {
      this.totalPermits = totalPermits;
      this.usedPermits = 0;
    }

    synchronized void aquire(int permits) throws Exception {
      while (usedPermits >= totalPermits) {
        if (exception != null) {
          throw exception;
        }
        this.wait();
      }
      usedPermits += permits;
    }

    synchronized void release(int permits) {
      usedPermits -= permits;
      this.notifyAll();
    }

    synchronized void setException(Exception exception) {
      this.exception = exception;
      this.notifyAll();
    }
  }

  /** A task that pulls bytes off a queue and actually writes them to a staging location. */
  private class StoreArtifact implements Callable<RunnerApi.ArtifactInformation> {

    private String stagingToken;
    private String name;
    private RunnerApi.ArtifactInformation originalArtifact;
    private BlockingQueue<ByteString> bytesQueue;
    private OverflowingSemaphore totalPendingBytes;

    public StoreArtifact(
        String stagingToken,
        String name,
        RunnerApi.ArtifactInformation originalArtifact,
        BlockingQueue<ByteString> bytesQueue,
        OverflowingSemaphore totalPendingBytes) {
      this.stagingToken = stagingToken;
      this.name = name;
      this.originalArtifact = originalArtifact;
      this.bytesQueue = bytesQueue;
      this.totalPendingBytes = totalPendingBytes;
    }

    @Override
    public RunnerApi.ArtifactInformation call() throws IOException {
      try {
        ArtifactDestination dest = destinationProvider.getDestination(stagingToken, name);
        LOG.debug("Storing artifact for {}.{} at {}", stagingToken, name, dest);
        ByteString chunk = bytesQueue.take();
        while (chunk.size() > 0) {
          totalPendingBytes.release(chunk.size());
          chunk.writeTo(dest.getOutputStream());
          chunk = bytesQueue.take();
        }
        dest.getOutputStream().close();
        return originalArtifact
            .toBuilder()
            .setTypeUrn(dest.getTypeUrn())
            .setTypePayload(dest.getTypePayload())
            .build();
      } catch (IOException | InterruptedException exn) {
        // As this thread will no longer be draining the queue, we don't want to get stuck writing
        // to it.
        totalPendingBytes.setException(exn);
        LOG.error("Exception staging artifacts", exn);
        if (exn instanceof IOException) {
          throw (IOException) exn;
        } else {
          throw new RuntimeException(exn);
        }
      }
    }
  }

  @Override
  public StreamObserver<ArtifactApi.ArtifactResponseWrapper> reverseArtifactRetrievalService(
      StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) {

    return new StreamObserver<ArtifactApi.ArtifactResponseWrapper>() {

      /** The maximum number of parallel threads to use to stage. */
      public static final int THREAD_POOL_SIZE = 10;

      /** The maximum number of bytes to buffer across all writes before throttling. */
      public static final int MAX_PENDING_BYTES = 100 << 20; // 100 MB

      IdGenerator idGenerator = IdGenerators.incrementingLongs();

      String stagingToken;
      Map<String, List<RunnerApi.ArtifactInformation>> toResolve;
      Map<String, List<Future<RunnerApi.ArtifactInformation>>> stagedFutures;
      ExecutorService stagingExecutor;
      OverflowingSemaphore totalPendingBytes;

      State state = State.START;
      Queue<String> pendingResolves;
      String currentEnvironment;
      Queue<RunnerApi.ArtifactInformation> pendingGets;
      BlockingQueue<ByteString> currentOutput;

      @Override
      @SuppressFBWarnings(value = "SF_SWITCH_FALLTHROUGH", justification = "fallthrough intended")
      // May be called by different threads for the same request; synchronized for memory
      // synchronization.
      public synchronized void onNext(ArtifactApi.ArtifactResponseWrapper responseWrapper) {
        switch (state) {
          case START:
            stagingToken = responseWrapper.getStagingToken();
            LOG.info("Staging artifacts for {}.", stagingToken);
            toResolve = toStage.get(stagingToken);
            if (toResolve == null) {
              responseObserver.onError(
                  new StatusException(
                      Status.INVALID_ARGUMENT.withDescription(
                          "Unknown staging token " + stagingToken)));
              return;
            }
            stagedFutures = new ConcurrentHashMap<>();
            pendingResolves = new ArrayDeque<>();
            pendingResolves.addAll(toResolve.keySet());
            stagingExecutor = Executors.newFixedThreadPool(THREAD_POOL_SIZE);
            totalPendingBytes = new OverflowingSemaphore(MAX_PENDING_BYTES);
            resolveNextEnvironment(responseObserver);
            break;

          case RESOLVE:
            {
              currentEnvironment = pendingResolves.remove();
              stagedFutures.put(currentEnvironment, new ArrayList<>());
              pendingGets = new ArrayDeque<>();
              for (RunnerApi.ArtifactInformation artifact :
                  responseWrapper.getResolveArtifactResponse().getReplacementsList()) {
                Optional<RunnerApi.ArtifactInformation> fetched = getLocal(artifact);
                if (fetched.isPresent()) {
                  stagedFutures
                      .get(currentEnvironment)
                      .add(CompletableFuture.completedFuture(fetched.get()));
                } else {
                  pendingGets.add(artifact);
                  responseObserver.onNext(
                      ArtifactApi.ArtifactRequestWrapper.newBuilder()
                          .setGetArtifact(
                              ArtifactApi.GetArtifactRequest.newBuilder().setArtifact(artifact))
                          .build());
                }
              }
              LOG.info(
                  "Getting {} artifacts for {}.{}.",
                  pendingGets.size(),
                  stagingToken,
                  pendingResolves.peek());
              if (pendingGets.isEmpty()) {
                resolveNextEnvironment(responseObserver);
              } else {
                state = State.GET;
              }
              break;
            }

          case GET:
            RunnerApi.ArtifactInformation currentArtifact = pendingGets.remove();
            String name = createFilename(currentEnvironment, currentArtifact);
            try {
              LOG.debug("Storing artifacts for {} as {}", stagingToken, name);
              currentOutput = new ArrayBlockingQueue<ByteString>(100);
              stagedFutures
                  .get(currentEnvironment)
                  .add(
                      stagingExecutor.submit(
                          new StoreArtifact(
                              stagingToken,
                              name,
                              currentArtifact,
                              currentOutput,
                              totalPendingBytes)));
            } catch (Exception exn) {
              LOG.error("Error submitting.", exn);
              responseObserver.onError(exn);
            }
            state = State.GETCHUNK;
            // fall through

          case GETCHUNK:
            try {
              ByteString chunk = responseWrapper.getGetArtifactResponse().getData();
              if (chunk.size() > 0) { // Make sure we don't accidentally send the EOF value.
                totalPendingBytes.aquire(chunk.size());
                currentOutput.put(chunk);
              }
              if (responseWrapper.getIsLast()) {
                currentOutput.put(ByteString.EMPTY); // The EOF value.
                if (pendingGets.isEmpty()) {
                  resolveNextEnvironment(responseObserver);
                } else {
                  state = State.GET;
                  LOG.debug("Waiting for {}", pendingGets.peek());
                }
              }
            } catch (Exception exn) {
              LOG.error("Error submitting.", exn);
              onError(exn);
            }
            break;

          default:
            responseObserver.onError(
                new StatusException(
                    Status.INVALID_ARGUMENT.withDescription("Illegal state " + state)));
        }
      }

      private void resolveNextEnvironment(
          StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) {
        if (pendingResolves.isEmpty()) {
          finishStaging(responseObserver);
        } else {
          state = State.RESOLVE;
          LOG.info("Resolving artifacts for {}.{}.", stagingToken, pendingResolves.peek());
          responseObserver.onNext(
              ArtifactApi.ArtifactRequestWrapper.newBuilder()
                  .setResolveArtifact(
                      ArtifactApi.ResolveArtifactsRequest.newBuilder()
                          .addAllArtifacts(toResolve.get(pendingResolves.peek())))
                  .build());
        }
      }

      private void finishStaging(
          StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) {
        LOG.debug("Finishing staging for {}.", stagingToken);
        Map<String, List<RunnerApi.ArtifactInformation>> staged = new HashMap<>();
        try {
          for (Map.Entry<String, List<Future<RunnerApi.ArtifactInformation>>> entry :
              stagedFutures.entrySet()) {
            List<RunnerApi.ArtifactInformation> envStaged = new ArrayList<>();
            for (Future<RunnerApi.ArtifactInformation> future : entry.getValue()) {
              envStaged.add(future.get());
            }
            staged.put(entry.getKey(), envStaged);
          }
          ArtifactStagingService.this.staged.put(stagingToken, staged);
          stagingExecutor.shutdown();
          state = State.DONE;
          LOG.info("Artifacts fully staged for {}.", stagingToken);
          responseObserver.onCompleted();
        } catch (Exception exn) {
          LOG.error("Error staging artifacts", exn);
          responseObserver.onError(exn);
          state = State.ERROR;
          return;
        }
      }

      /**
       * Return an alternative artifact if we do not need to get this over the artifact API, or
       * possibly at all.
       */
      private Optional<RunnerApi.ArtifactInformation> getLocal(
          RunnerApi.ArtifactInformation artifact) {
        return Optional.empty();
      }

      /**
       * Attempts to provide a reasonable filename for the artifact.
       *
       * @param index a monotonically increasing index, which provides uniqueness
       * @param environment the environment id
       * @param artifact the artifact itself
       */
      private String createFilename(String environment, RunnerApi.ArtifactInformation artifact) {
        String path;
        try {
          if (artifact.getRoleUrn().equals(ArtifactRetrievalService.STAGING_TO_ARTIFACT_URN)) {
            path =
                RunnerApi.ArtifactStagingToRolePayload.parseFrom(artifact.getRolePayload())
                    .getStagedName();
          } else if (artifact.getTypeUrn().equals(ArtifactRetrievalService.FILE_ARTIFACT_URN)) {
            path = RunnerApi.ArtifactFilePayload.parseFrom(artifact.getTypePayload()).getPath();
          } else if (artifact.getTypeUrn().equals(ArtifactRetrievalService.URL_ARTIFACT_URN)) {
            path = RunnerApi.ArtifactUrlPayload.parseFrom(artifact.getTypePayload()).getUrl();
          } else {
            path = "artifact";
          }
        } catch (InvalidProtocolBufferException exn) {
          throw new RuntimeException(exn);
        }
        // Limit to the last contiguous alpha-numeric sequence. In particular, this will exclude
        // all path separators.
        List<String> components = Splitter.onPattern("[^A-Za-z-_.]]").splitToList(path);
        String base = components.get(components.size() - 1);
        return clip(
            String.format("%s-%s-%s", idGenerator.getId(), clip(environment, 25), base), 100);
      }

      private String clip(String s, int maxLength) {
        return s.length() < maxLength ? s : s.substring(0, maxLength);
      }

      @Override
      public void onError(Throwable throwable) {
        stagingExecutor.shutdownNow();
        LOG.error("Error staging artifacts", throwable);
        state = State.ERROR;
      }

      @Override
      public void onCompleted() {
        Preconditions.checkArgument(state == State.DONE);
      }
    };
  }

  @Override
  public void close() throws Exception {
    // Nothing to close.
  }

  /**
   * Lazily stages artifacts by letting an ArtifactStagingService resolve and request artifacts.
   *
   * @param retrievalService an ArtifactRetrievalService used to resolve and retrieve artifacts
   * @param stagingService an ArtifactStagingService stub which will request artifacts
   * @param stagingToken the staging token of the job whose artifacts will be retrieved
   * @throws InterruptedException
   * @throws IOException
   */
  public static void offer(
      ArtifactRetrievalService retrievalService,
      ArtifactStagingServiceGrpc.ArtifactStagingServiceStub stagingService,
      String stagingToken)
      throws ExecutionException, InterruptedException {
    new StagingDriver(retrievalService, stagingService, stagingToken).getCompletionFuture().get();
  }

  /** Actually implements the reverse retrieval protocol. */
  private static class StagingDriver implements StreamObserver<ArtifactApi.ArtifactRequestWrapper> {

    private final ArtifactRetrievalService retrievalService;
    private final StreamObserver<ArtifactApi.ArtifactResponseWrapper> responseObserver;
    private final CompletableFuture<Void> completionFuture;

    public StagingDriver(
        ArtifactRetrievalService retrievalService,
        ArtifactStagingServiceGrpc.ArtifactStagingServiceStub stagingService,
        String stagingToken) {
      this.retrievalService = retrievalService;
      completionFuture = new CompletableFuture<Void>();
      responseObserver = stagingService.reverseArtifactRetrievalService(this);
      responseObserver.onNext(
          ArtifactApi.ArtifactResponseWrapper.newBuilder().setStagingToken(stagingToken).build());
    }

    public CompletableFuture<?> getCompletionFuture() {
      return completionFuture;
    }

    @Override
    public void onNext(ArtifactApi.ArtifactRequestWrapper requestWrapper) {

      if (completionFuture.isCompletedExceptionally()) {
        try {
          completionFuture.get();
        } catch (Throwable th) {
          responseObserver.onError(th);
          return;
        }
      }

      if (requestWrapper.hasResolveArtifact()) {
        retrievalService.resolveArtifacts(
            requestWrapper.getResolveArtifact(),
            new StreamObserver<ArtifactApi.ResolveArtifactsResponse>() {

              @Override
              public void onNext(ArtifactApi.ResolveArtifactsResponse resolveArtifactsResponse) {
                responseObserver.onNext(
                    ArtifactApi.ArtifactResponseWrapper.newBuilder()
                        .setResolveArtifactResponse(resolveArtifactsResponse)
                        .build());
              }

              @Override
              public void onError(Throwable throwable) {
                completionFuture.completeExceptionally(throwable);
                responseObserver.onError(throwable);
              }

              @Override
              public void onCompleted() {}
            });
      } else if (requestWrapper.hasGetArtifact()) {
        retrievalService.getArtifact(
            requestWrapper.getGetArtifact(),
            new StreamObserver<ArtifactApi.GetArtifactResponse>() {

              @Override
              public void onNext(ArtifactApi.GetArtifactResponse getArtifactResponse) {
                responseObserver.onNext(
                    ArtifactApi.ArtifactResponseWrapper.newBuilder()
                        .setGetArtifactResponse(getArtifactResponse)
                        .build());
              }

              @Override
              public void onError(Throwable throwable) {
                completionFuture.completeExceptionally(throwable);
                responseObserver.onError(throwable);
              }

              @Override
              public void onCompleted() {
                responseObserver.onNext(
                    ArtifactApi.ArtifactResponseWrapper.newBuilder()
                        .setGetArtifactResponse(
                            ArtifactApi.GetArtifactResponse.newBuilder().build())
                        .setIsLast(true)
                        .build());
              }
            });
      } else {
        Throwable exn =
            new StatusException(
                Status.INVALID_ARGUMENT.withDescription(
                    "Expected either a resolve or get request."));
        completionFuture.completeExceptionally(exn);
        responseObserver.onError(exn);
      }
    }

    @Override
    public void onError(Throwable throwable) {
      completionFuture.completeExceptionally(throwable);
    }

    @Override
    public void onCompleted() {
      responseObserver.onCompleted();
      completionFuture.complete(null);
    }
  }
}