diff --git a/README.md b/README.md index 251f005..53460df 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,32 @@ This is a library meant for library authors that want to build libraries that wo Read the [Javadoc](https://javadoc.io/doc/org.funfix/tasks-jvm/0.4.1/org/funfix/tasks/jvm/package-summary.html). Better documentation is coming. +### Migration Note (v0.5.0) + +The `AsyncFun` interface has changed to improve cancellation management and simplify the API. This is a source and binary incompatible change. + +**Old shape:** +```java +Task.fromAsync((executor, callback) -> { + // ... + return () -> { /* cleanup */ }; +}); +``` + +**New shape:** +```java +Task.fromAsync(continuation -> { + var executor = continuation.getExecutor(); + continuation.invokeOnCancellation(() -> { /* cleanup */ }); + // ... +}); +``` + +Key differences: +- The `executor` and `callback` are now encapsulated in the `Continuation`. +- Cancellation cleanup is registered via `continuation.invokeOnCancellation(finalizer)` instead of returning a `Cancellable`. +- `continuation.onCancellation()` signals that the task has completed due to cancellation, whereas `invokeOnCancellation(finalizer)` registers a cleanup action to run when cancellation occurs. + --- Maven: diff --git a/tasks-jvm/src/main/java/org/funfix/tasks/jvm/AsyncFun.java b/tasks-jvm/src/main/java/org/funfix/tasks/jvm/AsyncFun.java index 4e9ed31..e45808f 100644 --- a/tasks-jvm/src/main/java/org/funfix/tasks/jvm/AsyncFun.java +++ b/tasks-jvm/src/main/java/org/funfix/tasks/jvm/AsyncFun.java @@ -4,18 +4,28 @@ import org.jspecify.annotations.Nullable; import java.io.Serializable; -import java.util.concurrent.Executor; /** * A function that is a delayed, asynchronous computation. *

- * This function type is what's needed to describe {@link Task} instances. + * The injected {@link Continuation} provides: + *

+ *

+ * Example: + *

{@code
+ * Task.fromAsync(continuation -> {
+ *     final Executor executor = continuation.getExecutor();
+ *     continuation.invokeOnCancellation(() -> System.out.println("cleanup"));
+ *     executor.execute(() -> continuation.onSuccess("ok"));
+ * });
+ * }
*/ @FunctionalInterface @NonBlocking public interface AsyncFun extends Serializable { - Cancellable invoke( - Executor executor, - CompletionCallback continuation - ); + void invoke(Continuation continuation); } diff --git a/tasks-jvm/src/main/java/org/funfix/tasks/jvm/Cancellable.java b/tasks-jvm/src/main/java/org/funfix/tasks/jvm/Cancellable.java index 5a3c9eb..09e0101 100644 --- a/tasks-jvm/src/main/java/org/funfix/tasks/jvm/Cancellable.java +++ b/tasks-jvm/src/main/java/org/funfix/tasks/jvm/Cancellable.java @@ -10,8 +10,8 @@ /** * This is a token that can be used for interrupting a scheduled or * a running task. - * - *

The contract for {@code cancel} is: + *

+ * The contract for {@code cancel} is: *

    *
  1. Its execution is idempotent, meaning that calling it multiple times * has the same effect as calling it once.
  2. @@ -52,19 +52,6 @@ final class CancellableUtils { static final Cancellable EMPTY = () -> {}; } -/** - * Represents a forward reference to a {@link Cancellable} that was already - * registered and needs to be filled in later. - *

    - * INTERNAL API: Internal apis are subject to change or removal - * without any notice. When code depends on internal APIs, it is subject to - * breakage between minor version updates. - */ -@ApiStatus.Internal -interface CancellableForwardRef extends Cancellable { - void set(Cancellable cancellable); -} - /** * INTERNAL API. *

    @@ -74,58 +61,101 @@ interface CancellableForwardRef extends Cancellable { */ @ApiStatus.Internal final class MutableCancellable implements Cancellable { - private final AtomicReference ref; + private final AtomicReference<@Nullable State> ref; - MutableCancellable(final Cancellable initialRef) { + MutableCancellable(final @Nullable Cancellable initialRef) { ref = new AtomicReference<>(new State.Active(initialRef, 0, null)); } MutableCancellable() { - this(CancellableUtils.EMPTY); + this(null); } + /** + * Cancels all registered cancellation tokens. + *

    + * Tries to execute the finalizers on the same thread, for as long as + * possible. Normally, `Cancellation` tokens should not be expensive + * (non-blocking), but we can't force this via the compiler, and it's + * best to avoid concurrency issues if possible. Although, note that + * this is not a contract that we can provide. E.g., if the task + * was canceled, we're forced to cancel future cancellation tokens + * on the thread calling `register`. + */ @Override public void cancel() { - var state = ref.getAndSet(State.Closed.INSTANCE); - while (state instanceof State.Active active) { - try { - active.token.cancel(); - } catch (Exception e) { - UncaughtExceptionHandler.logOrRethrow(e); + while (true) { + final var current = ref.get(); + if (current instanceof State.Active active) { + final var update = new State.Cancelling( + // Adding a dummy that preserves the `order` for incoming + // subscriptions. + new State.Active(null, active.order, null) + ); + if (ref.compareAndSet(current, update)) { + // NOTE: this will eventually set the state to Cancelled + // (after managing to cancel all tokens) + startCancellationOfEverything(active); + return; + } + } else { + return; } - state = active.rest; } } - public CancellableForwardRef newCancellableRef() { - final var current = ref.get(); - if (current instanceof State.Closed) { - return new CancellableForwardRef() { - @Override - public void set(Cancellable cancellable) { - cancellable.cancel(); - } - @Override - public void cancel() {} - }; - } else if (current instanceof State.Active active) { - return new CancellableForwardRef() { - @Override - public void set(Cancellable cancellable) { - registerOrdered( - active.order, - cancellable, - active - ); + private void startCancellationOfEverything(State.@Nullable Active active) { + while (active != null) { + if (active.token != null) { + invoke(active.token); + } + active = active.rest; + // Tries fetching more tokens, since the state may have been updated + if (active == null) { + // Kind of hacky, but we need the loop due to the CAS + while (true) { + final var current = ref.get(); + if (current instanceof State.Cancelling cancelling) { + // We could have newly registered cancellable references. + // If we have, then active != null and will get processed. + active = cancelling.toCancel; + final var update = active == null + // If not, closing the loop with a final update + ? State.Cancelled.INSTANCE + : new State.Cancelling(null); + // If CAS succeeds, we break from the loop + // otherwise a concurrent update happened, so continue + if (ref.compareAndSet(current, update)) { + break; + } + } else { + // Once in `Cancelling`, concurrent updates are only allowed + // to go Cancelling -> Cancelling + final var name = current != null ? current.getClass().getName() : "null"; + throw new IllegalStateException("Bug — found: " + name); + } } + } + } + } - @Override - public void cancel() { - unregister(active.order); + /** + * The purpose of this method is to be called when invoking the methods + * on {@code CompletionCallback} in order to prevent leaks. + *

    + * NOTE: in case a concurrent `cancel()` has already happened, + * then `complete()` becomes a NO-OP. + */ + void complete() { + while (true) { + final var current = ref.get(); + if (current instanceof State.Active) { + if (ref.compareAndSet(current, State.Completed.INSTANCE)) { + return; } - }; - } else { - throw new IllegalStateException("Invalid state: " + current); + } else { + return; + } } } @@ -134,11 +164,19 @@ public void cancel() { while (true) { final var current = ref.get(); if (current instanceof State.Active active) { - final var newOrder = active.order + 1; - final var update = new State.Active(token, newOrder, active); - if (ref.compareAndSet(current, update)) { return () -> unregister(newOrder); } - } else if (current instanceof State.Closed) { - token.cancel(); + final var update = active.register(token); + if (ref.compareAndSet(current, update)) { + return () -> unregister(update.order); + } + } else if (current instanceof State.Cancelling cancelling) { + final var update = cancelling.register(token); + if (ref.compareAndSet(current, update)) { + return null; + } + } else if (current instanceof State.Cancelled) { + invoke(token); + return null; + } else if (current instanceof State.Completed) { return null; } else { throw new IllegalStateException("Invalid state: " + current); @@ -151,65 +189,68 @@ private void unregister(final long order) { final var current = ref.get(); if (current instanceof State.Active active) { State.@Nullable Active cursor = active; - State.@Nullable Active acc = null; + State.@Nullable Active newListReversed = null; while (cursor != null) { if (cursor.order != order) { - acc = new State.Active(cursor.token, cursor.order, acc); + newListReversed = new State.Active(cursor.token, cursor.order, newListReversed); } cursor = cursor.rest; } // Reversing State.@Nullable Active update = null; - while (acc != null) { - update = new State.Active(acc.token, acc.order, update); - acc = acc.rest; + while (newListReversed != null) { + update = new State.Active(newListReversed.token, newListReversed.order, update); + newListReversed = newListReversed.rest; } if (update == null) { - update = new State.Active(Cancellable.getEmpty(), 0, null); + update = new State.Active(null, 0, null); } if (ref.compareAndSet(current, update)) { return; } - } else if (current instanceof State.Closed) { - return; } else { - throw new IllegalStateException("Invalid state: " + current); + return; } } } - private void registerOrdered( - final long order, - final Cancellable newToken, - State current - ) { - while (true) { - if (current instanceof State.Active active) { - // Double-check ordering - if (active.order != order) { return; } - // Try to update - final var update = new State.Active(newToken, order + 1, null); - if (ref.compareAndSet(current, update)) { return; } - // Retry - current = ref.get(); - } else if (current instanceof State.Closed) { - newToken.cancel(); - return; - } else { - throw new IllegalStateException("Invalid state: " + current); - } + private static void invoke(Cancellable token) { + try { + token.cancel(); + } catch (Throwable e) { + UncaughtExceptionHandler.logOrRethrow(e); } } sealed interface State { record Active( - Cancellable token, + @Nullable Cancellable token, long order, @Nullable Active rest - ) implements State {} + ) implements State { + Active register(Cancellable token) { + final var newOrder = order + 1; + return new State.Active(token, newOrder, this); + } + } + + record Cancelling( + @Nullable Active toCancel + ) implements State { + Cancelling register(Cancellable token) { + final var newOrder = toCancel != null ? toCancel.order + 1 : 1; + return new State.Cancelling( + new State.Active(token, newOrder, toCancel) + ); + } + } + + record Cancelled() implements State { + static final Cancelled INSTANCE = new Cancelled(); + } - record Closed() implements State { - static final Closed INSTANCE = new Closed(); + record Completed() implements State { + static final Completed INSTANCE = new Completed(); } } } diff --git a/tasks-jvm/src/main/java/org/funfix/tasks/jvm/CompletionCallback.java b/tasks-jvm/src/main/java/org/funfix/tasks/jvm/CompletionCallback.java index 4cd0e08..bc1da16 100644 --- a/tasks-jvm/src/main/java/org/funfix/tasks/jvm/CompletionCallback.java +++ b/tasks-jvm/src/main/java/org/funfix/tasks/jvm/CompletionCallback.java @@ -9,7 +9,6 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.AbstractQueuedSynchronizer; import org.jetbrains.annotations.ApiStatus; -import org.jspecify.annotations.NullMarked; import org.jspecify.annotations.Nullable; /** @@ -28,10 +27,8 @@ public interface CompletionCallback< > extends Serializable { /** * Signals the completion of the task. - * - * @param outcome */ - void onOutcome(Outcome outcome); + void onOutcome(Outcome outcome); /** * Must be called when the task completes successfully. @@ -63,7 +60,7 @@ default void onCancellation() { */ static CompletionCallback empty() { return outcome -> { - if (outcome instanceof Outcome.Failure f) { + if (outcome instanceof Outcome.Failure f) { UncaughtExceptionHandler.logOrRethrow(f.exception()); } }; @@ -104,7 +101,7 @@ ManyCompletionCallback withExtraListener( } @Override - public void onOutcome(Outcome outcome) { + public void onOutcome(Outcome outcome) { for (final CompletionCallback listener : listeners) { try { listener.onOutcome(outcome); @@ -169,7 +166,7 @@ final class AsyncContinuationCallback< private final AtomicReference> listenerRef; private final TaskExecutor executor; - private @Nullable Outcome outcome; + private @Nullable Outcome outcome; private @Nullable T successValue; private @Nullable Throwable failureCause; private boolean isCancelled = false; @@ -205,12 +202,12 @@ public void run() { } @Override - public void onOutcome(final Outcome outcome) { + public void onOutcome(final Outcome outcome) { Objects.requireNonNull(outcome, "outcome"); if (isWaiting.getAndSet(false)) { this.outcome = outcome; executor.resumeOnExecutor(this); - } else if (outcome instanceof Outcome.Failure f) { + } else if (outcome instanceof Outcome.Failure f) { UncaughtExceptionHandler.logOrRethrow(f.exception()); } } @@ -341,10 +338,10 @@ public void onCancellation() { } @Override - public void onOutcome(Outcome outcome) { - if (outcome instanceof Outcome.Success success) { + public void onOutcome(Outcome outcome) { + if (outcome instanceof Outcome.Success success) { onSuccess(success.value()); - } else if (outcome instanceof Outcome.Failure failure) { + } else if (outcome instanceof Outcome.Failure failure) { onFailure(failure.exception()); } else { onCancellation(); diff --git a/tasks-jvm/src/main/java/org/funfix/tasks/jvm/Continuation.java b/tasks-jvm/src/main/java/org/funfix/tasks/jvm/Continuation.java index 268a9a2..5ee6e9f 100644 --- a/tasks-jvm/src/main/java/org/funfix/tasks/jvm/Continuation.java +++ b/tasks-jvm/src/main/java/org/funfix/tasks/jvm/Continuation.java @@ -6,38 +6,53 @@ import java.util.concurrent.Executor; /** - * INTERNAL API. - *

    * Continuation objects are used to complete tasks, or for registering * {@link Cancellable} references that can be used to interrupt running tasks. *

    * {@code Continuation} objects get injected in {@link AsyncFun} functions. * See {@link Task#fromAsync(AsyncFun)}. + *

    + * The {@code Continuation} provides the execution context and the means to + * signal the task's outcome. It also allows registering cleanup actions + * that run if the task is cancelled. * * @param is the type of the value that the task will complete with */ -@ApiStatus.Internal -interface Continuation +public interface Continuation extends CompletionCallback { /** * Returns the {@link Executor} that the task can use to run its * asynchronous computation. */ - TaskExecutor getExecutor(); + Executor getExecutor(); /** - * Registers a {@link Cancellable} reference that can be used to interrupt - * a running task. + * Registers a finalizer that is invoked if the running task is cancelled + * before completion. + *

    + * The finalizer is called only for cancellation before completion. + * If cancellation already happened, the finalizer may be called immediately. + * The finalizer is not called after success, failure, or terminal completion. + *

    + * Finalizers must be idempotent, fast, non-blocking, and thread-safe. + * Exceptions from finalizers are handled via {@link UncaughtExceptionHandler}. * - * @param cancellable is the reference to the cancellable object that this - * continuation will register. + * @param finalizer the cleanup action to run on cancellation */ - @Nullable Cancellable registerCancellable(Cancellable cancellable); + @Nullable Cancellable invokeOnCancellation(Cancellable finalizer); +} - CancellableForwardRef registerForwardCancellable(); +/** + * INTERNAL API. + */ +@ApiStatus.Internal +interface InternalContinuation + extends Continuation { + + TaskExecutor getTaskExecutor(); - Continuation withExecutorOverride(TaskExecutor executor); + InternalContinuation withExecutorOverride(TaskExecutor executor); void registerExtraCallback(CompletionCallback extraCallback); } @@ -48,7 +63,7 @@ interface Continuation @ApiStatus.Internal @FunctionalInterface interface AsyncContinuationFun { - void invoke(Continuation continuation); + void invoke(InternalContinuation continuation); } /** @@ -56,7 +71,7 @@ interface AsyncContinuationFun { */ @ApiStatus.Internal final class CancellableContinuation - implements Continuation, Cancellable { + implements InternalContinuation, Cancellable { private final ContinuationCallback callback; private final MutableCancellable cancellableRef; @@ -84,47 +99,51 @@ public CancellableContinuation( } @Override - public TaskExecutor getExecutor() { + public Executor getExecutor() { return this.executor; } @Override - public void cancel() { - cancellableRef.cancel(); + public TaskExecutor getTaskExecutor() { + return this.executor; } @Override - public CancellableForwardRef registerForwardCancellable() { - return cancellableRef.newCancellableRef(); + public void cancel() { + cancellableRef.cancel(); } @Override - public @Nullable Cancellable registerCancellable(Cancellable cancellable) { - return this.cancellableRef.register(cancellable); + public @Nullable Cancellable invokeOnCancellation(Cancellable finalizer) { + return cancellableRef.register(finalizer); } @Override - public void onOutcome(Outcome outcome) { + public void onOutcome(Outcome outcome) { + cancellableRef.complete(); callback.onOutcome(outcome); } @Override public void onSuccess(T value) { + cancellableRef.complete(); callback.onSuccess(value); } @Override public void onFailure(Throwable e) { + cancellableRef.complete(); callback.onFailure(e); } @Override public void onCancellation() { + cancellableRef.complete(); callback.onCancellation(); } @Override - public Continuation withExecutorOverride(TaskExecutor executor) { + public InternalContinuation withExecutorOverride(TaskExecutor executor) { return new CancellableContinuation<>( executor, callback, diff --git a/tasks-jvm/src/main/java/org/funfix/tasks/jvm/Fiber.java b/tasks-jvm/src/main/java/org/funfix/tasks/jvm/Fiber.java index 86c22a0..5026a24 100644 --- a/tasks-jvm/src/main/java/org/funfix/tasks/jvm/Fiber.java +++ b/tasks-jvm/src/main/java/org/funfix/tasks/jvm/Fiber.java @@ -301,7 +301,7 @@ public NotCompletedException() { @ApiStatus.Internal final class ExecutedFiber implements Fiber { private final TaskExecutor executor; - private final Continuation continuation; + private final InternalContinuation continuation; private final MutableCancellable cancellableRef; private final AtomicReference> stateRef; @@ -389,7 +389,7 @@ record Cancelled( ) implements State {} record Completed( - Outcome outcome + Outcome outcome ) implements State {} default void triggerListeners(TaskExecutor executor) { @@ -478,7 +478,7 @@ public void onCancellation() { } @Override - public void onOutcome(Outcome outcome) { + public void onOutcome(Outcome outcome) { while (true) { State current = stateRef.get(); if (current instanceof State.Active) { @@ -493,7 +493,7 @@ public void onOutcome(Outcome outcome) { return; } } else if (current instanceof State.Completed) { - if (outcome instanceof Outcome.Failure failure) { + if (outcome instanceof Outcome.Failure failure) { UncaughtExceptionHandler.logOrRethrow(failure.exception()); } return; diff --git a/tasks-jvm/src/main/java/org/funfix/tasks/jvm/Task.java b/tasks-jvm/src/main/java/org/funfix/tasks/jvm/Task.java index 8259e1f..9bbc538 100644 --- a/tasks-jvm/src/main/java/org/funfix/tasks/jvm/Task.java +++ b/tasks-jvm/src/main/java/org/funfix/tasks/jvm/Task.java @@ -38,10 +38,10 @@ private Task(final AsyncContinuationFun createFun) { */ public Task ensureRunningOnExecutor(final @Nullable Executor executor) { return new Task<>((cont) -> { - final Continuation cont2 = executor != null + final InternalContinuation cont2 = executor != null ? cont.withExecutorOverride(TaskExecutor.from(executor)) : cont; - cont2.getExecutor().resumeOnExecutor(() -> createFun.invoke(cont2)); + cont2.getTaskExecutor().resumeOnExecutor(() -> createFun.invoke(cont2)); }); } @@ -91,7 +91,7 @@ public Task withOnComplete(final CompletionCallback callback) { @SuppressWarnings("unchecked") final var extraCallback = (CompletionCallback) Objects.requireNonNull(callback); cont.registerExtraCallback(extraCallback); - cont.getExecutor().resumeOnExecutor(() -> createFun.invoke(cont)); + cont.getTaskExecutor().resumeOnExecutor(() -> createFun.invoke(cont)); }); } @@ -102,15 +102,15 @@ public Task withCancellation( final Cancellable cancellable ) { return new Task<>((cont) -> { - cont.registerCancellable(cancellable); - cont.getExecutor().resumeOnExecutor(() -> createFun.invoke(cont)); + cont.invokeOnCancellation(cancellable); + cont.getTaskExecutor().resumeOnExecutor(() -> createFun.invoke(cont)); }); } /** * Executes the task asynchronously. *

    - * This method ensures that the start starts execution on a different thread, + * This method ensures that the task starts execution on a different thread, * managed by the given executor. * * @param callback will be invoked when the task completes @@ -163,7 +163,7 @@ public Cancellable runAsync(final CompletionCallback callback) { * Similar to {@link #runAsync(Executor, CompletionCallback)}, this method * starts the execution on a different thread. * - * @param executor is the {@link Fiber} that may be used to run the task + * @param executor is the {@link Executor} that may be used to run the task * @return a {@link Fiber} that can be used to wait for the outcome, * or to cancel the running fiber. */ @@ -259,10 +259,10 @@ public T runBlocking() throws ExecutionException, InterruptedException { * @throws ExecutionException if the task fails with an exception * @throws InterruptedException if the current thread is interrupted. * The running task is also cancelled, and this method does not - * return until `onCancel` is signaled. + * return until `onCancellation` is signaled. * @throws TimeoutException if the task doesn't complete within the * specified timeout. The running task is also cancelled on timeout, - * and this method does not returning until `onCancel` is signaled. + * and this method does not return until `onCancellation` is signaled. */ public T runBlockingTimed( final @Nullable Executor executor, @@ -291,10 +291,10 @@ public T runBlockingTimed( * @throws ExecutionException if the task fails with an exception * @throws InterruptedException if the current thread is interrupted. * The running task is also cancelled, and this method does not - * return until `onCancel` is signaled. + * return until `onCancellation` is signaled. * @throws TimeoutException if the task doesn't complete within the * specified timeout. The running task is also cancelled on timeout, - * and this method does not returning until `onCancel` is signaled. + * and this method does not return until `onCancellation` is signaled. */ public T runBlockingTimed(final Duration timeout) throws ExecutionException, InterruptedException, TimeoutException { @@ -311,19 +311,20 @@ public T runBlockingTimed(final Duration timeout) *

  3. Trampolined execution to avoid stack-overflows
  4. *
*

+ * The provided {@link AsyncFun} receives a {@link Continuation} which + * provides the executor and methods to signal completion or register + * cancellation cleanup. * * @param start is the function that will trigger the async computation, - * injecting a callback that will be used to signal the result, - * and an executor that can be used for creating additional threads. + * injecting a {@link Continuation} that will be used to signal + * the result and manage cancellation. * @return a new task that will execute the given builder function upon execution */ public static Task fromAsync(final AsyncFun start) { return new Task<>((cont) -> Trampoline.execute(() -> { try { - cont.registerForwardCancellable().set( - start.invoke(cont.getExecutor(), cont) - ); + start.invoke(cont); } catch (final Throwable e) { UncaughtExceptionHandler.rethrowIfFatal(e); cont.onFailure(e); @@ -340,23 +341,17 @@ public T runBlockingTimed(final Duration timeout) public static Task fromBlockingIO(final DelayedFun run) { return new Task<>((cont) -> { Thread th = Thread.currentThread(); - final var registration = cont.registerCancellable(th::interrupt); - if (registration == null) { + cont.invokeOnCancellation(th::interrupt); + if (th.isInterrupted()) { cont.onCancellation(); return; } try { - T result; - try { - TaskLocalContext.signalTheStartOfBlockingCall(); - result = run.invoke(); - } finally { - registration.cancel(); - } + TaskLocalContext.signalTheStartOfBlockingCall(); + T result = run.invoke(); if (th.isInterrupted()) { throw new InterruptedException(); } - //noinspection DataFlowIssue cont.onSuccess(result); } catch (final InterruptedException | TaskCancellationException e) { cont.onCancellation(); @@ -497,18 +492,14 @@ public TaskFromCancellableFuture(DelayedFun> buil } @Override - public void invoke(Continuation continuation) { + public void invoke(InternalContinuation continuation) { try { - final var cancellableRef = - continuation.registerForwardCancellable(); - final var cancellableFuture = - Objects.requireNonNull(builder.invoke()); + final var cancellableFuture = Objects.requireNonNull(builder.invoke()); final CompletableFuture future = getCompletableFuture(cancellableFuture, continuation); - final Cancellable cancellable = - cancellableFuture.cancellable(); + final Cancellable cancellable = cancellableFuture.cancellable(); - cancellableRef.set(() -> { + continuation.invokeOnCancellation(() -> { try { future.cancel(true); } finally { diff --git a/tasks-jvm/src/main/java/org/funfix/tasks/jvm/TaskUtils.java b/tasks-jvm/src/main/java/org/funfix/tasks/jvm/TaskUtils.java index a7a0ec5..182033d 100644 --- a/tasks-jvm/src/main/java/org/funfix/tasks/jvm/TaskUtils.java +++ b/tasks-jvm/src/main/java/org/funfix/tasks/jvm/TaskUtils.java @@ -12,15 +12,15 @@ final class TaskUtils { static Task taskUninterruptibleBlockingIO( final DelayedFun func ) { - return Task.fromAsync((ec, callback) -> { + return Task.fromAsync((continuation) -> { try { - callback.onSuccess(func.invoke()); + continuation.onSuccess(func.invoke()); } catch (final InterruptedException e) { - callback.onCancellation(); - } catch (final Exception e) { - callback.onFailure(e); + continuation.onCancellation(); + } catch (final Throwable e) { + UncaughtExceptionHandler.rethrowIfFatal(e); + continuation.onFailure(e); } - return () -> {}; }); } diff --git a/tasks-jvm/src/test/java/org/funfix/tasks/jvm/CompletionCallbackTest.java b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/CompletionCallbackTest.java index e92e534..8539382 100644 --- a/tasks-jvm/src/test/java/org/funfix/tasks/jvm/CompletionCallbackTest.java +++ b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/CompletionCallbackTest.java @@ -52,7 +52,7 @@ public void onCancellation() { } @Override - public void onOutcome(Outcome outcome) { + public void onOutcome(Outcome outcome) { outcomeRef.set(outcome); called.incrementAndGet(); } @@ -91,7 +91,7 @@ public void onCancellation() { } @Override - public void onOutcome(Outcome outcome) { + public void onOutcome(Outcome outcome) { outcomeRef.set(outcome); called.incrementAndGet(); } @@ -139,7 +139,7 @@ public void onCancellation() { } @Override - public void onOutcome(Outcome outcome) { + public void onOutcome(Outcome outcome) { outcomeRef.set(outcome); called.incrementAndGet(); } diff --git a/tasks-jvm/src/test/java/org/funfix/tasks/jvm/ContinuationCancellationTest.java b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/ContinuationCancellationTest.java new file mode 100644 index 0000000..3464443 --- /dev/null +++ b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/ContinuationCancellationTest.java @@ -0,0 +1,161 @@ +package org.funfix.tasks.jvm; + +import org.junit.jupiter.api.Test; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class ContinuationCancellationTest { + @Test + void finalizerInvokedExactlyOnceWhenCancelledBeforeCompletion() + throws InterruptedException { + + final var counter = new AtomicInteger(); + final var registered = new CountDownLatch(1); + final var task = Task.fromAsync((continuation) -> { + continuation.invokeOnCancellation(() -> { + counter.incrementAndGet(); + registered.countDown(); + }); + }); + + final var fiber = task.runFiber(); + fiber.cancel(); + + TimedAwait.latchAndExpectCompletion(registered, "registered"); + assertEquals(1, counter.get()); + } + + @Test + void finalizerInvokedImmediatelyWhenRegisteredAfterCancellation() { + final var continuation = newContinuation(); + final var counter = new AtomicInteger(); + + continuation.cancel(); + continuation.invokeOnCancellation(counter::incrementAndGet); + + assertEquals(1, counter.get()); + } + + @Test + void finalizerNotInvokedAfterSuccess() { + final var continuation = newContinuation(); + final var counter = new AtomicInteger(); + + continuation.onSuccess("done"); + continuation.invokeOnCancellation(counter::incrementAndGet); + continuation.cancel(); + + assertEquals(0, counter.get()); + } + + @Test + void finalizerNotInvokedAfterFailure() { + final var continuation = newContinuation(); + final var counter = new AtomicInteger(); + + continuation.onFailure(new RuntimeException("boom")); + continuation.invokeOnCancellation(counter::incrementAndGet); + continuation.cancel(); + + assertEquals(0, counter.get()); + } + + @Test + void normalCompletionClearsFinalizersSoLaterCancelDoesNotRunCleanup() + throws Exception { + + final var counter = new AtomicInteger(); + final var task = Task.fromAsync((continuation) -> { + continuation.invokeOnCancellation(counter::incrementAndGet); + continuation.onSuccess("done"); + }); + + final var fiber = task.runFiber(); + assertEquals("done", fiber.awaitBlockingTimed(TestSettings.TIMEOUT)); + fiber.cancel(); + + assertEquals(0, counter.get()); + } + + @Test + void registrationRacingWithCancellationNeverMissesCleanup() + throws InterruptedException { + + final var counter = new AtomicInteger(); + final var finalizerInvoked = new CountDownLatch(1); + final var backgroundStarted = new CountDownLatch(1); + final var cancelIssued = new CountDownLatch(1); + final var backgroundDone = new CountDownLatch(1); + final var backgroundError = new AtomicReference(); + final var task = Task.fromAsync((continuation) -> { + final var thread = new Thread(() -> { + try { + backgroundStarted.countDown(); + TimedAwait.latchAndExpectCompletion(cancelIssued, "cancelIssued"); + continuation.invokeOnCancellation(() -> { + counter.incrementAndGet(); + finalizerInvoked.countDown(); + }); + } catch (final Throwable e) { + backgroundError.set(e); + } finally { + backgroundDone.countDown(); + } + }); + thread.start(); + }); + + final var fiber = task.runFiber(); + TimedAwait.latchAndExpectCompletion(backgroundStarted, "backgroundStarted"); + fiber.cancel(); + cancelIssued.countDown(); + + TimedAwait.latchAndExpectCompletion(finalizerInvoked, "finalizerInvoked"); + TimedAwait.latchAndExpectCompletion(backgroundDone, "backgroundDone"); + assertNull(backgroundError.get()); + assertEquals(1, counter.get()); + } + + @Test + void throwingFinalizerIsHandledWithoutPreventingOthers() { + final var continuation = newContinuation(); + final var counter = new AtomicInteger(); + final var reportedException = new AtomicReference(); + final var currentThread = Thread.currentThread(); + final var previousHandler = currentThread.getUncaughtExceptionHandler(); + final var thrown = new RuntimeException("boom"); + + currentThread.setUncaughtExceptionHandler((thread, error) -> reportedException.set(error)); + try { + continuation.invokeOnCancellation(counter::incrementAndGet); + continuation.invokeOnCancellation(() -> { + throw thrown; + }); + + continuation.cancel(); + } finally { + currentThread.setUncaughtExceptionHandler(previousHandler); + } + + assertEquals(1, counter.get()); + assertEquals(thrown, reportedException.get()); + } + + private static CancellableContinuation newContinuation() { + return new CancellableContinuation<>( + TaskExecutor.from(Runnable::run), + new ContinuationCallback<>() { + @Override + public void onOutcome(Outcome outcome) {} + + @Override + public void registerExtraCallback(CompletionCallback extraCallback) {} + } + ); + } +} diff --git a/tasks-jvm/src/test/java/org/funfix/tasks/jvm/FiberTests.java b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/FiberTests.java index d50ebda..eef1dd4 100644 --- a/tasks-jvm/src/test/java/org/funfix/tasks/jvm/FiberTests.java +++ b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/FiberTests.java @@ -296,7 +296,7 @@ void awaitAsyncHappyPath() throws InterruptedException { final var fiber = startFiber(() -> 1); final var result = new AtomicInteger(0); fiber.awaitAsync(outcome -> { - if (outcome instanceof Outcome.Success success) { + if (outcome instanceof Outcome.Success success) { //noinspection DataFlowIssue result.set(success.value()); latch.countDown(); @@ -320,7 +320,7 @@ void awaitAsyncCanFail() throws InterruptedException { }); final var result = new AtomicReference<@Nullable Throwable>(null); fiber.awaitAsync(outcome -> { - if (outcome instanceof Outcome.Failure failure) { + if (outcome instanceof Outcome.Failure failure) { result.set(failure.exception()); latch.countDown(); } else { @@ -353,9 +353,9 @@ void awaitAsyncSignalsCancellation() throws InterruptedException { final var awaitLatch = new CountDownLatch(1); fiber.awaitAsync(outcome -> { - if (outcome instanceof Outcome.Cancellation) { + if (outcome instanceof Outcome.Cancellation) { wasInterrupted.incrementAndGet(); - } else if (outcome instanceof Outcome.Failure failure) { + } else if (outcome instanceof Outcome.Failure failure) { UncaughtExceptionHandler.logOrRethrow(failure.exception()); } awaitLatch.countDown(); diff --git a/tasks-jvm/src/test/java/org/funfix/tasks/jvm/MutableCancellableTest.java b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/MutableCancellableTest.java new file mode 100644 index 0000000..30f752d --- /dev/null +++ b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/MutableCancellableTest.java @@ -0,0 +1,766 @@ +package org.funfix.tasks.jvm; + +import org.jspecify.annotations.Nullable; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.*; + +public class MutableCancellableTest { + + // ────────────────────────────────────────────────────────────────── + // Basic cancel() behavior + // ────────────────────────────────────────────────────────────────── + + @Test + void cancelInvokesRegisteredToken() { + final var called = new AtomicInteger(0); + final var mc = new MutableCancellable(called::incrementAndGet); + + mc.cancel(); + assertEquals(1, called.get()); + } + + @Test + void cancelIsIdempotent() { + final var called = new AtomicInteger(0); + final var mc = new MutableCancellable(called::incrementAndGet); + + mc.cancel(); + mc.cancel(); + mc.cancel(); + + assertEquals(1, called.get()); + } + + @Test + void cancelInvokesAllRegisteredTokens() { + final var mc = new MutableCancellable(); + final var order = Collections.synchronizedList(new ArrayList()); + + mc.register(() -> order.add(1)); + mc.register(() -> order.add(2)); + mc.register(() -> order.add(3)); + + mc.cancel(); + + // Tokens are stored as a stack (prepend), so the cancellation order is + // the reverse of registration: the last registered is cancelled first. + assertEquals(List.of(3, 2, 1), order); + } + + @Test + void cancelInvokesInitialTokenToo() { + final var order = Collections.synchronizedList(new ArrayList()); + final var mc = new MutableCancellable(() -> order.add(0)); + + mc.register(() -> order.add(1)); + mc.register(() -> order.add(2)); + + mc.cancel(); + + // Stack order: 2, 1, 0 (initial) + assertEquals(List.of(2, 1, 0), order); + } + + @Test + void emptyMutableCancellableCanBeCancelledSafely() { + final var mc = new MutableCancellable(); + // Should not throw + mc.cancel(); + mc.cancel(); + } + + // ────────────────────────────────────────────────────────────────── + // register() behavior + // ────────────────────────────────────────────────────────────────── + + @Test + void registerReturnsNonNullHandleInActiveState() { + final var mc = new MutableCancellable(); + final var handle = mc.register(() -> {}); + assertNotNull(handle); + } + + @Test + void registerAfterCancelledInvokesTokenImmediately() { + final var mc = new MutableCancellable(); + mc.cancel(); + + final var called = new AtomicInteger(0); + final var handle = mc.register(called::incrementAndGet); + + assertEquals(1, called.get()); + // Returns null when already cancelled + assertNull(handle); + } + + @Test + void registerAfterCompletedReturnsNullAndDoesNotInvoke() { + final var mc = new MutableCancellable(); + mc.complete(); + + final var called = new AtomicInteger(0); + final var handle = mc.register(called::incrementAndGet); + + assertEquals(0, called.get()); + assertNull(handle); + } + + @Test + void registerReturnsNullDuringCancelling() throws InterruptedException { + // A slow token blocks the cancelling thread, allowing us to + // observe the Cancelling state from another thread. + final var slowTokenStarted = new CountDownLatch(1); + final var slowTokenRelease = new CountDownLatch(1); + final var registrationDone = new CountDownLatch(1); + final var handleRef = new AtomicReference<@Nullable Cancellable>(); + final var errorRef = new AtomicReference<@Nullable Throwable>(); + + final var mc = new MutableCancellable(() -> { + slowTokenStarted.countDown(); + try { + TimedAwait.latchAndExpectCompletion(slowTokenRelease, "slowTokenRelease"); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + + // Start cancellation on a separate thread — it will block in the slow token + final var cancelThread = new Thread(mc::cancel); + cancelThread.start(); + + // Wait until the slow token is actually running + TimedAwait.latchAndExpectCompletion(slowTokenStarted, "slowTokenStarted"); + + // Now register from another thread while in Cancelling state + final var registerThread = new Thread(() -> { + try { + handleRef.set(mc.register(() -> {})); + } catch (Throwable e) { + errorRef.set(e); + } finally { + registrationDone.countDown(); + } + }); + registerThread.start(); + + // Let registrationDone complete (register should return quickly) + TimedAwait.latchAndExpectCompletion(registrationDone, "registrationDone"); + assertNull(errorRef.get()); + // During Cancelling state, register returns null (no unregister handle) + assertNull(handleRef.get()); + + // Release the slow token so cancel() completes + slowTokenRelease.countDown(); + cancelThread.join(TestSettings.TIMEOUT.toMillis()); + } + + // ────────────────────────────────────────────────────────────────── + // unregister() behavior + // ────────────────────────────────────────────────────────────────── + + @Test + void unregisterPreventsTokenFromBeingCancelled() { + final var mc = new MutableCancellable(); + final var called = new AtomicInteger(0); + + final var handle = mc.register(called::incrementAndGet); + assertNotNull(handle); + + // Unregister before cancel + handle.cancel(); + mc.cancel(); + + assertEquals(0, called.get()); + } + + @Test + void unregisterOnlyRemovesTargetToken() { + final var mc = new MutableCancellable(); + final var calls = Collections.synchronizedList(new ArrayList()); + + mc.register(() -> calls.add("A")); + final var handleB = mc.register(() -> calls.add("B")); + mc.register(() -> calls.add("C")); + + assertNotNull(handleB); + handleB.cancel(); + mc.cancel(); + + // B should be absent + assertEquals(List.of("C", "A"), calls); + } + + @Test + void unregisterIsIdempotent() { + final var mc = new MutableCancellable(); + final var called = new AtomicInteger(0); + final var handle = mc.register(called::incrementAndGet); + assertNotNull(handle); + + handle.cancel(); + handle.cancel(); // idempotent, should not throw + mc.cancel(); + + assertEquals(0, called.get()); + } + + @Test + void unregisterAfterCancelIsNoOp() { + final var mc = new MutableCancellable(); + final var called = new AtomicInteger(0); + final var handle = mc.register(called::incrementAndGet); + assertNotNull(handle); + + mc.cancel(); + assertEquals(1, called.get()); + + // unregister after cancel is a no-op + handle.cancel(); + assertEquals(1, called.get()); + } + + @Test + void unregisterAfterCompleteIsNoOp() { + final var mc = new MutableCancellable(); + final var called = new AtomicInteger(0); + final var handle = mc.register(called::incrementAndGet); + assertNotNull(handle); + + mc.complete(); + + // unregister after complete is a no-op + handle.cancel(); + assertEquals(0, called.get()); + } + + // ────────────────────────────────────────────────────────────────── + // complete() behavior + // ────────────────────────────────────────────────────────────────── + + @Test + void completePreventsCancelFromFiringTokens() { + final var called = new AtomicInteger(0); + final var mc = new MutableCancellable(called::incrementAndGet); + + mc.complete(); + mc.cancel(); + + assertEquals(0, called.get()); + } + + @Test + void completeIsIdempotent() { + final var mc = new MutableCancellable(); + mc.complete(); + mc.complete(); // should not throw + } + + @Test + void completeAfterCancelIsNoOp() { + final var mc = new MutableCancellable(); + mc.cancel(); + mc.complete(); // should not throw + } + + // ────────────────────────────────────────────────────────────────── + // Error handling + // ────────────────────────────────────────────────────────────────── + + @Test + void throwingTokenDoesNotPreventOtherTokensFromBeingCancelled() { + final var mc = new MutableCancellable(); + final var order = Collections.synchronizedList(new ArrayList()); + final var reportedErrors = Collections.synchronizedList(new ArrayList()); + + final var currentThread = Thread.currentThread(); + final var previousHandler = currentThread.getUncaughtExceptionHandler(); + currentThread.setUncaughtExceptionHandler((t, e) -> reportedErrors.add(e)); + + try { + mc.register(() -> order.add(1)); + mc.register(() -> { throw new RuntimeException("boom"); }); + mc.register(() -> order.add(3)); + + mc.cancel(); + } finally { + currentThread.setUncaughtExceptionHandler(previousHandler); + } + + // Both non-throwing tokens should have been cancelled + assertEquals(List.of(3, 1), order); + // The exception should have been reported + assertEquals(1, reportedErrors.size()); + assertEquals("boom", reportedErrors.get(0).getMessage()); + } + + @Test + void throwingTokenOnImmediateRegisterAfterCancelIsHandled() { + final var mc = new MutableCancellable(); + mc.cancel(); + + final var reportedErrors = Collections.synchronizedList(new ArrayList()); + final var currentThread = Thread.currentThread(); + final var previousHandler = currentThread.getUncaughtExceptionHandler(); + currentThread.setUncaughtExceptionHandler((t, e) -> reportedErrors.add(e)); + + try { + mc.register(() -> { throw new RuntimeException("late boom"); }); + } finally { + currentThread.setUncaughtExceptionHandler(previousHandler); + } + + assertEquals(1, reportedErrors.size()); + assertEquals("late boom", reportedErrors.get(0).getMessage()); + } + + // ────────────────────────────────────────────────────────────────── + // Concurrent behavior: tokens registered during Cancelling state + // are cancelled on the cancelling thread + // ────────────────────────────────────────────────────────────────── + + @Test + void tokenRegisteredDuringCancellingIsCancelledOnCancellingThread() + throws InterruptedException { + + // We use latches to orchestrate the following sequence: + // + // Cancel thread Register thread + // ───────────── ─────────────── + // 1. cancel() starts + // 2. slow token begins + // -> signals slowStarted + // 3. waits for slowStarted + // 4. register(newToken) + // -> signals registerDone + // 5. waits for registerDone + // 6. slow token completes + // 7. picks up newToken from + // Cancelling state + // 8. cancels newToken on + // THIS (cancel) thread + + final var slowStarted = new CountDownLatch(1); + final var registerDone = new CountDownLatch(1); + final var allDone = new CountDownLatch(1); + + final var cancellingThreadRef = new AtomicReference<@Nullable Thread>(); + final var newTokenThreadRef = new AtomicReference<@Nullable Thread>(); + final var errorRef = new AtomicReference<@Nullable Throwable>(); + + final var mc = new MutableCancellable(() -> { + // Record which thread is doing the cancellation + cancellingThreadRef.set(Thread.currentThread()); + slowStarted.countDown(); + try { + // Wait until the register thread has registered a new token + TimedAwait.latchAndExpectCompletion(registerDone, "registerDone"); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + + // The cancelling thread + final var cancelThread = new Thread(() -> { + try { + mc.cancel(); + } catch (Throwable e) { + errorRef.set(e); + } finally { + allDone.countDown(); + } + }, "cancel-thread"); + + // The register thread + final var registerThread = new Thread(() -> { + try { + TimedAwait.latchAndExpectCompletion(slowStarted, "slowStarted"); + mc.register(() -> { + // Record which thread invokes this new token's cancel + newTokenThreadRef.set(Thread.currentThread()); + }); + } catch (Throwable e) { + errorRef.set(e); + } finally { + registerDone.countDown(); + } + }, "register-thread"); + + cancelThread.start(); + registerThread.start(); + + TimedAwait.latchAndExpectCompletion(allDone, "allDone"); + cancelThread.join(TestSettings.TIMEOUT.toMillis()); + registerThread.join(TestSettings.TIMEOUT.toMillis()); + + assertNull(errorRef.get(), () -> "Unexpected error: " + errorRef.get()); + assertNotNull(cancellingThreadRef.get()); + assertNotNull(newTokenThreadRef.get()); + + // The key assertion: the new token is cancelled on the cancel thread, + // NOT on the register thread + assertSame( + cancellingThreadRef.get(), + newTokenThreadRef.get(), + "New token should be cancelled on the cancelling thread" + ); + } + + @Test + void multipleTokensRegisteredDuringCancellingAreAllCancelled() + throws InterruptedException { + + // Orchestration: + // Cancel thread runs cancel(), hits slow token which blocks. + // Register thread registers N new tokens while slow token is blocked. + // When slow token completes, the cancel thread picks up all N tokens. + + final int extraTokenCount = 5; + final var slowStarted = new CountDownLatch(1); + final var registerDone = new CountDownLatch(1); + final var allDone = new CountDownLatch(1); + + final var cancelledTokens = Collections.synchronizedList(new ArrayList()); + final var errorRef = new AtomicReference<@Nullable Throwable>(); + + final var mc = new MutableCancellable(() -> { + slowStarted.countDown(); + try { + TimedAwait.latchAndExpectCompletion(registerDone, "registerDone"); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + + final var cancelThread = new Thread(() -> { + try { + mc.cancel(); + } catch (Throwable e) { + errorRef.set(e); + } finally { + allDone.countDown(); + } + }, "cancel-thread"); + + final var registerThread = new Thread(() -> { + try { + TimedAwait.latchAndExpectCompletion(slowStarted, "slowStarted"); + for (int i = 0; i < extraTokenCount; i++) { + final int idx = i; + mc.register(() -> cancelledTokens.add(idx)); + } + } catch (Throwable e) { + errorRef.set(e); + } finally { + registerDone.countDown(); + } + }, "register-thread"); + + cancelThread.start(); + registerThread.start(); + + TimedAwait.latchAndExpectCompletion(allDone, "allDone"); + cancelThread.join(TestSettings.TIMEOUT.toMillis()); + registerThread.join(TestSettings.TIMEOUT.toMillis()); + + assertNull(errorRef.get(), () -> "Unexpected error: " + errorRef.get()); + assertEquals(extraTokenCount, cancelledTokens.size(), + "All tokens registered during Cancelling should be cancelled"); + } + + @Test + void tokensRegisteredInWavesDuringCancellingAreAllCancelled() + throws InterruptedException { + + // This test verifies the loop in startCancellationOfEverything: + // Wave 1: slow token blocks, register thread adds token A + // Wave 2: when slow token finishes and picks up A, A is also slow, + // and register thread adds token B during A's execution. + // + // All tokens (A, B) must be cancelled. + + final var wave1Started = new CountDownLatch(1); + final var wave1RegisterDone = new CountDownLatch(1); + final var wave2Started = new CountDownLatch(1); + final var wave2RegisterDone = new CountDownLatch(1); + final var allDone = new CountDownLatch(1); + + final var cancelledTokens = Collections.synchronizedList(new ArrayList()); + final var cancellingThread = new AtomicReference<@Nullable Thread>(); + final var tokenAThread = new AtomicReference<@Nullable Thread>(); + final var tokenBThread = new AtomicReference<@Nullable Thread>(); + final var errorRef = new AtomicReference<@Nullable Throwable>(); + + final var mc = new MutableCancellable(() -> { + // Wave 1: initial slow token + cancellingThread.set(Thread.currentThread()); + cancelledTokens.add("initial"); + wave1Started.countDown(); + try { + TimedAwait.latchAndExpectCompletion(wave1RegisterDone, "wave1RegisterDone"); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + + final var cancelThread = new Thread(() -> { + try { + mc.cancel(); + } catch (Throwable e) { + errorRef.set(e); + } finally { + allDone.countDown(); + } + }, "cancel-thread"); + + final var registerThread = new Thread(() -> { + try { + // Wave 1: register token A while initial token is slow + TimedAwait.latchAndExpectCompletion(wave1Started, "wave1Started"); + mc.register(() -> { + tokenAThread.set(Thread.currentThread()); + cancelledTokens.add("A"); + // Token A is also slow — wave 2 + wave2Started.countDown(); + try { + TimedAwait.latchAndExpectCompletion(wave2RegisterDone, "wave2RegisterDone"); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + wave1RegisterDone.countDown(); + + // Wave 2: register token B while token A is slow + TimedAwait.latchAndExpectCompletion(wave2Started, "wave2Started"); + mc.register(() -> { + tokenBThread.set(Thread.currentThread()); + cancelledTokens.add("B"); + }); + wave2RegisterDone.countDown(); + } catch (Throwable e) { + errorRef.set(e); + } + }, "register-thread"); + + cancelThread.start(); + registerThread.start(); + + TimedAwait.latchAndExpectCompletion(allDone, "allDone"); + cancelThread.join(TestSettings.TIMEOUT.toMillis()); + registerThread.join(TestSettings.TIMEOUT.toMillis()); + + assertNull(errorRef.get(), () -> "Unexpected error: " + errorRef.get()); + assertEquals(List.of("initial", "A", "B"), cancelledTokens); + + // All tokens should be cancelled on the cancel thread + assertSame(cancellingThread.get(), tokenAThread.get(), + "Token A should be cancelled on the cancelling thread"); + assertSame(cancellingThread.get(), tokenBThread.get(), + "Token B should be cancelled on the cancelling thread"); + } + + @Test + void concurrentCancelAndRegisterNeverLosesToken() + throws InterruptedException { + + // Stress-ish test: cancel and register race, but the token + // must always be invoked (either immediately by register when + // state is Cancelled, or picked up during Cancelling). + + for (int i = 0; i < TestSettings.CONCURRENCY_REPEATS; i++) { + final var mc = new MutableCancellable(); + final var called = new CountDownLatch(1); + + final var cancelThread = new Thread(mc::cancel); + final var registerThread = new Thread(() -> + mc.register(called::countDown) + ); + + cancelThread.start(); + registerThread.start(); + + TimedAwait.latchAndExpectCompletion(called, "called (iteration " + i + ")"); + cancelThread.join(TestSettings.TIMEOUT.toMillis()); + registerThread.join(TestSettings.TIMEOUT.toMillis()); + } + } + + @Test + void concurrentCancelAndCompleteDoNotDoubleInvoke() + throws InterruptedException { + + // When cancel() and complete() race, the token should be invoked + // at most once (by cancel if it wins, zero times if complete wins). + + for (int i = 0; i < TestSettings.CONCURRENCY_REPEATS; i++) { + final var called = new AtomicInteger(0); + final var mc = new MutableCancellable(called::incrementAndGet); + + final var cancelThread = new Thread(mc::cancel); + final var completeThread = new Thread(mc::complete); + + cancelThread.start(); + completeThread.start(); + + cancelThread.join(TestSettings.TIMEOUT.toMillis()); + completeThread.join(TestSettings.TIMEOUT.toMillis()); + + assertTrue(called.get() <= 1, + "Token should be invoked at most once, got: " + called.get()); + } + } + + @Test + void concurrentDoubleCancelDoesNotThrow() + throws InterruptedException { + + // Verifies that cancel() is truly idempotent even when a second + // cancel() call arrives while the first is still in the Cancelling + // state (draining tokens). The second call should return harmlessly. + + for (int i = 0; i < TestSettings.CONCURRENCY_REPEATS; i++) { + final var slowStarted = new CountDownLatch(1); + final var slowRelease = new CountDownLatch(1); + final var called = new AtomicInteger(0); + final var errorRef = new AtomicReference<@Nullable Throwable>(); + + final var mc = new MutableCancellable(() -> { + called.incrementAndGet(); + slowStarted.countDown(); + try { + TimedAwait.latchAndExpectCompletion(slowRelease, "slowRelease"); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + + final var cancelThread1 = new Thread(mc::cancel, "cancel-1"); + cancelThread1.start(); + + // Wait until the first cancel is in the middle of draining + TimedAwait.latchAndExpectCompletion(slowStarted, "slowStarted (iteration " + i + ")"); + + // Second cancel while in Cancelling state — must not throw + final var cancelThread2 = new Thread(() -> { + try { + mc.cancel(); + } catch (Throwable e) { + errorRef.set(e); + } + }, "cancel-2"); + cancelThread2.start(); + cancelThread2.join(TestSettings.TIMEOUT.toMillis()); + + assertNull(errorRef.get(), + "Second cancel() should not throw, got: " + errorRef.get()); + + // Let the first cancel finish + slowRelease.countDown(); + cancelThread1.join(TestSettings.TIMEOUT.toMillis()); + + // Token invoked exactly once + assertEquals(1, called.get()); + } + } + + // ────────────────────────────────────────────────────────────────── + // State transition: Cancelling -> Cancelled + // ────────────────────────────────────────────────────────────────── + + @Test + void afterCancelCompletesRegisterInvokesImmediately() throws InterruptedException { + // Ensure that once cancellation finishes (state = Cancelled), + // any subsequent register() invokes the token immediately. + + final var mc = new MutableCancellable(); + mc.cancel(); + + final var called = new AtomicInteger(0); + mc.register(called::incrementAndGet); + assertEquals(1, called.get()); + + // And again + mc.register(called::incrementAndGet); + assertEquals(2, called.get()); + } + + // ────────────────────────────────────────────────────────────────── + // Edge cases + // ────────────────────────────────────────────────────────────────── + + @Test + void registerThrowsOnNullToken() { + final var mc = new MutableCancellable(); + assertThrows(NullPointerException.class, () -> mc.register(null)); + } + + @Test + void throwingTokenDuringCancellingDoesNotPreventNewTokensFromBeingCancelled() + throws InterruptedException { + + // Initial token throws, but a token registered during Cancelling + // should still be cancelled. + + final var throwingStarted = new CountDownLatch(1); + final var registerDone = new CountDownLatch(1); + final var allDone = new CountDownLatch(1); + final var newTokenCalled = new AtomicInteger(0); + final var errorRef = new AtomicReference<@Nullable Throwable>(); + final var reportedErrors = Collections.synchronizedList(new ArrayList()); + + final var mc = new MutableCancellable(() -> { + throwingStarted.countDown(); + try { + TimedAwait.latchAndExpectCompletion(registerDone, "registerDone"); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + throw new RuntimeException("initial boom"); + }); + + final var cancelThread = new Thread(() -> { + final var th = Thread.currentThread(); + final var prev = th.getUncaughtExceptionHandler(); + th.setUncaughtExceptionHandler((t, e) -> reportedErrors.add(e)); + try { + mc.cancel(); + } catch (Throwable e) { + errorRef.set(e); + } finally { + th.setUncaughtExceptionHandler(prev); + allDone.countDown(); + } + }, "cancel-thread"); + + final var registerThread = new Thread(() -> { + try { + TimedAwait.latchAndExpectCompletion(throwingStarted, "throwingStarted"); + mc.register(newTokenCalled::incrementAndGet); + } catch (Throwable e) { + errorRef.set(e); + } finally { + registerDone.countDown(); + } + }, "register-thread"); + + cancelThread.start(); + registerThread.start(); + + TimedAwait.latchAndExpectCompletion(allDone, "allDone"); + cancelThread.join(TestSettings.TIMEOUT.toMillis()); + registerThread.join(TestSettings.TIMEOUT.toMillis()); + + assertNull(errorRef.get(), () -> "Unexpected error: " + errorRef.get()); + assertEquals(1, newTokenCalled.get(), + "New token should still be cancelled despite initial token throwing"); + assertEquals(1, reportedErrors.size()); + assertEquals("initial boom", reportedErrors.get(0).getMessage()); + } +} diff --git a/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskCreateTest.java b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskCreateTest.java index ee70253..a1a796f 100644 --- a/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskCreateTest.java +++ b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskCreateTest.java @@ -35,12 +35,11 @@ void successful() throws ExecutionException, InterruptedException, TimeoutExcept final var noErrors = new CountDownLatch(1); final var reportedException = new AtomicReference<@Nullable Throwable>(null); - final Task task = fromAsyncTask((executor, cb) -> { - cb.onSuccess("Hello, world!"); + final Task task = fromAsyncTask((continuation) -> { + continuation.onSuccess("Hello, world!"); // callback is idempotent - cb.onSuccess("Hello, world! (2)"); + continuation.onSuccess("Hello, world! (2)"); noErrors.countDown(); - return Cancellable.getEmpty(); }); final String result = task.runBlockingTimed(TIMEOUT); @@ -53,13 +52,12 @@ void successful() throws ExecutionException, InterruptedException, TimeoutExcept void failed() throws InterruptedException { final var noErrors = new CountDownLatch(1); final var reportedException = new AtomicReference<@Nullable Throwable>(null); - final Task task = fromAsyncTask((executor, cb) -> { + final Task task = fromAsyncTask((continuation) -> { Thread.setDefaultUncaughtExceptionHandler((t, ex) -> reportedException.set(ex)); - cb.onFailure(new RuntimeException("Sample exception")); + continuation.onFailure(new RuntimeException("Sample exception")); // callback is idempotent - cb.onFailure(new RuntimeException("Sample exception (2)")); + continuation.onFailure(new RuntimeException("Sample exception (2)")); noErrors.countDown(); - return Cancellable.getEmpty(); }); try { if (executor != null) @@ -85,12 +83,14 @@ void cancelled() throws InterruptedException, ExecutionException, Fiber.NotCompl final var noErrors = new CountDownLatch(1); final var reportedException = new AtomicReference<@Nullable Throwable>(null); - final Task task = fromAsyncTask((executor, cb) -> () -> { - Thread.setDefaultUncaughtExceptionHandler((t, ex) -> reportedException.set(ex)); - cb.onCancellation(); - // callback is idempotent - cb.onCancellation(); - noErrors.countDown(); + final Task task = fromAsyncTask((continuation) -> { + continuation.invokeOnCancellation(() -> { + Thread.setDefaultUncaughtExceptionHandler((t, ex) -> reportedException.set(ex)); + continuation.onCancellation(); + // callback is idempotent + continuation.onCancellation(); + noErrors.countDown(); + }); }); final var fiber = diff --git a/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskEnsureExecutorTest.java b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskEnsureExecutorTest.java index 917d68c..bea5207 100644 --- a/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskEnsureExecutorTest.java +++ b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskEnsureExecutorTest.java @@ -82,10 +82,10 @@ void switchesBackOnCallbackForRunAsync() throws InterruptedException { Task.fromBlockingIO(() -> Thread.currentThread().getName()) .ensureRunningOnExecutor(ec1) .runAsync(ec2, (CompletionCallback) outcome -> { - if (outcome instanceof Outcome.Success value) { + if (outcome instanceof Outcome.Success value) { threadName1.set(value.value()); threadName2.set(Thread.currentThread().getName()); - } else if (outcome instanceof Outcome.Failure f) { + } else if (outcome instanceof Outcome.Failure f) { UncaughtExceptionHandler.logOrRethrow(f.exception()); } isDone.countDown(); @@ -118,7 +118,7 @@ void testEnsureExecutorOnFiberAfterCompletion() throws ExecutionException, Inter ); final var isComplete = new CountDownLatch(1); - final var r2 = new AtomicReference<@Nullable Outcome>(null); + final var r2 = new AtomicReference<@Nullable Outcome>(null); fiber.awaitAsync(outcome -> { r2.set(outcome); isComplete.countDown(); diff --git a/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskExecuteTest.java b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskExecuteTest.java index 126188a..a85bef8 100644 --- a/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskExecuteTest.java +++ b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskExecuteTest.java @@ -17,7 +17,7 @@ void runAsyncWorksForSuccess() throws InterruptedException, TaskCancellationExce final var latch = new CountDownLatch(1); final var outcomeRef = - new AtomicReference<@Nullable Outcome>(null); + new AtomicReference<@Nullable Outcome>(null); final var task = Task.fromBlockingIO(() -> "Hello!"); task.runAsync(outcome -> { @@ -35,7 +35,7 @@ void runAsyncWorksForFailure() throws InterruptedException, TaskCancellationExce final var latch = new CountDownLatch(1); final var outcomeRef = - new AtomicReference<@Nullable Outcome>(null); + new AtomicReference<@Nullable Outcome>(null); final var expectedError = new RuntimeException("Error"); diff --git a/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskFromBlockingIOTest.java b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskFromBlockingIOTest.java index bc7dd8a..a116378 100644 --- a/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskFromBlockingIOTest.java +++ b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskFromBlockingIOTest.java @@ -7,6 +7,7 @@ import java.util.Objects; import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import static org.funfix.tasks.jvm.TestSettings.TIMEOUT; @@ -67,24 +68,84 @@ public void canFail() throws InterruptedException { } @Test - public void isCancellable() throws InterruptedException, ExecutionException, Fiber.NotCompletedException { + public void isCancellable() throws InterruptedException, ExecutionException, Fiber.NotCompletedException, TimeoutException { Objects.requireNonNull(executor); - final var latch = new CountDownLatch(1); + final var wasStarted = new CountDownLatch(1); + final var wasInterrupted = new CountDownLatch(1); + final var interrupted = new AtomicBoolean(false); @SuppressWarnings("NullAway") final var task = Task.fromBlockingIO(() -> { - latch.countDown(); - Thread.sleep(30000); - return null; + wasStarted.countDown(); + try { + Thread.sleep(5000); + return null; + } catch (final InterruptedException e) { + interrupted.set(true); + wasInterrupted.countDown(); + throw e; + } }); final var fiber = task.runFiber(executor); - TimedAwait.latchAndExpectCompletion(latch, "latch"); + TimedAwait.latchAndExpectCompletion(wasStarted, "wasStarted"); fiber.cancel(); - fiber.joinBlocking(); + fiber.joinBlockingTimed(TIMEOUT); try { fiber.getResultOrThrow(); fail("Should have thrown a CancellationException"); } catch (final TaskCancellationException ignored) {} + TimedAwait.latchAndExpectCompletion(wasInterrupted, "wasInterrupted"); + assertTrue(interrupted.get(), "Blocking thread should have been interrupted"); + } + + @Test + public void runBlockingSuccessDoesNotLeaveCurrentThreadInterrupted() + throws ExecutionException, InterruptedException { + + Objects.requireNonNull(executor); + Thread.interrupted(); + + final var result = Task.fromBlockingIO(() -> "Hello, world!") + .runBlocking(executor); + + assertEquals("Hello, world!", result); + assertFalse(Thread.currentThread().isInterrupted()); + } + + @Test + public void cancellationBeforeBlockingWorkStartsIsHandled() + throws InterruptedException, ExecutionException, Fiber.NotCompletedException, TimeoutException { + + final var gate = new CountDownLatch(1); + final var blockingWorkStarted = new AtomicBoolean(false); + final ExecutorService queuedExecutor = Executors.newSingleThreadExecutor(); + try { + final Future blocker = queuedExecutor.submit(() -> { + try { + TimedAwait.latchAndExpectCompletion(gate, "gate"); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + + final var fiber = Task.fromBlockingIO(() -> { + blockingWorkStarted.set(true); + return "unexpected"; + }).runFiber(queuedExecutor); + + fiber.cancel(); + gate.countDown(); + TimedAwait.future(blocker); + fiber.joinBlockingTimed(TIMEOUT); + + try { + fiber.getResultOrThrow(); + fail("Should have thrown a TaskCancellationException"); + } catch (final TaskCancellationException ignored) {} + assertFalse(blockingWorkStarted.get(), "Blocking work should not have started"); + } finally { + queuedExecutor.shutdownNow(); + } } } diff --git a/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskFromCancellableFutureTest.java b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskFromCancellableFutureTest.java new file mode 100644 index 0000000..6084830 --- /dev/null +++ b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskFromCancellableFutureTest.java @@ -0,0 +1,85 @@ +package org.funfix.tasks.jvm; + +import org.junit.jupiter.api.Test; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.funfix.tasks.jvm.TestSettings.TIMEOUT; +import static org.junit.jupiter.api.Assertions.*; + +public class TaskFromCancellableFutureTest { + @Test + void cancellationInvokesFutureCancelAndTokenCancel() + throws InterruptedException, ExecutionException, Fiber.NotCompletedException, TimeoutException { + + final var future = new CompletableFuture(); + final var tokenCancelCount = new AtomicInteger(); + final var tokenCancelled = new CountDownLatch(1); + final Cancellable token = () -> { + tokenCancelCount.incrementAndGet(); + tokenCancelled.countDown(); + }; + + final var fiber = Task.fromCancellableFuture(() -> new CancellableFuture<>(future, token)) + .runFiber(); + + fiber.cancel(); + fiber.joinBlockingTimed(TIMEOUT); + + try { + fiber.getResultOrThrow(); + fail("Should have thrown a TaskCancellationException"); + } catch (final TaskCancellationException ignored) {} + TimedAwait.latchAndExpectCompletion(tokenCancelled, "tokenCancelled"); + assertTrue(future.isCancelled(), "Future should have been cancelled"); + assertEquals(1, tokenCancelCount.get()); + } + + @Test + void cancellationBeforeFutureRegistrationInvokesCleanupImmediately() + throws InterruptedException, ExecutionException, Fiber.NotCompletedException, TimeoutException { + + final var future = new CompletableFuture(); + final var tokenCancelCount = new AtomicInteger(); + final var tokenCancelled = new CountDownLatch(1); + final var builderStarted = new CountDownLatch(1); + final var cancellationIssued = new CountDownLatch(1); + final var builderReturned = new AtomicBoolean(false); + final Cancellable token = () -> { + tokenCancelCount.incrementAndGet(); + tokenCancelled.countDown(); + }; + final ExecutorService executor = Executors.newSingleThreadExecutor(); + try { + final var fiber = Task.fromCancellableFuture(() -> { + builderStarted.countDown(); + TimedAwait.latchAndExpectCompletion(cancellationIssued, "cancellationIssued"); + builderReturned.set(true); + return new CancellableFuture<>(future, token); + }).runFiber(executor); + + TimedAwait.latchAndExpectCompletion(builderStarted, "builderStarted"); + fiber.cancel(); + cancellationIssued.countDown(); + TimedAwait.latchAndExpectCompletion(tokenCancelled, "tokenCancelled"); + fiber.joinBlockingTimed(TIMEOUT); + + try { + fiber.getResultOrThrow(); + fail("Should have thrown a TaskCancellationException"); + } catch (final TaskCancellationException ignored) {} + assertTrue(builderReturned.get(), "Builder should have returned after cancellation"); + assertTrue(future.isCancelled(), "Future should have been cancelled"); + assertEquals(1, tokenCancelCount.get()); + } finally { + executor.shutdownNow(); + } + } +} diff --git a/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskWithCancellationTest.java b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskWithCancellationTest.java index f31e13e..82856f4 100644 --- a/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskWithCancellationTest.java +++ b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskWithCancellationTest.java @@ -15,7 +15,7 @@ public class TaskWithCancellationTest { void testTaskWithCancellation() throws InterruptedException { for (int r = 0; r < TestSettings.CONCURRENCY_REPEATS; r++) { final var cancelTokensRef = new ConcurrentLinkedQueue(); - final var outcomeRef = new AtomicReference<@Nullable Outcome>(null); + final var outcomeRef = new AtomicReference<@Nullable Outcome>(null); final var startedLatch = new CountDownLatch(1); final var taskLatch = new CountDownLatch(1); diff --git a/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskWithOnCompletionTest.java b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskWithOnCompletionTest.java index 2cc57ab..0f3f8a7 100644 --- a/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskWithOnCompletionTest.java +++ b/tasks-jvm/src/test/java/org/funfix/tasks/jvm/TaskWithOnCompletionTest.java @@ -18,8 +18,8 @@ public class TaskWithOnCompletionTest { @Test void guaranteeOnSuccess() throws ExecutionException, InterruptedException { for (int t = 0; t < CONCURRENCY_REPEATS; t++) { - final var ref1 = new AtomicReference<@Nullable Outcome>(null); - final var ref2 = new AtomicReference<@Nullable Outcome>(null); + final var ref1 = new AtomicReference<@Nullable Outcome>(null); + final var ref2 = new AtomicReference<@Nullable Outcome>(null); final var outcome1 = Task .fromBlockingIO(() -> "Success") .withOnComplete(ref1::set) @@ -35,8 +35,8 @@ void guaranteeOnSuccess() throws ExecutionException, InterruptedException { @Test void guaranteeOnSuccessWithFibers() throws ExecutionException, InterruptedException, TaskCancellationException { for (int t = 0; t < CONCURRENCY_REPEATS; t++) { - final var ref1 = new AtomicReference<@Nullable Outcome>(null); - final var ref2 = new AtomicReference<@Nullable Outcome>(null); + final var ref1 = new AtomicReference<@Nullable Outcome>(null); + final var ref2 = new AtomicReference<@Nullable Outcome>(null); final var fiber = Task .fromBlockingIO(() -> "Success") .withOnComplete(ref1::set) @@ -52,8 +52,8 @@ void guaranteeOnSuccessWithFibers() throws ExecutionException, InterruptedExcept @Test void guaranteeOnSuccessWithBlockingIO() throws ExecutionException, InterruptedException, TimeoutException { for (int t = 0; t < CONCURRENCY_REPEATS; t++) { - final var ref1 = new AtomicReference<@Nullable Outcome>(null); - final var ref2 = new AtomicReference<@Nullable Outcome>(null); + final var ref1 = new AtomicReference<@Nullable Outcome>(null); + final var ref2 = new AtomicReference<@Nullable Outcome>(null); final var r = Task .fromBlockingIO(() -> "Success") .withOnComplete(ref1::set) @@ -70,8 +70,8 @@ void guaranteeOnSuccessWithBlockingIO() throws ExecutionException, InterruptedEx @Test void guaranteeOnFailure() throws InterruptedException { for (int t = 0; t < CONCURRENCY_REPEATS; t++) { - final var ref1 = new AtomicReference<@Nullable Outcome>(null); - final var ref2 = new AtomicReference<@Nullable Outcome>(null); + final var ref1 = new AtomicReference<@Nullable Outcome>(null); + final var ref2 = new AtomicReference<@Nullable Outcome>(null); final var error = new RuntimeException("Failure"); try { @@ -94,8 +94,8 @@ void guaranteeOnFailure() throws InterruptedException { @Test void guaranteeOnFailureWithFibers() throws InterruptedException, TaskCancellationException { for (int t = 0; t < CONCURRENCY_REPEATS; t++) { - final var ref1 = new AtomicReference<@Nullable Outcome>(null); - final var ref2 = new AtomicReference<@Nullable Outcome>(null); + final var ref1 = new AtomicReference<@Nullable Outcome>(null); + final var ref2 = new AtomicReference<@Nullable Outcome>(null); final var error = new RuntimeException("Failure"); try { @@ -119,8 +119,8 @@ void guaranteeOnFailureWithFibers() throws InterruptedException, TaskCancellatio @Test void guaranteeOnFailureBlockingIO() throws InterruptedException { for (int t = 0; t < CONCURRENCY_REPEATS; t++) { - final var ref1 = new AtomicReference<@Nullable Outcome>(null); - final var ref2 = new AtomicReference<@Nullable Outcome>(null); + final var ref1 = new AtomicReference<@Nullable Outcome>(null); + final var ref2 = new AtomicReference<@Nullable Outcome>(null); final var error = new RuntimeException("Failure"); try { @@ -143,8 +143,8 @@ void guaranteeOnFailureBlockingIO() throws InterruptedException { @Test void guaranteeOnCancellation() throws InterruptedException, ExecutionException, TimeoutException { for (int t = 0; t < CONCURRENCY_REPEATS; t++) { - final var ref1 = new AtomicReference<@Nullable Outcome>(null); - final var ref2 = new AtomicReference<@Nullable Outcome>(null); + final var ref1 = new AtomicReference<@Nullable Outcome>(null); + final var ref2 = new AtomicReference<@Nullable Outcome>(null); final var latch = new CountDownLatch(1); final var task = Task diff --git a/tasks-kotlin-coroutines/src/jvmMain/kotlin/org/funfix/tasks/kotlin/coroutines.jvm.kt b/tasks-kotlin-coroutines/src/jvmMain/kotlin/org/funfix/tasks/kotlin/coroutines.jvm.kt index 7b5929d..ed7b0e4 100644 --- a/tasks-kotlin-coroutines/src/jvmMain/kotlin/org/funfix/tasks/kotlin/coroutines.jvm.kt +++ b/tasks-kotlin-coroutines/src/jvmMain/kotlin/org/funfix/tasks/kotlin/coroutines.jvm.kt @@ -55,26 +55,26 @@ public suspend fun Task.runSuspending( public fun suspendAsTask( coroutineContext: CoroutineContext = EmptyCoroutineContext, block: suspend () -> T -): Task = Task.fromAsync { executor, callback -> +): Task = Task.fromAsync { continuation -> val job = GlobalScope.launch( - executor.asCoroutineDispatcher() + coroutineContext + continuation.executor.asCoroutineDispatcher() + coroutineContext ) { try { val r = block() - callback.onSuccess(r) + continuation.onSuccess(r) } catch (e: Throwable) { UncaughtExceptionHandler.rethrowIfFatal(e) when (e) { is CancellationException, is TaskCancellationException, is InterruptedException -> - callback.onCancellation() + continuation.onCancellation() else -> - callback.onFailure(e) + continuation.onFailure(e) } } } - Cancellable { + continuation.invokeOnCancellation { job.cancel() } } @@ -95,7 +95,7 @@ internal class CoroutineAsCompletionCallback( false } - override fun onOutcome(outcome: Outcome) { + override fun onOutcome(outcome: Outcome) { when (outcome) { is Outcome.Success -> onSuccess(outcome.value()) is Outcome.Failure -> onFailure(outcome.exception()) diff --git a/tasks-kotlin-coroutines/src/jvmTest/kotlin/org/funfix/tasks/kotlin/CoroutinesJvmTest.kt b/tasks-kotlin-coroutines/src/jvmTest/kotlin/org/funfix/tasks/kotlin/CoroutinesJvmTest.kt index 6c6c84d..01efd32 100644 --- a/tasks-kotlin-coroutines/src/jvmTest/kotlin/org/funfix/tasks/kotlin/CoroutinesJvmTest.kt +++ b/tasks-kotlin-coroutines/src/jvmTest/kotlin/org/funfix/tasks/kotlin/CoroutinesJvmTest.kt @@ -5,7 +5,6 @@ import java.util.concurrent.Executors import kotlin.test.* import kotlinx.coroutines.* import kotlinx.coroutines.test.runTest -import org.funfix.tasks.jvm.Cancellable import org.funfix.tasks.jvm.Outcome import org.funfix.tasks.jvm.Task import org.funfix.tasks.jvm.TaskCancellationException @@ -13,9 +12,8 @@ import org.funfix.tasks.jvm.TaskCancellationException class CoroutinesJvmTest { @Test fun `runSuspending signals Success`() = runTest { - val task = Task.fromAsync { _, cb -> - cb.onSuccess(42) - Cancellable {} + val task = Task.fromAsync { continuation -> + continuation.onSuccess(42) } val result = task.runSuspending() @@ -26,9 +24,8 @@ class CoroutinesJvmTest { @Test fun `runSuspending signals Failure`() = runTest { val ex = RuntimeException("Boom") - val task = Task.fromAsync { _, cb -> - cb.onFailure(ex) - Cancellable {} + val task = Task.fromAsync { continuation -> + continuation.onFailure(ex) } val thrown = assertFailsWith { task.runSuspending() } @@ -41,11 +38,11 @@ class CoroutinesJvmTest { fun `runSuspending cancels the task token`() = runTest { val cancelled = CompletableDeferred() val started = CompletableDeferred() - val task = Task.fromAsync { _, cb -> + val task = Task.fromAsync { continuation -> started.complete(Unit) - Cancellable { + continuation.invokeOnCancellation { cancelled.complete(Unit) - cb.onCancellation() + continuation.onCancellation() } } @@ -62,7 +59,7 @@ class CoroutinesJvmTest { val task = suspendAsTask { 21 + 21 } - val deferred = CompletableDeferred>() + val deferred = CompletableDeferred>() task.runAsync { outcome -> deferred.complete(outcome) } assertEquals(Outcome.Success(42), deferred.await()) @@ -74,7 +71,7 @@ class CoroutinesJvmTest { val task = suspendAsTask { throw ex } - val deferred = CompletableDeferred>() + val deferred = CompletableDeferred>() task.runAsync { outcome -> deferred.complete(outcome) } assertEquals(Outcome.Failure(ex), deferred.await()) @@ -87,7 +84,7 @@ class CoroutinesJvmTest { started.complete(Unit) awaitCancellation() } - val deferred = CompletableDeferred>() + val deferred = CompletableDeferred>() val cancel = task.runAsync { outcome -> deferred.complete(outcome) } started.await() @@ -119,9 +116,8 @@ class CoroutinesJvmTest { @Test fun `runSuspending translates async callback success`() = runTest { - val task = Task.fromAsync { _, cb -> - cb.onSuccess(42) - Cancellable {} + val task = Task.fromAsync { continuation -> + continuation.onSuccess(42) } assertEquals(42, task.runSuspending()) @@ -130,9 +126,8 @@ class CoroutinesJvmTest { @Test fun `runSuspending translates async callback failure`() = runTest { val ex = RuntimeException("Boom") - val task = Task.fromAsync { _, cb -> - cb.onFailure(ex) - Cancellable {} + val task = Task.fromAsync { continuation -> + continuation.onFailure(ex) } val thrown = assertFailsWith { task.runSuspending() } @@ -142,9 +137,8 @@ class CoroutinesJvmTest { @Test fun `runSuspending resumes with task cancellation`() = runTest { - val task = Task.fromAsync { _, cb -> - cb.onCancellation() - Cancellable {} + val task = Task.fromAsync { continuation -> + continuation.onCancellation() } assertFailsWith { task.runSuspending() } @@ -153,7 +147,7 @@ class CoroutinesJvmTest { @Test fun `runSuspending forwards runAsync failure`() = runTest { val ex = RuntimeException("Boom") - val task = Task.fromAsync { _, _ -> + val task = Task.fromAsync { throw ex } @@ -192,7 +186,7 @@ class CoroutinesJvmTest { val task = suspendAsTask { throw CancellationException("cancelled") } - val deferred = CompletableDeferred>() + val deferred = CompletableDeferred>() task.runAsync { outcome -> deferred.complete(outcome) } assertEquals(Outcome.Cancellation(), deferred.await()) @@ -203,7 +197,7 @@ class CoroutinesJvmTest { val task = suspendAsTask { throw TaskCancellationException("stop") } - val deferred = CompletableDeferred>() + val deferred = CompletableDeferred>() task.runAsync { outcome -> deferred.complete(outcome) } assertEquals(Outcome.Cancellation(), deferred.await()) @@ -214,7 +208,7 @@ class CoroutinesJvmTest { val task = suspendAsTask { throw InterruptedException("stop") } - val deferred = CompletableDeferred>() + val deferred = CompletableDeferred>() task.runAsync { outcome -> deferred.complete(outcome) } assertEquals(Outcome.Cancellation(), deferred.await())