""" Train module nodes ================== This module contains nodes that train a random forest model on the MNIST data set. """ import datetime as dt from airflow import DAG from airflow.operators.python_operator import PythonOperator from sklearn.datasets import load_digits from sklearn.ensemble import RandomForestClassifier from sklearn.externals import joblib from sklearn.preprocessing import LabelBinarizer # NOTE: Normally you would put this in a config file MODEL_OUTPUT_PATH = "models/example.joblib" def get_mnist_data(): """Loads the MNIST data set into memory. Returns ------- X : array-like, shape=[n_samples, n_features] Training data for the MNIST data set. y : array-like, shape=[n_samples,] Labels for the MNIST data set. """ digits = load_digits() X, y = digits.data, digits.target y = LabelBinarizer().fit_transform(y) return X, y def fit_estimator(model_path, **kwargs): """Estimates a random forest on the MNIST data set. Parameters ---------- model_path : str Path the pickled model is written to. kwargs : dict Keyword arguments for Airflow compatibility. """ X, y = get_mnist_data() model = RandomForestClassifier(n_estimators=50) model.fit(X, y) joblib.dump(model, model_path) def predict_samples(model_path, **kwargs): """Computes in-sample predictions on the MNIST data set, using the model built by ``fit_estimator``. Parameters ---------- model_path : str Path the pickled model is loaded from. kwargs : dict Keyword arguments for Airflow compatibility. """ model = joblib.load(model_path) X, y = get_mnist_data() # XX: Normally you would save the predictions to somewhere here. model.predict(X) model_output_path = MODEL_OUTPUT_PATH default_args = { "owner": "me", "start_date": dt.datetime(2017, 6, 1), "retries": 1, "retry_delay": dt.timedelta(minutes=5), } with DAG( "dummy_ml_pipeline", default_args=default_args, schedule_interval="0 * * * *" ) as dag: train_model = PythonOperator( task_id="train_model", provide_context=True, op_kwargs={"model_path": model_output_path}, python_callable=fit_estimator, ) predict_data = PythonOperator( task_id="predict_data", provide_context=True, op_kwargs={"model_path": model_output_path}, python_callable=predict_samples, ) train_model.set_downstream(predict_data)