package com.johnsnowlabs.nlp.pretrained import com.johnsnowlabs.nlp.LightPipeline import org.apache.spark.ml.PipelineModel import org.apache.spark.sql.DataFrame case class PretrainedPipeline( downloadName: String, lang: String = "en", source: String = ResourceDownloader.publicLoc, parseEmbeddingsVectors: Boolean = false, diskLocation: Option[String] = None ) { /** Support for java default argument interoperability */ def this(downloadName: String) { this(downloadName, "en", ResourceDownloader.publicLoc) } def this(downloadName: String, lang: String) { this(downloadName, lang, ResourceDownloader.publicLoc) } val model: PipelineModel = if (diskLocation.isEmpty) { ResourceDownloader .downloadPipeline(downloadName, Option(lang), source) } else { PipelineModel.load(diskLocation.get) } lazy val lightModel = new LightPipeline(model, parseEmbeddingsVectors) def annotate(dataset: DataFrame, inputColumn: String): DataFrame = { model .transform(dataset.withColumnRenamed(inputColumn, "text")) } def annotate(target: String): Map[String, Seq[String]] = lightModel.annotate(target) def annotate(target: Array[String]): Array[Map[String, Seq[String]]] = lightModel.annotate(target) def transform(dataFrame: DataFrame): DataFrame = model.transform(dataFrame) } object PretrainedPipeline { def fromDisk(path: String, parseEmbeddings: Boolean = false): PretrainedPipeline = { PretrainedPipeline(null, null, null, parseEmbeddings, Some(path)) } def fromDisk(path: String): PretrainedPipeline = { fromDisk(path, false) } }