/* * Copyright 2013-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.cloud.gateway.filter; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Random; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Mono; import org.springframework.beans.factory.ObjectProvider; import org.springframework.cloud.gateway.event.PredicateArgsEvent; import org.springframework.cloud.gateway.event.RefreshRoutesEvent; import org.springframework.cloud.gateway.event.WeightDefinedEvent; import org.springframework.cloud.gateway.route.RouteLocator; import org.springframework.cloud.gateway.support.ConfigurationService; import org.springframework.cloud.gateway.support.WeightConfig; import org.springframework.context.ApplicationEvent; import org.springframework.context.event.SmartApplicationListener; import org.springframework.core.Ordered; import org.springframework.core.style.ToStringCreator; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.WEIGHT_ATTR; /** * @author Spencer Gibb * @author Alexey Nakidkin */ public class WeightCalculatorWebFilter implements WebFilter, Ordered, SmartApplicationListener { /** * Order of Weight Calculator Web filter. */ public static final int WEIGHT_CALC_FILTER_ORDER = 10001; private static final Log log = LogFactory.getLog(WeightCalculatorWebFilter.class); private final ObjectProvider<RouteLocator> routeLocator; private final ConfigurationService configurationService; private Random random = new Random(); private int order = WEIGHT_CALC_FILTER_ORDER; private Map<String, GroupWeightConfig> groupWeights = new ConcurrentHashMap<>(); public WeightCalculatorWebFilter(ObjectProvider<RouteLocator> routeLocator, ConfigurationService configurationService) { this.routeLocator = routeLocator; this.configurationService = configurationService; } /* for testing */ static Map<String, String> getWeights(ServerWebExchange exchange) { Map<String, String> weights = exchange.getAttribute(WEIGHT_ATTR); if (weights == null) { weights = new ConcurrentHashMap<>(); exchange.getAttributes().put(WEIGHT_ATTR, weights); } return weights; } @Override public int getOrder() { return order; } public void setOrder(int order) { this.order = order; } public void setRandom(Random random) { this.random = random; } @Override public boolean supportsEventType(Class<? extends ApplicationEvent> eventType) { // from config file return PredicateArgsEvent.class.isAssignableFrom(eventType) || // from java dsl WeightDefinedEvent.class.isAssignableFrom(eventType) || // force initialization RefreshRoutesEvent.class.isAssignableFrom(eventType); } @Override public boolean supportsSourceType(Class<?> sourceType) { return true; } @Override public void onApplicationEvent(ApplicationEvent event) { if (event instanceof PredicateArgsEvent) { handle((PredicateArgsEvent) event); } else if (event instanceof WeightDefinedEvent) { addWeightConfig(((WeightDefinedEvent) event).getWeightConfig()); } else if (event instanceof RefreshRoutesEvent && routeLocator != null) { // forces initialization routeLocator.ifAvailable(locator -> locator.getRoutes().subscribe()); } } public void handle(PredicateArgsEvent event) { Map<String, Object> args = event.getArgs(); if (args.isEmpty() || !hasRelevantKey(args)) { return; } WeightConfig config = new WeightConfig(event.getRouteId()); this.configurationService.with(config).name(WeightConfig.CONFIG_PREFIX) .normalizedProperties(args).bind(); addWeightConfig(config); } private boolean hasRelevantKey(Map<String, Object> args) { return args.keySet().stream() .anyMatch(key -> key.startsWith(WeightConfig.CONFIG_PREFIX + ".")); } /* for testing */ void addWeightConfig(WeightConfig weightConfig) { String group = weightConfig.getGroup(); GroupWeightConfig config; // only create new GroupWeightConfig rather than modify // and put at end of calculations. This avoids concurency problems // later during filter execution. if (groupWeights.containsKey(group)) { config = new GroupWeightConfig(groupWeights.get(group)); } else { config = new GroupWeightConfig(group); } config.weights.put(weightConfig.getRouteId(), weightConfig.getWeight()); // recalculate // normalize weights int weightsSum = 0; for (Integer weight : config.weights.values()) { weightsSum += weight; } final AtomicInteger index = new AtomicInteger(0); for (Map.Entry<String, Integer> entry : config.weights.entrySet()) { String routeId = entry.getKey(); Integer weight = entry.getValue(); Double nomalizedWeight = weight / (double) weightsSum; config.normalizedWeights.put(routeId, nomalizedWeight); // recalculate rangeIndexes config.rangeIndexes.put(index.getAndIncrement(), routeId); } // TODO: calculate ranges config.ranges.clear(); config.ranges.add(0.0); List<Double> values = new ArrayList<>(config.normalizedWeights.values()); for (int i = 0; i < values.size(); i++) { Double currentWeight = values.get(i); Double previousRange = config.ranges.get(i); Double range = previousRange + currentWeight; config.ranges.add(range); } if (log.isTraceEnabled()) { log.trace("Recalculated group weight config " + config); } // only update after all calculations groupWeights.put(group, config); } /* for testing */ Map<String, GroupWeightConfig> getGroupWeights() { return groupWeights; } @Override public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) { Map<String, String> weights = getWeights(exchange); for (String group : groupWeights.keySet()) { GroupWeightConfig config = groupWeights.get(group); if (config == null) { if (log.isDebugEnabled()) { log.debug("No GroupWeightConfig found for group: " + group); } continue; // nothing we can do, but this is odd } double r = this.random.nextDouble(); List<Double> ranges = config.ranges; if (log.isTraceEnabled()) { log.trace("Weight for group: " + group + ", ranges: " + ranges + ", r: " + r); } for (int i = 0; i < ranges.size() - 1; i++) { if (r >= ranges.get(i) && r < ranges.get(i + 1)) { String routeId = config.rangeIndexes.get(i); weights.put(group, routeId); break; } } } if (log.isTraceEnabled()) { log.trace("Weights attr: " + weights); } return chain.filter(exchange); } /* for testing */ static class GroupWeightConfig { String group; LinkedHashMap<String, Integer> weights = new LinkedHashMap<>(); LinkedHashMap<String, Double> normalizedWeights = new LinkedHashMap<>(); LinkedHashMap<Integer, String> rangeIndexes = new LinkedHashMap<>(); List<Double> ranges = new ArrayList<>(); GroupWeightConfig(String group) { this.group = group; } GroupWeightConfig(GroupWeightConfig other) { this.group = other.group; this.weights = new LinkedHashMap<>(other.weights); this.normalizedWeights = new LinkedHashMap<>(other.normalizedWeights); this.rangeIndexes = new LinkedHashMap<>(other.rangeIndexes); } @Override public String toString() { return new ToStringCreator(this).append("group", group) .append("weights", weights) .append("normalizedWeights", normalizedWeights) .append("rangeIndexes", rangeIndexes).toString(); } } }