RateLimiter
/**
* Rate limiter implementation is based on token bucket algorithm. There are two parameters:
* <ul>
* <li>
* burst size - maximum number of requests allowed into the system as a burst
* </li>
* <li>
* average rate - expected number of requests per second (RateLimiters using MINUTES is also supported)
* </li>
* </ul>
*
* @author Tomasz Bak
*/
public class RateLimiter {
private final long rateToMsConversion;
private final AtomicInteger consumedTokens = new AtomicInteger();
private final AtomicLong lastRefillTime = new AtomicLong(0);
@Deprecated
public RateLimiter() {
this(TimeUnit.SECONDS);
}
public RateLimiter(TimeUnit averageRateUnit) {
switch (averageRateUnit) {
case SECONDS:
rateToMsConversion = 1000;
break;
case MINUTES:
rateToMsConversion = 60 * 1000;
break;
default:
throw new IllegalArgumentException("TimeUnit of " + averageRateUnit + " is not supported");
}
}
public boolean acquire(int burstSize, long averageRate) {
return acquire(burstSize, averageRate, System.currentTimeMillis());
}
public boolean acquire(int burstSize, long averageRate, long currentTimeMillis) {
if (burstSize <= 0 || averageRate <= 0) { // Instead of throwing exception, we just let all the traffic go
return true;
}
refillToken(burstSize, averageRate, currentTimeMillis);
return consumeToken(burstSize);
}
private void refillToken(int burstSize, long averageRate, long currentTimeMillis) {
long refillTime = lastRefillTime.get();
long timeDelta = currentTimeMillis - refillTime;
long newTokens = timeDelta * averageRate / rateToMsConversion;
if (newTokens > 0) {
long newRefillTime = refillTime == 0
? currentTimeMillis
: refillTime + newTokens * rateToMsConversion / averageRate;
if (lastRefillTime.compareAndSet(refillTime, newRefillTime)) {
while (true) {
int currentLevel = consumedTokens.get();
int adjustedLevel = Math.min(currentLevel, burstSize); // In case burstSize decreased
int newLevel = (int) Math.max(0, adjustedLevel - newTokens);
if (consumedTokens.compareAndSet(currentLevel, newLevel)) {
return;
}
}
}
}
}
private boolean consumeToken(int burstSize) {
while (true) {
int currentLevel = consumedTokens.get();
if (currentLevel >= burstSize) {
return false;
}
if (consumedTokens.compareAndSet(currentLevel, currentLevel + 1)) {
return true;
}
}
}
public void reset() {
consumedTokens.set(0);
lastRefillTime.set(0);
}
}
调用
private boolean isOverloaded(Target target) {
int maxInWindow = serverConfig.getRateLimiterBurstSize();
int fetchWindowSize = serverConfig.getRateLimiterRegistryFetchAverageRate();
boolean overloaded = !registryFetchRateLimiter.acquire(maxInWindow, fetchWindowSize);
if (target == Target.FullFetch) {
int fullFetchWindowSize = serverConfig.getRateLimiterFullFetchAverageRate();
overloaded |= !registryFullFetchRateLimiter.acquire(maxInWindow, fullFetchWindowSize);
}
return overloaded;
}