/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * @author Aiyun Tang * @mail [email protected] */ package com.tay.redislimiter.core; import java.util.ArrayList; import java.util.List; import java.util.concurrent.TimeUnit; import lombok.RequiredArgsConstructor; import redis.clients.jedis.Jedis; import redis.clients.jedis.JedisPool; /** * Redis based Rate limiter * * @author Aiyun Tang <[email protected]> */ @RequiredArgsConstructor public class RedisRateLimiter { private JedisPool jedisPool; private TimeUnit timeUnit; // private int permitsPerUnit; private static final String LUA_SECOND_SCRIPT = " local current; " + " current = redis.call('incr',KEYS[1]); " + " if tonumber(current) == 1 then " + " redis.call('expire',KEYS[1],ARGV[1]); " + " return 1; " + " else" + " if tonumber(current) <= tonumber(ARGV[2]) then " + " return 1; " + " else " + " return -1; " + " end " + " end "; private static final String LUA_PERIOD_SCRIPT = " local currentSectionCount;" + " local previosSectionCount;" + " local totalCountInPeriod;" + " currentSectionCount = redis.call('zcount', KEYS[2], '-inf', '+inf');" + " previosSectionCount = redis.call('zcount', KEYS[1], ARGV[3], '+inf');" + " totalCountInPeriod = tonumber(currentSectionCount)+tonumber(previosSectionCount);" + " if totalCountInPeriod < tonumber(ARGV[5]) then " + " redis.call('zadd',KEYS[2],ARGV[1],ARGV[2]);" + " if tonumber(currentSectionCount) == 0 then " + " redis.call('expire',KEYS[2],ARGV[4]); " + " end " + " return 1" + " else " + " return -1" + " end "; private static final int PERIOD_SECOND_TTL = 10; private static final int PERIOD_MINUTE_TTL = 2 * 60 + 10; private static final int PERIOD_HOUR_TTL = 2 * 3600 + 10; private static final int PERIOD_DAY_TTL = 2 * 3600 * 24 + 10; private static final long MICROSECONDS_IN_MINUTE = 60 * 1000000L; private static final long MICROSECONDS_IN_HOUR = 3600 * 1000000L; private static final long MICROSECONDS_IN_DAY = 24 * 3600 * 1000000L; public RedisRateLimiter(JedisPool jedisPool, TimeUnit timeUnit) { this.jedisPool = jedisPool; this.timeUnit = timeUnit; } public JedisPool getJedisPool() { return jedisPool; } public TimeUnit getTimeUnit() { return timeUnit; } public boolean acquire(String keyPrefix, int permitsPerUnit){ boolean rtv = false; if (jedisPool != null) { Jedis jedis = null; try { jedis = jedisPool.getResource(); if (timeUnit == TimeUnit.SECONDS) { String keyName = getKeyNameForSecond(jedis, keyPrefix); List<String> keys = new ArrayList<String>(); keys.add(keyName); List<String> argvs = new ArrayList<String>(); argvs.add(String.valueOf(getExpire())); argvs.add(String.valueOf(permitsPerUnit)); Long val = (Long)jedis.eval(LUA_SECOND_SCRIPT, keys, argvs); rtv = (val > 0); } else if (timeUnit == TimeUnit.MINUTES || timeUnit == TimeUnit.HOURS || timeUnit == TimeUnit.DAYS) { rtv = doPeriod(jedis, keyPrefix, permitsPerUnit); } } finally { if (jedis != null) { jedis.close(); } } } return rtv; } private boolean doPeriod(Jedis jedis, String keyPrefix, int permitsPerUnit) { List<String> jedisTime = jedis.time(); long currentSecond = Long.parseLong(jedisTime.get(0)); long microSecondsElapseInCurrentSecond = Long.parseLong(jedisTime.get(1)); String[] keyNames = getKeyNames(currentSecond, keyPrefix); //因为redis访问实际上是单线程的,而且jedis.time()方法返回的时间精度为微秒级,每一个jedis.time()调用耗时应该会超过1微秒,因此我们可以认为每次jedis.time()返回的时间都是唯一且递增 //因此这个currentTimeInMicroSecond在多线程情况下不会存在相同 long currentTimeInMicroSecond = currentSecond * 1000000 + microSecondsElapseInCurrentSecond; String previousSectionBeginScore = String.valueOf((currentTimeInMicroSecond - getPeriodMicrosecond())); String expires =String.valueOf(getExpire()); String currentTimeInMicroSecondStr = String.valueOf(currentTimeInMicroSecond); List<String> keys = new ArrayList<String>(); keys.add(keyNames[0]); keys.add(keyNames[1]); List<String> argvs = new ArrayList<String>(); argvs.add(currentTimeInMicroSecondStr); argvs.add(currentTimeInMicroSecondStr); argvs.add(previousSectionBeginScore); argvs.add(expires); argvs.add(String.valueOf(permitsPerUnit)); Long val = (Long)jedis.eval(LUA_PERIOD_SCRIPT, keys, argvs); return (val > 0); } // private long getRedisTime(Jedis jedis) { // List<String> jedisTime = jedis.time(); // long currentSecond = Long.parseLong(jedisTime.get(0)); // long microSecondsElapseInCurrentSecond = Long.parseLong(jedisTime.get(1)); // long currentTimeInMicroSecond = currentSecond * 1000000 + microSecondsElapseInCurrentSecond; // return currentTimeInMicroSecond; // } private String getKeyNameForSecond(Jedis jedis, String keyPrefix) { return keyPrefix + ":" + jedis.time().get(0); } private String[] getKeyNames(long currentSecond, String keyPrefix) { String[] keyNames = null; if (timeUnit == TimeUnit.MINUTES) { long index = currentSecond / 60; String keyName1 = keyPrefix + ":" + (index - 1); String keyName2 = keyPrefix + ":" + index; keyNames = new String[] { keyName1, keyName2 }; } else if (timeUnit == TimeUnit.HOURS) { long index = currentSecond / 3600; String keyName1 = keyPrefix + ":" + (index - 1); String keyName2 = keyPrefix + ":" + index; keyNames = new String[] { keyName1, keyName2 }; } else if (timeUnit == TimeUnit.DAYS) { long index = currentSecond / (3600 * 24); String keyName1 = keyPrefix + ":" + (index - 1); String keyName2 = keyPrefix + ":" + index; keyNames = new String[] { keyName1, keyName2 }; } else { throw new java.lang.IllegalArgumentException("Don't support this TimeUnit: " + timeUnit); } return keyNames; } private int getExpire() { int expire = 0; if (timeUnit == TimeUnit.SECONDS) { expire = PERIOD_SECOND_TTL; } else if (timeUnit == TimeUnit.MINUTES) { expire = PERIOD_MINUTE_TTL; } else if (timeUnit == TimeUnit.HOURS) { expire = PERIOD_HOUR_TTL; } else if (timeUnit == TimeUnit.DAYS) { expire = PERIOD_DAY_TTL; } else { throw new java.lang.IllegalArgumentException("Don't support this TimeUnit: " + timeUnit); } return expire; } private long getPeriodMicrosecond() { if (timeUnit == TimeUnit.MINUTES) { return MICROSECONDS_IN_MINUTE; } else if (timeUnit == TimeUnit.HOURS) { return MICROSECONDS_IN_HOUR; } else if (timeUnit == TimeUnit.DAYS) { return MICROSECONDS_IN_DAY; } else { throw new java.lang.IllegalArgumentException("Don't support this TimeUnit: " + timeUnit); } } }