/* * Copyright (c) IBM Corporation 2017. All Rights Reserved. * Project name: java-async-util * This project is licensed under the Apache License 2.0, see LICENSE. */ package com.ibm.asyncutil.util; import java.util.Objects; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; import com.ibm.asyncutil.util.TestUtil.CompletableStage; /** * A rudimentary implementation of {@link CompletionStage}. Used to validate CompletionStage * behavior against a class that isn't {@link CompletableFuture} */ public class SimpleCompletionStage<T> implements CompletableStage<T> { @SuppressWarnings("rawtypes") private static final AtomicReferenceFieldUpdater<SimpleCompletionStage, Object> RESULT_UPDATER = AtomicReferenceFieldUpdater.newUpdater(SimpleCompletionStage.class, Object.class, "result"); @SuppressWarnings("rawtypes") private static final AtomicReferenceFieldUpdater<SimpleCompletionStage, Node> STACK_UPDATER = AtomicReferenceFieldUpdater.newUpdater(SimpleCompletionStage.class, Node.class, "stack"); private static final Executor LOCAL = Runnable::run; private static final Executor ASYNC = CompletableFuture::runAsync; private static final Object NIL = new ExceptionalResult(null); private volatile Object result; @SuppressWarnings("unused") // used via STACK_UPDATER private volatile Node<T> stack; private static class Node<T> { Node<T> next; private final CompletionAction<T> action; public Node(final CompletionAction<T> action) { this.action = action; } } private interface CompletionAction<T> { void accept(Object res, T t, Throwable exc); } private static class ExceptionalResult { private final Throwable throwable; public ExceptionalResult(final Throwable throwable) { this.throwable = throwable; } } /** * @return true iff stage wasn't complete and is now complete with the given result */ @Override public boolean complete(final T result) { return completeEncoded(result == null ? NIL : result); } /** * @return true iff stage wasn't complete and is now complete with the given exception */ @Override public boolean completeExceptionally(final Throwable exception) { return completeEncoded(new ExceptionalResult(Objects.requireNonNull(exception))); } private boolean completeEncoded(final Object res) { if (RESULT_UPDATER.compareAndSet(this, null, res)) { popAllAndRun(res); return true; } return false; } @Override public boolean isDone() { return this.result != null; } private void push(final CompletionAction<T> action) { final Node<T> n = new Node<>(action); STACK_UPDATER.updateAndGet(this, head -> { @SuppressWarnings("unchecked") final Node<T> safeNode = head; n.next = safeNode; return n; }); final Object res = this.result; if (res != null) { popAllAndRun(res); } } private void popAllAndRun(final Object res) { final T t; final Throwable exc; if (res instanceof ExceptionalResult) { t = null; exc = ((ExceptionalResult) res).throwable; } else { @SuppressWarnings("unchecked") final T safeT = (T) res; t = safeT; exc = null; } for (@SuppressWarnings("unchecked") Node<T> n = STACK_UPDATER.getAndSet(this, null); n != null; n = n.next) { n.action.accept(res, t, exc); } } private void runOrPush(final CompletionAction<T> action) { final Object res = this.result; if (res == null) { push(action); } else { if (res instanceof ExceptionalResult) { action.accept(res, null, ((ExceptionalResult) res).throwable); } else { @SuppressWarnings("unchecked") final T safeT = (T) res; action.accept(res, safeT, null); } } } private static <T> CompletionException wrapIfNecessary(final Throwable e) { return e instanceof CompletionException ? (CompletionException) e : new CompletionException(e); } @Override public <U> CompletionStage<U> thenApply(final Function<? super T, ? extends U> fn) { return thenApplyAsync(fn, LOCAL); } @Override public <U> CompletionStage<U> thenApplyAsync(final Function<? super T, ? extends U> fn) { return thenApplyAsync(fn, ASYNC); } @Override public <U> CompletionStage<U> thenApplyAsync(final Function<? super T, ? extends U> fn, final Executor executor) { Objects.requireNonNull(fn); Objects.requireNonNull(executor); final SimpleCompletionStage<U> scs = new SimpleCompletionStage<>(); runOrPush((res, t, exc) -> { try { executor.execute(() -> { if (exc == null) { try { scs.complete(fn.apply(t)); } catch (final Throwable e) { scs.completeExceptionally(SimpleCompletionStage.wrapIfNecessary(e)); } } else { scs.completeEncoded(res); } }); } catch (final Throwable e) { scs.completeExceptionally(e); } }); return scs; } @Override public CompletionStage<Void> thenAccept(final Consumer<? super T> action) { return thenAcceptAsync(action, LOCAL); } @Override public CompletionStage<Void> thenAcceptAsync(final Consumer<? super T> action) { return thenAcceptAsync(action, ASYNC); } @Override public CompletionStage<Void> thenAcceptAsync(final Consumer<? super T> action, final Executor executor) { Objects.requireNonNull(action); return thenApplyAsync(t -> { action.accept(t); return null; }, executor); } @Override public CompletionStage<Void> thenRun(final Runnable action) { return thenRunAsync(action, LOCAL); } @Override public CompletionStage<Void> thenRunAsync(final Runnable action) { return thenRunAsync(action, ASYNC); } @Override public CompletionStage<Void> thenRunAsync(final Runnable action, final Executor executor) { Objects.requireNonNull(action); return thenApplyAsync(t -> { action.run(); return null; }, executor); } @Override public <U, V> CompletionStage<V> thenCombine(final CompletionStage<? extends U> other, final BiFunction<? super T, ? super U, ? extends V> fn) { return thenCombineAsync(other, fn, LOCAL); } @Override public <U, V> CompletionStage<V> thenCombineAsync(final CompletionStage<? extends U> other, final BiFunction<? super T, ? super U, ? extends V> fn) { return thenCombineAsync(other, fn, ASYNC); } @Override public <U, V> CompletionStage<V> thenCombineAsync(final CompletionStage<? extends U> other, final BiFunction<? super T, ? super U, ? extends V> fn, final Executor executor) { Objects.requireNonNull(fn); final SimpleCompletionStage<V> scs = new SimpleCompletionStage<>(); other.whenCompleteAsync((u, uExc) -> { try { this.whenCompleteAsync((t, tExc) -> { if (uExc == null) { if (tExc == null) { try { scs.complete(fn.apply(t, u)); } catch (final Throwable e) { scs.completeExceptionally(SimpleCompletionStage.wrapIfNecessary(e)); } } else { scs.completeExceptionally(tExc); } } else { if (tExc != null) { uExc.addSuppressed(tExc); } scs.completeExceptionally(uExc); } }, executor); } catch (final Throwable e) { scs.completeExceptionally(e); } }, executor); return scs; } @Override public <U> CompletionStage<Void> thenAcceptBoth(final CompletionStage<? extends U> other, final BiConsumer<? super T, ? super U> action) { return thenAcceptBothAsync(other, action, LOCAL); } @Override public <U> CompletionStage<Void> thenAcceptBothAsync(final CompletionStage<? extends U> other, final BiConsumer<? super T, ? super U> action) { return thenAcceptBothAsync(other, action, ASYNC); } @Override public <U> CompletionStage<Void> thenAcceptBothAsync(final CompletionStage<? extends U> other, final BiConsumer<? super T, ? super U> action, final Executor executor) { Objects.requireNonNull(action); return thenCombineAsync(other, (t, u) -> { action.accept(t, u); return null; }, executor); } @Override public CompletionStage<Void> runAfterBoth(final CompletionStage<?> other, final Runnable action) { return runAfterBothAsync(other, action, LOCAL); } @Override public CompletionStage<Void> runAfterBothAsync(final CompletionStage<?> other, final Runnable action) { return runAfterBothAsync(other, action, ASYNC); } @Override public CompletionStage<Void> runAfterBothAsync(final CompletionStage<?> other, final Runnable action, final Executor executor) { Objects.requireNonNull(action); return thenCombineAsync(other, (t, u) -> { action.run(); return null; }, executor); } @Override public <U> CompletionStage<U> applyToEither(final CompletionStage<? extends T> other, final Function<? super T, U> fn) { return applyToEitherAsync(other, fn, LOCAL); } @Override public <U> CompletionStage<U> applyToEitherAsync(final CompletionStage<? extends T> other, final Function<? super T, U> fn) { return applyToEitherAsync(other, fn, ASYNC); } @Override public <U> CompletionStage<U> applyToEitherAsync(final CompletionStage<? extends T> other, final Function<? super T, U> fn, final Executor executor) { Objects.requireNonNull(fn); final SimpleCompletionStage<T> scs = new SimpleCompletionStage<>(); runOrPush((res, t, exc) -> scs.completeEncoded(res)); other.whenComplete((t, exc) -> { if (exc == null) { scs.complete(t); } else { scs.completeExceptionally(exc); } }); return scs.thenApplyAsync(fn, executor); } @Override public CompletionStage<Void> acceptEither(final CompletionStage<? extends T> other, final Consumer<? super T> action) { return acceptEitherAsync(other, action, LOCAL); } @Override public CompletionStage<Void> acceptEitherAsync(final CompletionStage<? extends T> other, final Consumer<? super T> action) { return acceptEitherAsync(other, action, ASYNC); } @Override public CompletionStage<Void> acceptEitherAsync(final CompletionStage<? extends T> other, final Consumer<? super T> action, final Executor executor) { Objects.requireNonNull(action); return applyToEitherAsync(other, t -> { action.accept(t); return null; }, executor); } @Override public CompletionStage<Void> runAfterEither(final CompletionStage<?> other, final Runnable action) { return runAfterEitherAsync(other, action, LOCAL); } @Override public CompletionStage<Void> runAfterEitherAsync(final CompletionStage<?> other, final Runnable action) { return runAfterEitherAsync(other, action, ASYNC); } @Override public CompletionStage<Void> runAfterEitherAsync(final CompletionStage<?> other, final Runnable action, final Executor executor) { Objects.requireNonNull(action); @SuppressWarnings("unchecked") final CompletionStage<T> safeStage = (CompletionStage<T>) other; return applyToEitherAsync(safeStage, t -> { action.run(); return null; }, executor); } @Override public <U> CompletionStage<U> thenCompose( final Function<? super T, ? extends CompletionStage<U>> fn) { return thenComposeAsync(fn, LOCAL); } @Override public <U> CompletionStage<U> thenComposeAsync( final Function<? super T, ? extends CompletionStage<U>> fn) { return thenComposeAsync(fn, ASYNC); } @Override public <U> CompletionStage<U> thenComposeAsync( final Function<? super T, ? extends CompletionStage<U>> fn, final Executor executor) { Objects.requireNonNull(fn); Objects.requireNonNull(executor); final SimpleCompletionStage<U> scs = new SimpleCompletionStage<>(); runOrPush((res, t, tExc) -> { if (tExc == null) { try { executor.execute(() -> { try { fn.apply(t).whenComplete((u, uExc) -> { if (uExc == null) { scs.complete(u); } else { scs.completeExceptionally(uExc); } }); } catch (final Throwable e) { scs.completeExceptionally(SimpleCompletionStage.wrapIfNecessary(e)); } }); } catch (final Throwable e) { scs.completeExceptionally(e); } } else { scs.completeEncoded(res); } }); return scs; } @Override public CompletionStage<T> exceptionally(final Function<Throwable, ? extends T> fn) { Objects.requireNonNull(fn); final SimpleCompletionStage<T> scs = new SimpleCompletionStage<>(); runOrPush((res, t, exc) -> { if (exc == null) { scs.completeEncoded(res); } else { try { scs.complete(fn.apply(exc)); } catch (final Throwable e) { scs.completeExceptionally(SimpleCompletionStage.wrapIfNecessary(e)); } } }); return scs; } @Override public CompletionStage<T> whenComplete(final BiConsumer<? super T, ? super Throwable> action) { return whenCompleteAsync(action, LOCAL); } @Override public CompletionStage<T> whenCompleteAsync( final BiConsumer<? super T, ? super Throwable> action) { return whenCompleteAsync(action, ASYNC); } @Override public CompletionStage<T> whenCompleteAsync(final BiConsumer<? super T, ? super Throwable> action, final Executor executor) { Objects.requireNonNull(action); final SimpleCompletionStage<T> scs = new SimpleCompletionStage<>(); runOrPush((res, t, exc) -> { try { executor.execute(() -> { try { action.accept(t, exc); } catch (final Throwable e) { if (exc == null) { scs.completeExceptionally(SimpleCompletionStage.wrapIfNecessary(e)); return; } exc.addSuppressed(e); } scs.completeEncoded(res); }); } catch (final Throwable e) { scs.completeExceptionally(e); } }); return scs; } @Override public <U> CompletionStage<U> handle(final BiFunction<? super T, Throwable, ? extends U> fn) { return handleAsync(fn, LOCAL); } @Override public <U> CompletionStage<U> handleAsync( final BiFunction<? super T, Throwable, ? extends U> fn) { return handleAsync(fn, ASYNC); } @Override public <U> CompletionStage<U> handleAsync(final BiFunction<? super T, Throwable, ? extends U> fn, final Executor executor) { Objects.requireNonNull(fn); final SimpleCompletionStage<U> scs = new SimpleCompletionStage<>(); whenCompleteAsync((t, exc) -> { try { scs.complete(fn.apply(t, exc)); } catch (final Throwable e) { scs.completeExceptionally(SimpleCompletionStage.wrapIfNecessary(e)); } }, executor); return scs; } @Override public CompletableFuture<T> toCompletableFuture() { final CompletableFuture<T> cf = new CompletableFuture<>(); runOrPush((res, t, exc) -> { if (exc == null) { cf.complete(t); } else { cf.completeExceptionally(exc); } }); return cf; } }