/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.runtime.webmonitor.handlers; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.time.Time; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.messages.Acknowledge; import org.apache.flink.runtime.rest.handler.HandlerRequest; import org.apache.flink.runtime.rest.handler.HandlerRequestException; import org.apache.flink.runtime.rest.handler.RestHandlerException; import org.apache.flink.runtime.rest.messages.MessageParameter; import org.apache.flink.runtime.rest.messages.MessageQueryParameter; import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.runtime.util.BlobServerResource; import org.apache.flink.runtime.webmonitor.TestingDispatcherGateway; import org.apache.flink.runtime.webmonitor.retriever.GatewayRetriever; import org.apache.flink.runtime.webmonitor.testutils.ParameterProgram; import org.apache.flink.util.TestLogger; import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus; import org.junit.Assert; import org.junit.Before; import org.junit.ClassRule; import org.junit.Test; import org.junit.rules.TemporaryFolder; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import static junit.framework.TestCase.assertEquals; import static junit.framework.TestCase.fail; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; /** Base test class for jar request handlers. */ public abstract class JarHandlerParameterTest <REQB extends JarRequestBody, M extends JarMessageParameters> extends TestLogger { enum ProgramArgsParType { String, List, Both } static final String[] PROG_ARGS = new String[] {"--host", "localhost", "--port", "1234"}; static final int PARALLELISM = 4; @ClassRule public static final TemporaryFolder TMP = new TemporaryFolder(); @ClassRule public static final BlobServerResource BLOB_SERVER_RESOURCE = new BlobServerResource(); static final AtomicReference<JobGraph> LAST_SUBMITTED_JOB_GRAPH_REFERENCE = new AtomicReference<>(); static TestingDispatcherGateway restfulGateway; static Path jarDir; static GatewayRetriever<TestingDispatcherGateway> gatewayRetriever = () -> CompletableFuture.completedFuture(restfulGateway); static CompletableFuture<String> localAddressFuture = CompletableFuture.completedFuture("shazam://localhost:12345"); static Time timeout = Time.seconds(10); static Map<String, String> responseHeaders = Collections.emptyMap(); static Executor executor = TestingUtils.defaultExecutor(); private static Path jarWithManifest; private static Path jarWithoutManifest; static void init() throws Exception { jarDir = TMP.newFolder().toPath(); // properties are set property by surefire plugin final String parameterProgramJarName = System.getProperty("parameterJarName") + ".jar"; final String parameterProgramWithoutManifestJarName = System.getProperty("parameterJarWithoutManifestName") + ".jar"; final Path jarLocation = Paths.get(System.getProperty("targetDir")); jarWithManifest = Files.copy( jarLocation.resolve(parameterProgramJarName), jarDir.resolve("program-with-manifest.jar")); jarWithoutManifest = Files.copy( jarLocation.resolve(parameterProgramWithoutManifestJarName), jarDir.resolve("program-without-manifest.jar")); restfulGateway = new TestingDispatcherGateway.Builder() .setBlobServerPort(BLOB_SERVER_RESOURCE.getBlobServerPort()) .setSubmitFunction(jobGraph -> { LAST_SUBMITTED_JOB_GRAPH_REFERENCE.set(jobGraph); return CompletableFuture.completedFuture(Acknowledge.get()); }) .build(); gatewayRetriever = () -> CompletableFuture.completedFuture(restfulGateway); localAddressFuture = CompletableFuture.completedFuture("shazam://localhost:12345"); timeout = Time.seconds(10); responseHeaders = Collections.emptyMap(); executor = TestingUtils.defaultExecutor(); } @Before public void reset() { ParameterProgram.actualArguments = null; } @Test public void testDefaultParameters() throws Exception { // baseline, ensure that reasonable defaults are chosen handleRequest(createRequest( getDefaultJarRequestBody(), getUnresolvedJarMessageParameters(), getUnresolvedJarMessageParameters(), jarWithManifest)); validateDefaultGraph(); } @Test public void testConfigurationViaQueryParametersWithProgArgsAsString() throws Exception { testConfigurationViaQueryParameters(ProgramArgsParType.String); } @Test public void testConfigurationViaQueryParametersWithProgArgsAsList() throws Exception { testConfigurationViaQueryParameters(ProgramArgsParType.List); } @Test public void testConfigurationViaQueryParametersFailWithProgArgsAsStringAndList() throws Exception { try { testConfigurationViaQueryParameters(ProgramArgsParType.Both); fail("RestHandlerException is excepted"); } catch (RestHandlerException e) { assertEquals(HttpResponseStatus.BAD_REQUEST, e.getHttpResponseStatus()); } } private void testConfigurationViaQueryParameters(ProgramArgsParType programArgsParType) throws Exception { // configure submission via query parameters handleRequest(createRequest( getDefaultJarRequestBody(), getJarMessageParameters(programArgsParType), getUnresolvedJarMessageParameters(), jarWithoutManifest)); validateGraph(); } @Test public void testConfigurationViaJsonRequestWithProgArgsAsString() throws Exception { testConfigurationViaJsonRequest(ProgramArgsParType.String); } @Test public void testConfigurationViaJsonRequestWithProgArgsAsList() throws Exception { testConfigurationViaJsonRequest(ProgramArgsParType.List); } @Test public void testConfigurationViaJsonRequestFailWithProgArgsAsStringAndList() throws Exception { try { testConfigurationViaJsonRequest(ProgramArgsParType.Both); fail("RestHandlerException is excepted"); } catch (RestHandlerException e) { assertEquals(HttpResponseStatus.BAD_REQUEST, e.getHttpResponseStatus()); } } @Test public void testProvideJobId() throws Exception { JobID jobId = new JobID(); HandlerRequest<REQB, M> request = createRequest( getJarRequestBodyWithJobId(jobId), getUnresolvedJarMessageParameters(), getUnresolvedJarMessageParameters(), jarWithManifest ); handleRequest(request); Optional<JobGraph> jobGraph = getLastSubmittedJobGraphAndReset(); assertThat(jobGraph.isPresent(), is(true)); assertThat(jobGraph.get().getJobID(), is(equalTo(jobId))); } private void testConfigurationViaJsonRequest(ProgramArgsParType programArgsParType) throws Exception { handleRequest(createRequest( getJarRequestBody(programArgsParType), getUnresolvedJarMessageParameters(), getUnresolvedJarMessageParameters(), jarWithoutManifest )); validateGraph(); } @Test public void testParameterPrioritizationWithProgArgsAsString() throws Exception { testParameterPrioritization(ProgramArgsParType.String); } @Test public void testParameterPrioritizationWithProgArgsAsList() throws Exception { testParameterPrioritization(ProgramArgsParType.List); } @Test public void testFailIfProgArgsAreAsStringAndAsList() throws Exception { try { testParameterPrioritization(ProgramArgsParType.Both); fail("RestHandlerException is excepted"); } catch (RestHandlerException e) { assertEquals(HttpResponseStatus.BAD_REQUEST, e.getHttpResponseStatus()); } } private void testParameterPrioritization(ProgramArgsParType programArgsParType) throws Exception { // configure submission via query parameters and JSON request, JSON should be prioritized handleRequest(createRequest( getJarRequestBody(programArgsParType), getWrongJarMessageParameters(programArgsParType), getUnresolvedJarMessageParameters(), jarWithoutManifest)); validateGraph(); } static String getProgramArgsString(ProgramArgsParType programArgsParType) { return programArgsParType == ProgramArgsParType.String || programArgsParType == ProgramArgsParType.Both ? String.join(" ", PROG_ARGS) : null; } static List<String> getProgramArgsList(ProgramArgsParType programArgsParType) { return programArgsParType == ProgramArgsParType.List || programArgsParType == ProgramArgsParType.Both ? Arrays.asList(PROG_ARGS) : null; } private static <REQB extends JarRequestBody, M extends JarMessageParameters> HandlerRequest<REQB, M> createRequest( REQB requestBody, M parameters, M unresolvedMessageParameters, Path jar) throws HandlerRequestException { final Map<String, List<String>> queryParameterAsMap = parameters.getQueryParameters().stream() .filter(MessageParameter::isResolved) .collect(Collectors.toMap( MessageParameter::getKey, JarHandlerParameterTest::getValuesAsString )); return new HandlerRequest<>( requestBody, unresolvedMessageParameters, Collections.singletonMap(JarIdPathParameter.KEY, jar.getFileName().toString()), queryParameterAsMap, Collections.emptyList() ); } private static <X> List<String> getValuesAsString(MessageQueryParameter<X> parameter) { final List<X> values = parameter.getValue(); return values.stream().map(parameter::convertValueToString).collect(Collectors.toList()); } abstract M getUnresolvedJarMessageParameters(); abstract M getJarMessageParameters(ProgramArgsParType programArgsParType); abstract M getWrongJarMessageParameters(ProgramArgsParType programArgsParType); abstract REQB getDefaultJarRequestBody(); abstract REQB getJarRequestBody(ProgramArgsParType programArgsParType); abstract REQB getJarRequestBodyWithJobId(JobID jobId); abstract void handleRequest(HandlerRequest<REQB, M> request) throws Exception; JobGraph validateDefaultGraph() { JobGraph jobGraph = LAST_SUBMITTED_JOB_GRAPH_REFERENCE.getAndSet(null); Assert.assertEquals(0, ParameterProgram.actualArguments.length); Assert.assertEquals(ExecutionConfig.PARALLELISM_DEFAULT, getExecutionConfig(jobGraph).getParallelism()); return jobGraph; } JobGraph validateGraph() { JobGraph jobGraph = LAST_SUBMITTED_JOB_GRAPH_REFERENCE.getAndSet(null); Assert.assertArrayEquals(PROG_ARGS, ParameterProgram.actualArguments); Assert.assertEquals(PARALLELISM, getExecutionConfig(jobGraph).getParallelism()); return jobGraph; } private static Optional<JobGraph> getLastSubmittedJobGraphAndReset() { return Optional.ofNullable(LAST_SUBMITTED_JOB_GRAPH_REFERENCE.getAndSet(null)); } private static ExecutionConfig getExecutionConfig(JobGraph jobGraph) { ExecutionConfig executionConfig; try { executionConfig = jobGraph.getSerializedExecutionConfig().deserializeValue(ParameterProgram.class.getClassLoader()); } catch (Exception e) { throw new AssertionError("Exception while deserializing ExecutionConfig.", e); } return executionConfig; } }