/** * Amazon Kinesis Scaling Utility * * Copyright 2014, Amazon.com, Inc. or its affiliates. All Rights Reserved. * * SPDX-License-Identifier: Apache-2.0 */ package com.amazonaws.services.kinesis.scaling; import java.math.BigDecimal; import java.math.BigInteger; import java.math.RoundingMode; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.amazonaws.services.kinesis.scaling.StreamScaler.SortOrder; import software.amazon.awssdk.services.kinesis.KinesisClient; import software.amazon.awssdk.services.kinesis.model.DescribeStreamSummaryRequest; import software.amazon.awssdk.services.kinesis.model.DescribeStreamSummaryResponse; import software.amazon.awssdk.services.kinesis.model.LimitExceededException; import software.amazon.awssdk.services.kinesis.model.ListShardsRequest; import software.amazon.awssdk.services.kinesis.model.ListShardsResponse; import software.amazon.awssdk.services.kinesis.model.MergeShardsRequest; import software.amazon.awssdk.services.kinesis.model.ResourceInUseException; import software.amazon.awssdk.services.kinesis.model.Shard; import software.amazon.awssdk.services.kinesis.model.SplitShardRequest; import software.amazon.awssdk.services.kinesis.model.StreamDescriptionSummary; import software.amazon.awssdk.services.sns.SnsClient; import software.amazon.awssdk.services.sns.model.PublishRequest; public class StreamScalingUtils { private static final Logger LOG = LoggerFactory.getLogger(StreamScalingUtils.class); public static final int DESCRIBE_RETRIES = 10; public static final int MODIFY_RETRIES = 10; // retry timeout set to 100ms as API's will potentially throttle > 10/sec public static final int RETRY_TIMEOUT_MS = 100; // rounding scale for BigInteger and BigDecimal comparisons public static final int PCT_COMPARISON_SCALE = 10; public static final RoundingMode ROUNDING_MODE = RoundingMode.HALF_DOWN; private static interface KinesisOperation { public Object run(KinesisClient client); } /** * Method to do a fuzzy comparison between two doubles, so that we can make * generalisations about allocation of keyspace to shards. For example, when we * have a stream of 3 shards, we'll have shards of 33, 33, and 34% of the * keyspace - these must all be treated as equal * * @param a * @param b * @return */ public static int softCompare(double a, double b) { // allow variation by 1 order of magnitude greater than the comparison // scale final BigDecimal acceptedVariation = BigDecimal.valueOf(1d) .divide(BigDecimal.valueOf(10d).pow(PCT_COMPARISON_SCALE - 1)); BigDecimal first = new BigDecimal(a).setScale(PCT_COMPARISON_SCALE, ROUNDING_MODE); BigDecimal second = new BigDecimal(b).setScale(PCT_COMPARISON_SCALE, ROUNDING_MODE); BigDecimal variation = first.subtract(second).abs(); // if the variation of the two values is within the accepted variation, // then we return 'equal' if (variation.compareTo(acceptedVariation) < 0) { return 0; } else { return first.compareTo(second); } } /** * Wait for a Stream to become available or transition to the indicated status * * @param streamName * @param status * @throws Exception */ public static void waitForStreamStatus(KinesisClient kinesisClient, String streamName, String status) throws Exception { boolean ok = false; String streamStatus; // stream mutation takes around 30 seconds, so we'll start with 20 as // a timeout int waitTimeout = 20000; do { streamStatus = getStreamStatus(kinesisClient, streamName); if (!streamStatus.equals(status)) { Thread.sleep(waitTimeout); // reduce the wait timeout from the initial wait time waitTimeout = 1000; } else { ok = true; } } while (!ok); } /** * Get the status of a Stream * * @param streamName * @return */ protected static String getStreamStatus(KinesisClient kinesisClient, String streamName) throws Exception { return describeStream(kinesisClient, streamName).streamStatus().name(); } public static StreamDescriptionSummary describeStream(final KinesisClient kinesisClient, final String streamName) throws Exception { KinesisOperation describe = new KinesisOperation() { public Object run(KinesisClient client) { DescribeStreamSummaryResponse result = client .describeStreamSummary(DescribeStreamSummaryRequest.builder().streamName(streamName).build()); return result.streamDescriptionSummary(); } }; return (StreamDescriptionSummary) doOperation(kinesisClient, describe, streamName, DESCRIBE_RETRIES, false); } public static List<Shard> listShards(final KinesisClient kinesisClient, final String streamName, final String shardIdStart) throws Exception { LOG.debug(String.format("Listing Stream %s from Shard %s", streamName, shardIdStart)); KinesisOperation describe = new KinesisOperation() { public Object run(KinesisClient client) { ListShardsRequest.Builder builder = ListShardsRequest.builder().streamName(streamName); ListShardsRequest req = null; boolean hasMoreResults = true; List<Shard> shards = new ArrayList<>(); while (hasMoreResults) { if (shardIdStart != null && (req != null && req.nextToken() == null)) { builder.exclusiveStartShardId(shardIdStart); } ListShardsResponse result = client.listShards(builder.build()); shards.addAll(result.shards()); if (result.nextToken() == null) { hasMoreResults = false; } else { req = ListShardsRequest.builder().nextToken(result.nextToken()).build(); } } return shards; } }; return (List<Shard>) doOperation(kinesisClient, describe, streamName, DESCRIBE_RETRIES, false); } public static Shard getShard(final KinesisClient kinesisClient, final String streamName, final String shardIdStart) throws Exception { LOG.debug(String.format("Getting Shard %s for Stream %s", shardIdStart, streamName)); KinesisOperation describe = new KinesisOperation() { public Object run(KinesisClient client) { // reduce the shardIdStart by 1 as the API uses it as an exclusive start key not // a filter String shardIdToQuery = new BigDecimal(shardIdStart).subtract(new BigDecimal("1")).toString(); ListShardsRequest req = ListShardsRequest.builder().streamName(streamName) .exclusiveStartShardId(shardIdToQuery).build(); ListShardsResponse result = client.listShards(req); return result.shards().get(0); } }; return (Shard) doOperation(kinesisClient, describe, streamName, DESCRIBE_RETRIES, false); } public static void splitShard(final KinesisClient kinesisClient, final String streamName, final String shardId, final BigInteger targetHash, final boolean waitForActive) throws Exception { LOG.debug(String.format("Splitting Shard %s at %s", shardId, targetHash.toString())); KinesisOperation split = new KinesisOperation() { public Object run(KinesisClient client) { final SplitShardRequest req = SplitShardRequest.builder().streamName(streamName).shardToSplit(shardId) .newStartingHashKey(targetHash.toString()).build(); client.splitShard(req); return null; } }; doOperation(kinesisClient, split, streamName, MODIFY_RETRIES, waitForActive); } public static void mergeShards(final KinesisClient kinesisClient, final String streamName, final ShardHashInfo lowerShard, final ShardHashInfo higherShard, final boolean waitForActive) throws Exception { LOG.debug(String.format("Merging Shard %s and %s", lowerShard, higherShard)); KinesisOperation merge = new KinesisOperation() { public Object run(KinesisClient client) { final MergeShardsRequest req = MergeShardsRequest.builder().streamName(streamName) .shardToMerge(lowerShard.getShardId()).adjacentShardToMerge(higherShard.getShardId()).build(); client.mergeShards(req); return null; } }; doOperation(kinesisClient, merge, streamName, MODIFY_RETRIES, waitForActive); } private static Object doOperation(KinesisClient kinesisClient, KinesisOperation operation, String streamName, int retries, boolean waitForActive) throws Exception { boolean done = false; int attempts = 0; Object result = null; do { attempts++; try { result = operation.run(kinesisClient); if (waitForActive) { waitForStreamStatus(kinesisClient, streamName, "ACTIVE"); } done = true; } catch (ResourceInUseException e) { // thrown when the Shard is mutating - wait until we are able to // do the modification or ResourceNotFoundException is thrown Thread.sleep(1000); } catch (LimitExceededException lee) { // API Throttling LOG.warn(String.format("LimitExceededException for Stream %s", streamName)); Thread.sleep(getTimeoutDuration(attempts)); } } while (!done && attempts < retries); if (!done) { throw new Exception(String.format("Unable to Complete Kinesis Operation after %s Retries", retries)); } else { return result; } } // calculate an exponential backoff based on the attempt count private static final long getTimeoutDuration(int attemptCount) { return new Double(Math.pow(2, attemptCount) * RETRY_TIMEOUT_MS).longValue(); } private static final int compareShardsByStartHash(Shard o1, Shard o2) { return new BigInteger(o1.hashKeyRange().startingHashKey()) .compareTo(new BigInteger(o2.hashKeyRange().startingHashKey())); } public static int getOpenShardCount(KinesisClient kinesisClient, String streamName) throws Exception { return StreamScalingUtils.describeStream(kinesisClient, streamName).openShardCount(); } /** * Get a list of all Open shards ordered by their start hash * * @param streamName * @return A Map of only Open Shards indexed by the Shard ID */ public static Map<String, ShardHashInfo> getOpenShards(KinesisClient kinesisClient, String streamName, String lastShardId) throws Exception { return getOpenShards(kinesisClient, streamName, SortOrder.ASCENDING, lastShardId); } public static ShardHashInfo getOpenShard(KinesisClient kinesisClient, String streamName, String shardId) throws Exception { Shard s = getShard(kinesisClient, streamName, shardId); if (!s.shardId().equals(shardId)) { throw new Exception(String.format("Shard %s not found in Stream %s", shardId, streamName)); } else { return new ShardHashInfo(streamName, s); } } public static Map<String, ShardHashInfo> getOpenShards(KinesisClient kinesisClient, String streamName, SortOrder sortOrder, String lastShardId) throws Exception { Collection<String> openShardNames = new ArrayList<>(); Map<String, ShardHashInfo> shardMap = new LinkedHashMap<>(); // load all the open shards on the Stream and sort if required for (Shard shard : listShards(kinesisClient, streamName, lastShardId)) { openShardNames.add(shard.shardId()); shardMap.put(shard.shardId(), new ShardHashInfo(streamName, shard)); // remove this Shard's parents from the set of active shards - they // are now closed and cannot be modified or written to if (shard.parentShardId() != null) { openShardNames.remove(shard.parentShardId()); shardMap.remove(shard.parentShardId()); } if (shard.adjacentParentShardId() != null) { openShardNames.remove(shard.adjacentParentShardId()); shardMap.remove(shard.adjacentParentShardId()); } } // create a List of Open shards for sorting List<Shard> sortShards = new ArrayList<>(); for (String s : openShardNames) { // paranoid null check in case we get a null map entry if (s != null) { sortShards.add(shardMap.get(s).getShard()); } } if (sortOrder.equals(SortOrder.ASCENDING)) { // sort the list into lowest start hash order Collections.sort(sortShards, new Comparator<Shard>() { public int compare(Shard o1, Shard o2) { return compareShardsByStartHash(o1, o2); } }); } else if (sortOrder.equals(SortOrder.DESCENDING)) { // sort the list into highest start hash order Collections.sort(sortShards, new Comparator<Shard>() { public int compare(Shard o1, Shard o2) { return compareShardsByStartHash(o1, o2) * -1; } }); } // else we were supplied a NONE sort order so no sorting // build the Shard map into the correct order shardMap.clear(); for (Shard s : sortShards) { shardMap.put(s.shardId(), new ShardHashInfo(streamName, s)); } return shardMap; } public static void sendNotification(SnsClient snsClient, String notificationARN, String subject, String message) { final PublishRequest req = PublishRequest.builder().topicArn(notificationARN).message(message).subject(subject) .build(); snsClient.publish(req); } }