import logging import os import re import subprocess from six.moves import shlex_quote from mlflow.models import FlavorBackend from mlflow.tracking.artifact_utils import _download_artifact_from_uri _logger = logging.getLogger(__name__) class RFuncBackend(FlavorBackend): """ Flavor backend implementation for the generic R models. Predict and serve locally models with 'crate' flavor. """ version_pattern = re.compile("version ([0-9]+[.][0-9]+[.][0-9]+)") def predict(self, model_uri, input_path, output_path, content_type, json_format): """ Generate predictions using R model saved with MLflow. Return the prediction results as a JSON. """ model_path = _download_artifact_from_uri(model_uri) str_cmd = "mlflow:::mlflow_rfunc_predict(model_path = '{0}', input_path = {1}, " \ "output_path = {2}, content_type = {3}, json_format = {4})" command = str_cmd.format(shlex_quote(model_path), _str_optional(input_path), _str_optional(output_path), _str_optional(content_type), _str_optional(json_format)) _execute(command) def serve(self, model_uri, port, host): """ Generate R model locally. """ model_path = _download_artifact_from_uri(model_uri) command = "mlflow::mlflow_rfunc_serve('{0}', port = {1}, host = '{2}')".format( shlex_quote(model_path), port, host) _execute(command) def can_score_model(self): process = subprocess.Popen(["Rscript", "--version"], close_fds=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) _, stderr = process.communicate() if process.wait() != 0: return False version = self.version_pattern.search(stderr.decode("utf-8")) if not version: return False version = [int(x) for x in version.group(1).split(".")] return version[0] > 3 or version[0] == 3 and version[1] >= 3 def _execute(command): env = os.environ.copy() import sys process = subprocess.Popen(["Rscript", "-e", command], env=env, close_fds=False, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr) if process.wait() != 0: raise Exception("Command returned non zero exit code.") def _str_optional(s): return "NULL" if s is None else "'{}'".format(shlex_quote(str(s)))