package polynote.server import cats.instances.list._ import cats.syntax.traverse._ import fs2.concurrent.Topic import polynote.buildinfo.BuildInfo import polynote.kernel.util.Publish import polynote.kernel.{BaseEnv, StreamThrowableOps} import polynote.kernel.environment.{Env, PublishMessage, Config} import polynote.kernel.interpreter.Interpreter import polynote.kernel.logging.Logging import polynote.messages._ import polynote.server.auth.IdentityProvider.checkPermission import polynote.server.auth.{IdentityProvider, Permission, UserIdentity} import uzhttp.websocket.Frame import zio.stream.ZStream import zio.stream.{Stream, Take} import zio.Queue import zio.{Promise, RIO, Task, URIO, ZIO} import scala.collection.immutable.SortedMap object SocketSession { def apply(in: Stream[Throwable, Frame], broadcastAll: Topic[Task, Option[Message]]): URIO[SessionEnv with NotebookManager, Stream[Throwable, Frame]] = for { output <- Queue.unbounded[Take[Nothing, Message]] publishMessage <- Env.add[SessionEnv with NotebookManager](Publish(output): Publish[Task, Message]) env <- ZIO.environment[SessionEnv with NotebookManager with PublishMessage] closed <- Promise.make[Throwable, Unit] _ <- broadcastAll.subscribe(32).unNone.interruptAndIgnoreWhen(closed).through(publishMessage.publish).compile.drain.forkDaemon close = closeQueueIf(closed, output) } yield parallelStreams( toFrames(ZStream.fromEffect(handshake) ++ Stream.fromQueue(output).unTake), in.handleMessages(close)(handler andThen errorHandler) ++ closeStream(closed, output), keepaliveStream(closed)).provide(env).catchAllCause { cause => ZStream.empty } private val handler: Message => RIO[SessionEnv with PublishMessage with NotebookManager, Option[Message]] = { case ListNotebooks(_) => NotebookManager.list().map { notebooks => Some(ListNotebooks(notebooks.map(ShortString.apply))) } case CreateNotebook(path, maybeContent) => NotebookManager.assertValidPath(path) *> checkPermission(Permission.CreateNotebook(path)) *> NotebookManager.create(path, maybeContent).as(None) case RenameNotebook(path, newPath) => (NotebookManager.assertValidPath(path) &> NotebookManager.assertValidPath(newPath)) *> checkPermission(Permission.CreateNotebook(newPath)) *> checkPermission(Permission.DeleteNotebook(path)) *> NotebookManager.rename(path, newPath).as(None) case CopyNotebook(path, newPath) => (NotebookManager.assertValidPath(path) &> NotebookManager.assertValidPath(newPath)) *> checkPermission(Permission.CreateNotebook(newPath)) *> NotebookManager.copy(path, newPath).as(None) case DeleteNotebook(path) => NotebookManager.assertValidPath(path) *> checkPermission(Permission.DeleteNotebook(path)) *> NotebookManager.delete(path).as(None) case RunningKernels(_) => for { paths <- NotebookManager.listRunning() statuses <- ZIO.collectAllPar(paths.map(NotebookManager.status)) kernelStatuses = paths.zip(statuses).map { case (p, s) => ShortString(p) -> s } } yield Some(RunningKernels(kernelStatuses)) case other => ZIO.succeed(None) } val errorHandler: RIO[SessionEnv with PublishMessage with NotebookManager, Option[Message]] => RIO[SessionEnv with PublishMessage with NotebookManager, Option[Message]] = _.catchAll { err => Logging.error(err).as(Some(Error(0, err))) } def handshake: RIO[SessionEnv, ServerHandshake] = for { factories <- Interpreter.Factories.access identity <- UserIdentity.access config <- Config.access } yield ServerHandshake( (SortedMap.empty[String, String] ++ factories.mapValues(_.head.languageName)).asInstanceOf[TinyMap[TinyString, TinyString]], serverVersion = BuildInfo.version, serverCommit = BuildInfo.commit, identity = identity.map(i => Identity(i.name, i.avatar.map(ShortString))), sparkTemplates = config.spark.flatMap(_.propertySets).getOrElse(Nil) ) }