/* * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.servingsdk.impl; import com.amazonaws.ml.mms.http.InvalidPluginException; import java.lang.annotation.Annotation; import java.util.HashMap; import java.util.Map; import java.util.ServiceLoader; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.ai.mms.servingsdk.ModelServerEndpoint; import software.amazon.ai.mms.servingsdk.annotations.Endpoint; import software.amazon.ai.mms.servingsdk.annotations.helpers.EndpointTypes; public final class PluginsManager { private static final PluginsManager INSTANCE = new PluginsManager(); private Logger logger = LoggerFactory.getLogger(PluginsManager.class); private Map<String, ModelServerEndpoint> inferenceEndpoints; private Map<String, ModelServerEndpoint> managementEndpoints; private PluginsManager() {} public static PluginsManager getInstance() { return INSTANCE; } public void initialize() { inferenceEndpoints = initInferenceEndpoints(); managementEndpoints = initManagementEndpoints(); } private boolean validateEndpointPlugin(Annotation a, EndpointTypes type) { return a instanceof Endpoint && !((Endpoint) a).urlPattern().isEmpty() && ((Endpoint) a).endpointType().equals(type); } private HashMap<String, ModelServerEndpoint> getEndpoints(EndpointTypes type) throws InvalidPluginException { ServiceLoader<ModelServerEndpoint> loader = ServiceLoader.load(ModelServerEndpoint.class); HashMap<String, ModelServerEndpoint> ep = new HashMap<>(); for (ModelServerEndpoint mep : loader) { Class<? extends ModelServerEndpoint> modelServerEndpointClassObj = mep.getClass(); Annotation[] annotations = modelServerEndpointClassObj.getAnnotations(); for (Annotation a : annotations) { if (validateEndpointPlugin(a, type)) { if (ep.get(((Endpoint) a).urlPattern()) != null) { throw new InvalidPluginException( "Multiple plugins found for endpoint " + "\"" + ((Endpoint) a).urlPattern() + "\""); } logger.info("Loading plugin for endpoint {}", ((Endpoint) a).urlPattern()); ep.put(((Endpoint) a).urlPattern(), mep); } } } return ep; } private HashMap<String, ModelServerEndpoint> initInferenceEndpoints() { return getEndpoints(EndpointTypes.INFERENCE); } private HashMap<String, ModelServerEndpoint> initManagementEndpoints() { return getEndpoints(EndpointTypes.MANAGEMENT); } public Map<String, ModelServerEndpoint> getInferenceEndpoints() { return inferenceEndpoints; } public Map<String, ModelServerEndpoint> getManagementEndpoints() { return managementEndpoints; } }