from dataclasses import dataclass import pytest import requests import pandas from livy import ( LivySession, LivyBatch, SessionKind, SparkRuntimeError, SessionState, ) def livy_available(livy_url): return requests.get(livy_url).status_code == 200 def session_stopped(livy_url, session_id): response = requests.get(f"{livy_url}/session/{session_id}") if response.status_code == 404: return True else: return response.get_json()["state"] == "shutting_down" @dataclass class Parameters: print_foo_code: str print_foo_output: str create_dataframe_code: str dataframe_count_code: str dataframe_count_output: str error_code: str SPARK_CREATE_DF = """ import org.apache.spark.sql.Row import org.apache.spark.sql.types._ val rdd = sc.parallelize(0 to 99) val schema = StructType(List( StructField("value", IntegerType, nullable = false) )) val df = spark.createDataFrame(rdd.map { i => Row(i) }, schema) """ SPARK_TEST_PARAMETERS = Parameters( print_foo_code='println("foo")', print_foo_output="foo\n\n", create_dataframe_code=SPARK_CREATE_DF, dataframe_count_code="df.count()", dataframe_count_output="res1: Long = 100\n\n", error_code="1 / 0", ) PYSPARK_CREATE_DF = """ from pyspark.sql import Row df = spark.createDataFrame([Row(value=i) for i in range(100)]) """ PYSPARK_TEST_PARAMETERS = Parameters( print_foo_code='print("foo")', print_foo_output="foo\n", create_dataframe_code=PYSPARK_CREATE_DF, dataframe_count_code="df.count()", dataframe_count_output="100\n", error_code="1 / 0", ) SPARKR_CREATE_DF = """ df <- createDataFrame(data.frame(value = 0:99)) """ SPARKR_TEST_PARAMETERS = Parameters( print_foo_code='print("foo")', print_foo_output='[1] "foo"\n', create_dataframe_code=SPARKR_CREATE_DF, dataframe_count_code="count(df)", dataframe_count_output="[1] 100\n", error_code="missing_function()", ) @pytest.mark.integration @pytest.mark.parametrize( "session_kind, params", [ (SessionKind.SPARK, SPARK_TEST_PARAMETERS), (SessionKind.PYSPARK, PYSPARK_TEST_PARAMETERS), (SessionKind.SPARKR, SPARKR_TEST_PARAMETERS), ], ) def test_session(integration_url, capsys, session_kind, params): assert livy_available(integration_url) with LivySession.create(integration_url, kind=session_kind) as session: assert session.state == SessionState.IDLE session.run(params.print_foo_code) assert capsys.readouterr() == (params.print_foo_output, "") session.run(params.create_dataframe_code) capsys.readouterr() session.run(params.dataframe_count_code) assert capsys.readouterr() == (params.dataframe_count_output, "") with pytest.raises(SparkRuntimeError): session.run(params.error_code) expected = pandas.DataFrame({"value": range(100)}) assert session.read("df").equals(expected) assert session_stopped(integration_url, session.session_id) SQL_CREATE_VIEW = """ CREATE TEMPORARY VIEW view AS SELECT * FROM RANGE(100) """ @pytest.mark.integration def test_sql_session(integration_url): assert livy_available(integration_url) with LivySession.create(integration_url, kind=SessionKind.SQL) as session: assert session.state == SessionState.IDLE session.run(SQL_CREATE_VIEW) output = session.run("SELECT COUNT(*) FROM view") assert output.json["data"] == [[100]] with pytest.raises(SparkRuntimeError): session.run("not valid SQL!") expected = pandas.DataFrame({"id": range(100)}) assert session.read_sql("SELECT * FROM view").equals(expected) assert session_stopped(integration_url, session.session_id) @pytest.mark.integration def test_batch_job(integration_url): assert livy_available(integration_url) batch = LivyBatch.create( integration_url, file=( "https://repo.typesafe.com/typesafe/maven-releases/org/apache/" "spark/spark-examples_2.11/1.6.0-typesafe-001/" "spark-examples_2.11-1.6.0-typesafe-001.jar" ), class_name="org.apache.spark.examples.SparkPi", ) assert batch.state == SessionState.RUNNING batch.wait() assert batch.state == SessionState.SUCCESS assert any( "spark.SparkContext: Successfully stopped SparkContext" in line for line in batch.log() )