/*
 * Copyright 2018-present Open Networking Foundation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.atomix.core.semaphore.impl;

import com.google.common.base.MoreObjects;
import com.google.common.base.Objects;
import io.atomix.core.semaphore.AtomicSemaphoreType;
import io.atomix.core.semaphore.QueueStatus;
import io.atomix.primitive.PrimitiveType;
import io.atomix.primitive.service.AbstractPrimitiveService;
import io.atomix.primitive.service.BackupInput;
import io.atomix.primitive.service.BackupOutput;
import io.atomix.primitive.session.Session;
import io.atomix.primitive.session.SessionId;
import io.atomix.utils.concurrent.Scheduled;
import io.atomix.utils.serializer.Namespace;
import io.atomix.utils.serializer.Serializer;

import java.time.Duration;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.concurrent.TimeUnit;

public abstract class AbstractAtomicSemaphoreService extends AbstractPrimitiveService<AtomicSemaphoreClient> implements AtomicSemaphoreService {
  private static final Serializer SERIALIZER = Serializer.using(Namespace.builder()
      .register(AtomicSemaphoreType.instance().namespace())
      .register(Waiter.class)
      .build());

  private int available;
  private Map<Long, Integer> holders = new HashMap<>();
  private LinkedList<Waiter> waiterQueue = new LinkedList<>();
  private final Map<Long, Scheduled> timers = new HashMap<>();

  public AbstractAtomicSemaphoreService(PrimitiveType primitiveType, int initialCapacity) {
    super(primitiveType, AtomicSemaphoreClient.class);
    this.available = initialCapacity;
  }

  @Override
  public void backup(BackupOutput output) {
    output.writeInt(available);
    output.writeObject(holders, SERIALIZER::encode);
    output.writeObject(waiterQueue, SERIALIZER::encode);
  }

  @Override
  public void restore(BackupInput input) {
    available = input.readInt();

    holders = input.readObject(SERIALIZER::decode);
    waiterQueue = input.readObject(SERIALIZER::decode);

    timers.values().forEach(Scheduled::cancel);
    timers.clear();

    for (Waiter waiter : waiterQueue) {
      if (waiter.expire > 0) {
        timers.put(waiter.index, getScheduler()
            .schedule(Duration.ofMillis(waiter.expire - getWallClock().getTime().unixTimestamp()), () -> {
              timers.remove(waiter.index);
              waiterQueue.remove(waiter);
              fail(waiter.session, waiter.id);
            }));
      }
    }
  }

  @Override
  public void onExpire(Session session) {
    releaseSession(session);
  }

  @Override
  public void onClose(Session session) {
    releaseSession(session);
  }

  @Override
  public void acquire(long id, int permits, long timeout) {
    Session session = getCurrentSession();
    if (available >= permits) {
      acquire(session.sessionId(), id, permits, getCurrentIndex());
    } else {
      if (timeout > 0) {
        Waiter waiter = new Waiter(
            session.sessionId(),
            getCurrentIndex(),
            id,
            permits,
            getWallClock().getTime().unixTimestamp() + timeout);
        waiterQueue.add(waiter);

        timers.put(getCurrentIndex(), getScheduler().schedule(timeout, TimeUnit.MILLISECONDS, () -> {
          timers.remove(getCurrentIndex());
          waiterQueue.remove(waiter);
          fail(session.sessionId(), id);
        }));
      } else if (timeout == 0) {
        fail(session.sessionId(), id);
      } else {
        waiterQueue.add(new Waiter(
            session.sessionId(),
            getCurrentIndex(),
            id,
            permits,
            0));
      }
    }
  }

  @Override
  public void release(int permits) {
    release(getCurrentSession().sessionId().id(), permits);
  }

  @Override
  public int available() {
    return available;
  }

  @Override
  public int drain() {
    int acquirePermits = available;
    available = 0;

    if (acquirePermits > 0) {
      holders.compute(getCurrentSession().sessionId().id(), (k, v) -> {
        if (v == null) {
          v = 0;
        }
        return v + acquirePermits;
      });
    }
    return acquirePermits;
  }

  @Override
  public int increase(int permits) {
    increaseAvailable(permits);
    checkAndNotifyWaiters();
    return available;
  }

  @Override
  public int reduce(int permits) {
    return decreaseAvailable(permits);
  }

  @Override
  public QueueStatus queueStatus() {
    int permits = waiterQueue.stream().map(w -> w.acquirePermits).reduce(0, Integer::sum);
    return new QueueStatus(waiterQueue.size(), permits);
  }

  @Override
  public Map<Long, Integer> holderStatus() {
    return holders;
  }

  private void acquire(SessionId sessionId, long operationId, int acquirePermits, long version) {
    decreaseAvailable(acquirePermits);
    holders.compute(sessionId.id(), (k, v) -> {
      if (v == null) {
        v = 0;
      }
      return v + acquirePermits;
    });
    success(sessionId, operationId, acquirePermits, version);
  }


  /**
   * Release permits and traverse the queue to remove waiters that meet the requirement.
   *
   * @param sessionId      sessionId
   * @param releasePermits permits to release
   */
  private void release(long sessionId, int releasePermits) {
    increaseAvailable(releasePermits);
    holders.computeIfPresent(sessionId, (id, acquired) -> {
      acquired -= releasePermits;
      if (acquired <= 0) {
        return null;
      }
      return acquired;
    });

    checkAndNotifyWaiters();
  }

  private void success(SessionId sessionId, long operationId, int acquirePermits, long version) {
    getSession(sessionId).accept(client -> client.succeeded(operationId, version, acquirePermits));
  }

  private void fail(SessionId sessionId, long operationId) {
    getSession(sessionId).accept(client -> client.failed(operationId));
  }

  private void releaseSession(Session session) {
    if (holders.containsKey(session.sessionId().id())) {
      release(session.sessionId().id(), holders.get(session.sessionId().id()));
    }
  }

  private void checkAndNotifyWaiters() {
    Iterator<Waiter> iterator = waiterQueue.iterator();
    while (iterator.hasNext() && available > 0) {
      Waiter waiter = iterator.next();
      if (available >= waiter.acquirePermits) {
        iterator.remove();
        Scheduled timer = timers.remove(waiter.index);
        if (timer != null) {
          timer.cancel();
        }
        acquire(waiter.session, waiter.id, waiter.acquirePermits, waiter.index);
      }
    }
  }

  private int increaseAvailable(int permits) {
    int newAvailable = available + permits;

    if (newAvailable < available) {
      newAvailable = Integer.MAX_VALUE;
    }
    available = newAvailable;

    return available;
  }

  private int decreaseAvailable(int permits) {
    int newAvailable = available - permits;

    if (newAvailable > available) {
      newAvailable = Integer.MIN_VALUE;
    }
    available = newAvailable;

    return available;
  }

  @Override
  public Serializer serializer() {
    return SERIALIZER;
  }

  private class Waiter {
    private final SessionId session;
    private final long index;
    private final long id;
    private final int acquirePermits;
    private final long expire;

    Waiter(SessionId session, long index, long id, int acquirePermits, long expire) {
      this.session = session;
      this.index = index;
      this.id = id;
      this.acquirePermits = acquirePermits;
      this.expire = expire;
    }

    @Override
    public boolean equals(Object o) {
      if (this == o) {
        return true;
      }
      if (o == null || getClass() != o.getClass()) {
        return false;
      }
      Waiter waiter = (Waiter) o;
      return session.equals(waiter.session)
          && index == waiter.index
          && id == waiter.id
          && acquirePermits == waiter.acquirePermits;
    }

    @Override
    public int hashCode() {
      return Objects.hashCode(session, index, id, acquirePermits);
    }

    @Override
    public String toString() {
      return MoreObjects.toStringHelper(this)
          .add("session", session)
          .add("index", index)
          .add("id", id)
          .add("acquirePermits", acquirePermits)
          .add("expire", expire)
          .toString();
    }
  }
}