package io.catbird.util.effect import cats.effect.{ ContextShift, IO } import com.twitter.util.{ ExecutorServiceFuturePool, Future, FuturePool } import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuite import scala.concurrent.ExecutionContext class ContextShiftingSuite extends FixtureAnyFunSuite with ThreadPoolNamingSupport { protected final class FixtureParam { val ioPoolName = "io-pool" val futurePoolName = "future-pool" val ioPool = newNamedThreadPool(ioPoolName) val futurePool: ExecutorServiceFuturePool = // threadpool of Future (often managed by a library like finagle-http) FuturePool(newNamedThreadPool(futurePoolName)) def newIO: IO[String] = IO(currentThreadName()) def newFuture: Future[String] = futurePool.apply { // Not 100% sure why but this sleep is needed to reproduce the error. There might be an optimization if the // Future is already resolved at some point Thread.sleep(200) currentThreadName() } } test("After resolving the Future with futureToAsync stay on the Future threadpool") { f => implicit val contextShift: ContextShift[IO] = // threadpool of IO (often provided by IOApp) IO.contextShift(ExecutionContext.fromExecutor(f.ioPool)) val (futurePoolName, ioPoolName) = (for { futurePoolName <- futureToAsync[IO, String](f.newFuture) ioPoolName <- f.newIO } yield (futurePoolName, ioPoolName)).start(contextShift).flatMap(_.join).unsafeRunSync() assert(futurePoolName == f.futurePoolName) assert(ioPoolName == f.futurePoolName) // Uh oh, this is likely not what the user wants } test("After resolving the Future with futureToAsyncAndShift shift back to the threadpool of ContextShift[F]") { f => implicit val contextShift: ContextShift[IO] = // threadpool of IO (often provided by IOApp) IO.contextShift(ExecutionContext.fromExecutor(f.ioPool)) // If you'd use `futureToAsync` here instead, this whole thing would sometimes stay on the future-pool val (futurePoolName, ioPoolName) = (for { futurePoolName <- futureToAsyncAndShift[IO, String](f.newFuture) ioPoolName <- f.newIO } yield (futurePoolName, ioPoolName)) .start(contextShift) // start the computation on the default threadpool... .flatMap(_.join) // ...then block until we have the results .unsafeRunSync() assert(futurePoolName == f.futurePoolName) assert(ioPoolName == f.ioPoolName) } override protected def withFixture(test: OneArgTest): Outcome = withFixture(test.toNoArgTest(new FixtureParam)) }