若依源码:接口限流功能的实现

本文发布于 2025年01月07日,阅读 9 次,点赞 0 次,归类于 源码分析

博客:https://www.emanjusaka.com

公众号:emanjusaka的编程栈

by emanjusaka from https://www.emanjusaka.com/archives/ruoyi-ratelimiter 彼岸花开可奈何

本文是若依的源码解读,这是一个系列文章,欢迎关注我的博客或者微信公众号获取后续文章更新。

若依项目中接口限流是通过注解 +AOP+lua 脚本去实现的,下面我们来分析一下具体的代码实现。

定义注解

RateLimiter.java


/**
 * 限流注解
 * 
 * @author ruoyi
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter
{
    /**
     * 限流key
     */
    public String key() default CacheConstants.RATE_LIMIT_KEY;

    /**
     * 限流时间,单位秒
     */
    public int time() default 60;

    /**
     * 限流次数
     */
    public int count() default 100;

    /**
     * 限流类型
     */
    public LimitType limitType() default LimitType.DEFAULT;
}

首先定义一个注解 @RateLimiter,其中包含四个参数分别是限流 key、限流时间、限流次数、限流类型。

全部都给了默认值,限流 key 默认值为 rate_limit:​,这里设置的 key 只是限流 key 的前缀部分,具体使用还会追加类名和方法名。如果限流类型是 IP 还要再加上 IP。

限流时间默认值是 60s。

限流次数默认值是 100 次。

限流时间内访问接口的次数不能超过限流次数。

限流类型默认是全局限流,也就是只要是调用接口就会限制。

通过 AOP 拦截请求实现限流

RateLimiterAspect.java

/**
 * 限流处理
 *
 * @author ruoyi
 */
@Aspect
@Component
public class RateLimiterAspect
{
    private static final Logger log = LoggerFactory.getLogger(RateLimiterAspect.class);

    private RedisTemplate<Object, Object> redisTemplate;

    private RedisScript<Long> limitScript;

    // 注入 redisTemplate
    @Autowired
    public void setRedisTemplate1(RedisTemplate<Object, Object> redisTemplate)
    {
        this.redisTemplate = redisTemplate;
    }
    // 注入 limitScript
    @Autowired
    public void setLimitScript(RedisScript<Long> limitScript)
    {
        this.limitScript = limitScript;
    }
    // 在有注解@RateLimiter的方法前执行
    @Before("@annotation(rateLimiter)")
    public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable
    {
        int time = rateLimiter.time();
        int count = rateLimiter.count();
        // 获取组合的键
        String combineKey = getCombineKey(rateLimiter, point);
        List<Object> keys = Collections.singletonList(combineKey);
        try
        {
            // 执行Redis脚本,获取限流结果
            Long number = redisTemplate.execute(limitScript, keys, count, time);
            // 如果结果为空或者超过限制次数
            if (StringUtils.isNull(number) || number.intValue() > count)
            {
                throw new ServiceException("访问过于频繁,请稍候再试");
            }
            log.info("限制请求'{}',当前请求'{}',缓存key'{}'", count, number.intValue(), combineKey);
        }
        catch (ServiceException e)
        {
            throw e;
        }
        catch (Exception e)
        {
            throw new RuntimeException("服务器限流异常,请稍候再试");
        }
    }

    public String getCombineKey(RateLimiter rateLimiter, JoinPoint point)
    {
        // 创建一个StringBuilder对象,并初始化为RateLimiter的key
        StringBuffer stringBuffer = new StringBuffer(rateLimiter.key());
        // 如果RateLimiter的限制类型是IP
        if (rateLimiter.limitType() == LimitType.IP)
        {
            // 获取当前请求的IP地址,并追加到StringBuilder对象后,再追加一个"-"
            stringBuffer.append(IpUtils.getIpAddr()).append("-");
        }
        // 获取JoinPoint的MethodSignature对象
        MethodSignature signature = (MethodSignature) point.getSignature();
        // 获取MethodSignature对象的方法对象
        Method method = signature.getMethod();
        // 获取方法的声明类
        Class<?> targetClass = method.getDeclaringClass();
        // 将方法的声明类和方法的名称追加到StringBuilder对象后,并在它们之间追加一个"-"
        stringBuffer.append(targetClass.getName()).append("-").append(method.getName());
        // 返回StringBuilder对象的字符串表示
        return stringBuffer.toString();
    }
}

既然是实现接口限流功能肯定是要在切点前面行(也就是在接口执行之前),所以使用前置通知@Before​。

通过执行 lua 脚本获取接口在限流时间内的执行次数,如果超过了限流次数就抛出异常限制接口的调用。

通过 lua 脚本可以保证操作的原子性。

限流类型是 IP 的话获取组合键的时候需要获取请求来自的 IP。这里若依采取的是自己实现的方法,在这里对其不再展开,如果感兴趣它的实现可以看我的另一篇文章若依源码:获取 IP 方法的工具类

注入 lua 限流脚本

RedisConfig.java

/**
 * redis配置
 *
 * @author ruoyi
 */
@Configuration
@EnableCaching
public class RedisConfig {
    @Bean
    @SuppressWarnings(value = {"unchecked", "rawtypes"})
    public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory connectionFactory) {
        RedisTemplate<Object, Object> template = new RedisTemplate<>();
        template.setConnectionFactory(connectionFactory);

        FastJson2JsonRedisSerializer serializer = new FastJson2JsonRedisSerializer(Object.class);

        // 使用StringRedisSerializer来序列化和反序列化redis的key值
        template.setKeySerializer(new StringRedisSerializer());
        template.setValueSerializer(serializer);

        // Hash的key也采用StringRedisSerializer的序列化方式
        template.setHashKeySerializer(new StringRedisSerializer());
        template.setHashValueSerializer(serializer);

        template.afterPropertiesSet();
        return template;
    }

    @Bean
    public DefaultRedisScript<Long> limitScript() {
        DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
        redisScript.setScriptText(limitScriptText());
        redisScript.setResultType(Long.class);
        return redisScript;
    }

    /**
     * 限流脚本
     */
    private String limitScriptText() {
        return "local key = KEYS[1]\n" +
                "local count = tonumber(ARGV[1])\n" +
                "local time = tonumber(ARGV[2])\n" +
                "local current = redis.call('get', key);\n" +
                "if current and tonumber(current) > count then\n" +
                "    return tonumber(current);\n" +
                "end\n" +
                "current = redis.call('incr', key)\n" +
                "if tonumber(current) == 1 then\n" +
                "    redis.call('expire', key, time)\n" +
                "end\n" +
                "return tonumber(current);";
    }
}

这个 Redis 的配置类做了两件事,一个是设置 RedisTemplate 的序列化规则,一个就是注入了 lua 脚本。

设置 RedisTemplate 的序列化规则代码比较固定没啥可说的。我们主要讲解一下 lua 脚本。

local key = KEYS[1]

定义键名为脚本的第一个参数

local count = tonumber(ARGV[1])

将脚本的第二个参数转换为数字,作为限制的数量

local current = redis.call('get', key);

将脚本的第三个参数转换为数字,作为过期时间

if current and tonumber(current) > count then
    return tonumber(current);
end

如果当前值存在且大于限制数量就返回当前值

current = redis.call('incr', key)

对键的值进行递增操作

if tonumber(current) == 1 then
    redis.call('expire', key, time)
end

如果递增后的值为1,表示这是第一次设置键

设置键的过期时间

使用方法

@RestController
public class TestController {
    @RateLimiter(count = 10, time = 10, limitType = LimitType.IP)
    @GetMapping("/test")
    public void test() {
        System.out.println("test");
    }
}

通过注解的方式使用,也可以不设置参数都采用默认值。怎样调整参数看具体需求。

本篇完