package fr.umlv.loom; import java.lang.invoke.MethodHandles; import java.lang.invoke.VarHandle; import java.time.Duration; import java.util.Objects; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Future; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.function.Supplier; public interface Task<T> extends Future<T> { T join() throws CancellationException; T await(Duration duration) throws CancellationException, TimeoutException; final class TaskImpl<T> implements Task<T> { private static final Object CANCELLED = new Object(); private static final VarHandle RESULT_HANDLE; static { try { RESULT_HANDLE = MethodHandles.lookup().findVarHandle(TaskImpl.class, "result", Object.class); } catch (NoSuchFieldException | IllegalAccessException e) { throw new AssertionError(e); } } private final static class $$$<E extends Throwable> { final E throwable; $$$(E throwable) { this.throwable = throwable; } } private final Thread virtualThread; private volatile Object result; // null -> CANCELLED or null -> value | $$$(exception) TaskImpl(ThreadFactory factory, Supplier<? extends T> supplier) { virtualThread = factory.newThread(() -> { Object result; try { result = supplier.get(); } catch(Error e) { // don't capture errors, only checked and unchecked exceptions throw e; } catch(Throwable e) { result = new $$$<>(e); } setResultIfNull(Objects.requireNonNull(result)); }); } private boolean setResultIfNull(Object result) { return RESULT_HANDLE.compareAndSet(this, (Object)null, result); } @Override @SuppressWarnings("unchecked") public T join() { try { virtualThread.join(); } catch (InterruptedException e) { throw new CompletionException(e); } Object result = this.result; if (result == CANCELLED) { throw new CancellationException(); } if (result instanceof $$$<?>) { throw (($$$<RuntimeException>)result).throwable; } return (T)result; } @Override @SuppressWarnings("unchecked") public T await(Duration duration) throws TimeoutException { try { virtualThread.join(duration); } catch(InterruptedException e) { throw new CompletionException(e); } if (setResultIfNull(CANCELLED)) { throw new TimeoutException(); } Object result = this.result; if (result == CANCELLED) { throw new CancellationException(); } if (result instanceof $$$<?>) { throw (($$$<RuntimeException>)result).throwable; } return (T)result; } @Override @SuppressWarnings("unchecked") public T get() throws CancellationException, ExecutionException, InterruptedException { virtualThread.join(); Object result = this.result; if (result == CANCELLED) { throw new CancellationException(); } if (result instanceof $$$<?>) { throw new ExecutionException((($$$<?>)result).throwable); } return (T)result; } @Override @SuppressWarnings("unchecked") public T get(long timeout, TimeUnit unit) throws TimeoutException, ExecutionException, InterruptedException { virtualThread.join(Duration.of(timeout, unit.toChronoUnit())); if (setResultIfNull(CANCELLED)) { throw new TimeoutException(); } Object result = this.result; if (result == CANCELLED) { throw new CancellationException(); } if (result instanceof $$$<?>) { throw new ExecutionException((($$$<?>)result).throwable); } return (T)result; } @Override public boolean isDone() { return result != null; } @Override public boolean cancel(boolean mayInterruptIfRunning) { return setResultIfNull(CANCELLED); } @Override public boolean isCancelled() { return result == CANCELLED; } } public static <T> Task<T> async(Supplier<? extends T> supplier) { return async0(runnable -> Thread.newThread(Thread.VIRTUAL, runnable), supplier); } public static <T> Task<T> async(Executor executor, Supplier<? extends T> supplier) { return async0(runnable -> Thread.builder().virtual(executor).task(runnable).build(), supplier); } private static <T> Task<T> async0(ThreadFactory factory, Supplier<? extends T> supplier) { var task = new TaskImpl<T>(factory, supplier); task.virtualThread.start(); return task; } }