/* * Copyright 2013-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.cloud.gateway.route; import java.util.ArrayList; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Flux; import org.springframework.cloud.gateway.config.GatewayProperties; import org.springframework.cloud.gateway.event.FilterArgsEvent; import org.springframework.cloud.gateway.event.PredicateArgsEvent; import org.springframework.cloud.gateway.filter.FilterDefinition; import org.springframework.cloud.gateway.filter.GatewayFilter; import org.springframework.cloud.gateway.filter.OrderedGatewayFilter; import org.springframework.cloud.gateway.filter.factory.GatewayFilterFactory; import org.springframework.cloud.gateway.handler.AsyncPredicate; import org.springframework.cloud.gateway.handler.predicate.PredicateDefinition; import org.springframework.cloud.gateway.handler.predicate.RoutePredicateFactory; import org.springframework.cloud.gateway.support.ConfigurationService; import org.springframework.cloud.gateway.support.HasRouteId; import org.springframework.core.Ordered; import org.springframework.core.annotation.AnnotationAwareOrderComparator; import org.springframework.web.server.ServerWebExchange; /** * {@link RouteLocator} that loads routes from a {@link RouteDefinitionLocator}. * * @author Spencer Gibb */ public class RouteDefinitionRouteLocator implements RouteLocator { /** * Default filters name. */ public static final String DEFAULT_FILTERS = "defaultFilters"; protected final Log logger = LogFactory.getLog(getClass()); private final RouteDefinitionLocator routeDefinitionLocator; private final ConfigurationService configurationService; private final Map<String, RoutePredicateFactory> predicates = new LinkedHashMap<>(); private final Map<String, GatewayFilterFactory> gatewayFilterFactories = new HashMap<>(); private final GatewayProperties gatewayProperties; public RouteDefinitionRouteLocator(RouteDefinitionLocator routeDefinitionLocator, List<RoutePredicateFactory> predicates, List<GatewayFilterFactory> gatewayFilterFactories, GatewayProperties gatewayProperties, ConfigurationService configurationService) { this.routeDefinitionLocator = routeDefinitionLocator; this.configurationService = configurationService; initFactories(predicates); gatewayFilterFactories.forEach( factory -> this.gatewayFilterFactories.put(factory.name(), factory)); this.gatewayProperties = gatewayProperties; } private void initFactories(List<RoutePredicateFactory> predicates) { predicates.forEach(factory -> { String key = factory.name(); if (this.predicates.containsKey(key)) { this.logger.warn("A RoutePredicateFactory named " + key + " already exists, class: " + this.predicates.get(key) + ". It will be overwritten."); } this.predicates.put(key, factory); if (logger.isInfoEnabled()) { logger.info("Loaded RoutePredicateFactory [" + key + "]"); } }); } @Override public Flux<Route> getRoutes() { Flux<Route> routes = this.routeDefinitionLocator.getRouteDefinitions() .map(this::convertToRoute); if (!gatewayProperties.isFailOnRouteDefinitionError()) { // instead of letting error bubble up, continue routes = routes.onErrorContinue((error, obj) -> { if (logger.isWarnEnabled()) { logger.warn("RouteDefinition id " + ((RouteDefinition) obj).getId() + " will be ignored. Definition has invalid configs, " + error.getMessage()); } }); } return routes.map(route -> { if (logger.isDebugEnabled()) { logger.debug("RouteDefinition matched: " + route.getId()); } return route; }); } private Route convertToRoute(RouteDefinition routeDefinition) { AsyncPredicate<ServerWebExchange> predicate = combinePredicates(routeDefinition); List<GatewayFilter> gatewayFilters = getFilters(routeDefinition); return Route.async(routeDefinition).asyncPredicate(predicate) .replaceFilters(gatewayFilters).build(); } @SuppressWarnings("unchecked") List<GatewayFilter> loadGatewayFilters(String id, List<FilterDefinition> filterDefinitions) { ArrayList<GatewayFilter> ordered = new ArrayList<>(filterDefinitions.size()); for (int i = 0; i < filterDefinitions.size(); i++) { FilterDefinition definition = filterDefinitions.get(i); GatewayFilterFactory factory = this.gatewayFilterFactories .get(definition.getName()); if (factory == null) { throw new IllegalArgumentException( "Unable to find GatewayFilterFactory with name " + definition.getName()); } if (logger.isDebugEnabled()) { logger.debug("RouteDefinition " + id + " applying filter " + definition.getArgs() + " to " + definition.getName()); } // @formatter:off Object configuration = this.configurationService.with(factory) .name(definition.getName()) .properties(definition.getArgs()) .eventFunction((bound, properties) -> new FilterArgsEvent( // TODO: why explicit cast needed or java compile fails RouteDefinitionRouteLocator.this, id, (Map<String, Object>) properties)) .bind(); // @formatter:on // some filters require routeId // TODO: is there a better place to apply this? if (configuration instanceof HasRouteId) { HasRouteId hasRouteId = (HasRouteId) configuration; hasRouteId.setRouteId(id); } GatewayFilter gatewayFilter = factory.apply(configuration); if (gatewayFilter instanceof Ordered) { ordered.add(gatewayFilter); } else { ordered.add(new OrderedGatewayFilter(gatewayFilter, i + 1)); } } return ordered; } private List<GatewayFilter> getFilters(RouteDefinition routeDefinition) { List<GatewayFilter> filters = new ArrayList<>(); // TODO: support option to apply defaults after route specific filters? if (!this.gatewayProperties.getDefaultFilters().isEmpty()) { filters.addAll(loadGatewayFilters(DEFAULT_FILTERS, new ArrayList<>(this.gatewayProperties.getDefaultFilters()))); } if (!routeDefinition.getFilters().isEmpty()) { filters.addAll(loadGatewayFilters(routeDefinition.getId(), new ArrayList<>(routeDefinition.getFilters()))); } AnnotationAwareOrderComparator.sort(filters); return filters; } private AsyncPredicate<ServerWebExchange> combinePredicates( RouteDefinition routeDefinition) { List<PredicateDefinition> predicates = routeDefinition.getPredicates(); AsyncPredicate<ServerWebExchange> predicate = lookup(routeDefinition, predicates.get(0)); for (PredicateDefinition andPredicate : predicates.subList(1, predicates.size())) { AsyncPredicate<ServerWebExchange> found = lookup(routeDefinition, andPredicate); predicate = predicate.and(found); } return predicate; } @SuppressWarnings("unchecked") private AsyncPredicate<ServerWebExchange> lookup(RouteDefinition route, PredicateDefinition predicate) { RoutePredicateFactory<Object> factory = this.predicates.get(predicate.getName()); if (factory == null) { throw new IllegalArgumentException( "Unable to find RoutePredicateFactory with name " + predicate.getName()); } if (logger.isDebugEnabled()) { logger.debug("RouteDefinition " + route.getId() + " applying " + predicate.getArgs() + " to " + predicate.getName()); } // @formatter:off Object config = this.configurationService.with(factory) .name(predicate.getName()) .properties(predicate.getArgs()) .eventFunction((bound, properties) -> new PredicateArgsEvent( RouteDefinitionRouteLocator.this, route.getId(), properties)) .bind(); // @formatter:on return factory.applyAsync(config); } }