一. 概述
参考开源项目https://github.com/xkcoding/spring-boot-demo
本Demo使用通过 AOP 结合 Redis + Lua 脚本实现分布式限流, 旨在保护 API 被恶意频繁访问的问题
本Demo限流策略为:滑动窗口计数器算法
二. 滑动窗口计数器算法
滑动窗口计数器算法 算的上是固定窗口计数器算法的升级版。 滑动窗口计数器算法相比于固定窗口计数器算法的优化在于:它把时间以一定比例分片 。 例如我们的借口限流每分钟处理 60 个请求,我们可以把 1 分钟分为 60 个窗口。每隔 1 秒移动一次,每个窗口一秒只能处理 不大于 60(请求数)/60(窗口数) 的请求, 如果当前窗口的请求计数总和超过了限制的数量的话就不再处理其他请求。 很显然, 当滑动窗口的格子划分的越多,滑动窗口的滚动就越平滑,限流的统计就会越精确。
三. SpringBootDemo
3.1 依赖
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<!-- 对象池,使用redis时必须引入 -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-pool2</artifactId>
</dependency>
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
</dependency>
3.2 application.yml
server:
port: 8080
servlet:
context-path: /demo
spring:
redis:
host: localhost
# 连接超时时间(记得添加单位,Duration)
timeout: 10000ms
# Redis默认情况下有16个分片,这里配置具体使用的分片
# database: 0
lettuce:
pool:
# 连接池最大连接数(使用负值表示没有限制) 默认 8
max-active: 8
# 连接池最大阻塞等待时间(使用负值表示没有限制) 默认 -1
max-wait: -1ms
# 连接池中的最大空闲连接 默认 8
max-idle: 8
# 连接池中的最小空闲连接 默认 0
min-idle: 0
3.3 LUA脚本: scripts/redis/limit.lua
-- 下标从 1 开始 获取key
local key = KEYS[1]
-- 下标从 1 开始 获取参数
local now = tonumber(ARGV[1]) -- 当前时间错
local ttl = tonumber(ARGV[2]) -- 有效
local expired = tonumber(ARGV[3]) --
local max = tonumber(ARGV[4])
-- 清除过期的数据
-- 移除指定分数区间内的所有元素,expired 即已经过期的 score
-- 根据当前时间毫秒数 - 超时毫秒数,得到过期时间 expired
redis.call('zremrangebyscore', key, 0, expired)
-- 获取 zset 中的当前元素个数
local current = tonumber(redis.call('zcard', key))
local next = current + 1
if next > max then
-- 达到限流大小 返回 0
return 0;
else
-- 往 zset 中添加一个值、得分均为当前时间戳的元素,[value,score]
redis.call("zadd", key, now, now)
-- 每次访问均重新设置 zset 的过期时间,单位毫秒
redis.call("pexpire", key, ttl)
return next
end
3.4 脚本注入:RedisConfig.java
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.scripting.support.ResourceScriptSource;
@Configuration
public class RedisConfig {
@Bean
public RedisScript<Long> limitRedisScript() {
DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("scripts/redis/limit.lua")));
redisScript.setResultType(Long.class);
return redisScript;
}
}
3.5 限流注解:RateLimiter.java
限流注解,添加了 {@link AliasFor} 必须通过 {@link AnnotationUtils} 获取,才会生效
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
long DEFAULT_REQUEST = 10;
/**
* max 最大请求数
*/
@AliasFor("max") long value() default DEFAULT_REQUEST;
/**
* max 最大请求数
*/
@AliasFor("value") long max() default DEFAULT_REQUEST;
/**
* 限流key,默认为: 类路径+方法名
*/
String key() default "";
/**
* 滑动窗口时间,默认1分钟
*/
long timeout() default 1;
/**
* 滑动窗口时间单位,默认 分钟
*/
TimeUnit timeUnit() default TimeUnit.MINUTES;
}
3.6 代理: RateLimiterAspect.java
@Slf4j
@Aspect
@Component
public class RateLimiterAspect {
private final static String SEPARATOR = ":";
private final static String REDIS_LIMIT_KEY_PREFIX = "limit:";
@Resource
private StringRedisTemplate stringRedisTemplate;
@Resource
private RedisScript<Long> limitRedisScript;
@Pointcut("@annotation(com.xkcoding.ratelimit.redis.annotation.RateLimiter)")
public void rateLimit() {
}
@Around("rateLimit()")
public Object pointcut(ProceedingJoinPoint point) throws Throwable {
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
// 通过 AnnotationUtils.findAnnotation 获取 RateLimiter 注解
RateLimiter rateLimiter = AnnotationUtils.findAnnotation(method, RateLimiter.class);
if (rateLimiter != null) {
String key = rateLimiter.key();
// 默认用类名+方法名做限流的 key 前缀
if (StrUtil.isBlank(key)) {
key = method.getDeclaringClass().getName() + StrUtil.DOT + method.getName();
}
// 最终限流的 key 为 前缀 + IP地址
// TODO: 此时需要考虑局域网多用户访问的情况,因此 key 后续需要加上方法参数更加合理
key = key + SEPARATOR + IpUtil.getIpAddr();
long max = rateLimiter.max();
long timeout = rateLimiter.timeout();
TimeUnit timeUnit = rateLimiter.timeUnit();
boolean limited = shouldLimited(key, max, timeout, timeUnit);
if (limited) {
throw new RuntimeException("手速太快了,慢点儿吧~");
}
}
return point.proceed();
}
private boolean shouldLimited(String key, long max, long timeout, TimeUnit timeUnit) {
// 最终的 key 格式为:
// limit:自定义key:IP
// limit:类名.方法名:IP
key = REDIS_LIMIT_KEY_PREFIX + key;
// 统一使用单位毫秒
long ttl = timeUnit.toMillis(timeout);
// 当前时间毫秒数
long now = Instant.now().toEpochMilli();
long expired = now - ttl;
/**
* 注意这里必须转为 String,否则会报错 java.lang.Long cannot be cast to java.lang.String
* stringRedisTemplate.execute(RedisScript<T> script, List<K> keys, Object... args)
*/
Long executeTimes = stringRedisTemplate.execute(limitRedisScript, Collections.singletonList(key), now + "", ttl + "", expired + "", max + "");
if (executeTimes != null) {
if (executeTimes == 0) {
log.error("【{}】在单位时间 {} 毫秒内已达到访问上限,当前接口上限 {}", key, ttl, max);
return true;
} else {
log.info("【{}】在单位时间 {} 毫秒内访问 {} 次", key, ttl, executeTimes);
return false;
}
}
return false;
}
}
3.7 工具
@Slf4j
public class IpUtil {
private final static String UNKNOWN = "unknown";
private final static int MAX_LENGTH = 15;
/**
* 获取IP地址
* 使用Nginx等反向代理软件, 则不能通过request.getRemoteAddr()获取IP地址
* 如果使用了多级反向代理的话,X-Forwarded-For的值并不止一个,而是一串IP地址,X-Forwarded-For中第一个非unknown的有效IP字符串,则为真实IP地址
*/
public static String getIpAddr() {
HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
String ip = null;
try {
ip = request.getHeader("x-forwarded-for");
if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (StrUtil.isEmpty(ip) || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_CLIENT_IP");
}
if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_X_FORWARDED_FOR");
}
if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
} catch (Exception e) {
log.error("IPUtils ERROR ", e);
}
// 使用代理,则获取第一个IP地址
if (!StrUtil.isEmpty(ip) && ip.length() > MAX_LENGTH) {
if (ip.indexOf(StrUtil.COMMA) > 0) {
ip = ip.substring(0, ip.indexOf(StrUtil.COMMA));
}
}
return ip;
}
}
3.8 使用
@Slf4j
@RestController
public class TestController {
/**
* 每分钟只能请求5次
* @return
*/
@RateLimiter(value = 5)
@GetMapping("/test1")
public Dict test1() {
log.info("【test1】被执行了。。。。。");
return Dict.create().set("msg", "hello,world!").set("description", "别想一直看到我,不信你快速刷新看看~");
}
/**
* 每分钟只能请求2次
* @return
*/
@RateLimiter(value = 2, key = "测试自定义key")
@GetMapping("/test3")
public Dict test3() {
log.info("【test3】被执行了。。。。。");
return Dict.create().set("msg", "hello,world!").set("description", "别想一直看到我,不信你快速刷新看看~");
}
}