Spring Boot 整合Redis使用Lua脚本实现限流

目录

    • 一、简介
    • 二、maven依赖
    • 三、编码实现
      • 3.1、配置文件
      • 3.2、配置类
      • 3.3、注解类
      • 3.4、切面类
      • 3.5、lua脚本
      • 3.6、自定义异常和全局异常
      • 3.7、控制层
    • 四、验证
    • 4.1、单用户限流
    • 4.2、接口限流
    • 结语

一、简介

  本篇文章主要来讲Spring Boot 整合Redis使用Lua脚本实现限流,实现限流有多种方式,我们今天主要讲使用Lua脚本。

  为什么我们使用Lua脚本来限流?因为Lua脚本具有原子性,那为什么lua脚本具有原子性?简单来说,因为Redis使用相同的Lua解释器来运行所有命令,Redis保证脚本以原子方式执行:在执行脚本时,不会执行其他脚本或Redis命令。因此从所有其他客户端的角度来看,脚本的效果要么仍然不可见,要么已经执行完成了。

二、maven依赖

pom.xml


<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0modelVersion>
    <parent>
        <groupId>org.springframework.bootgroupId>
        <artifactId>spring-boot-starter-parentartifactId>
        <version>2.6.0version>
        <relativePath/> 
    parent>

    <groupId>com.aliangroupId>
    <artifactId>redis-limit-luaartifactId>
    <version>0.0.1-SNAPSHOTversion>
    <name>redisCachename>
    <description>redis-limit-luadescription>

    <properties>
        <project.build.sourceEncoding>UTF-8project.build.sourceEncoding>
        <project.reporting.outputEncoding>UTF-8project.reporting.outputEncoding>
        <project.package.directory>targetproject.package.directory>
        <java.version>1.8java.version>
        
        <jackson.version>2.9.10jackson.version>
        
        <lombok.version>1.16.14lombok.version>
        
        <fastjson.version>1.2.68fastjson.version>
        
        <junit.version>4.12junit.version>
    properties>

    <dependencies>

        <dependency>
            <groupId>org.springframework.bootgroupId>
            <artifactId>spring-boot-starter-webartifactId>
        dependency>

		
        <dependency>
            <groupId>org.springframework.bootgroupId>
            <artifactId>spring-boot-starter-aopartifactId>
        dependency>

        
        <dependency>
            <groupId>org.springframework.bootgroupId>
            <artifactId>spring-boot-starter-data-redisartifactId>
        dependency>

        
        <dependency>
            <groupId>com.fasterxml.jackson.coregroupId>
            <artifactId>jackson-databindartifactId>
            <version>${jackson.version}version>
        dependency>

        
        <dependency>
            <groupId>com.fasterxml.jackson.datatypegroupId>
            <artifactId>jackson-datatype-jsr310artifactId>
            <version>${jackson.version}version>
        dependency>

        
        <dependency>
            <groupId>com.alibabagroupId>
            <artifactId>fastjsonartifactId>
            <version>${fastjson.version}version>
        dependency>

        <dependency>
            <groupId>org.apache.commonsgroupId>
            <artifactId>commons-lang3artifactId>
            <version>3.12.0version>
        dependency>

        
        <dependency>
            <groupId>org.projectlombokgroupId>
            <artifactId>lombokartifactId>
            <version>${lombok.version}version>
        dependency>

    dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.bootgroupId>
                <artifactId>spring-boot-maven-pluginartifactId>
            plugin>
        plugins>
    build>

project>

三、编码实现

3.1、配置文件

application.properties

# 端口
server.port=8090
# 上下文路径
server.servlet.context-path=/rateLimit

# Redis数据库索引(默认为0)
spring.redis.database=0
# Redis服务器地址
spring.redis.host=127.0.0.1
# Redis服务器连接端口
spring.redis.port=6379
# Redis服务器连接密码(默认为空)
spring.redis.password=123456
# 连接池最大连接数(使用负值表示没有限制)
spring.redis.jedis.pool.max-active=20
# 连接池中的最小空闲连接
spring.redis.jedis.pool.min-idle=10
# 连接池中的最大空闲连接
spring.redis.jedis.pool.max-idle=10
# 连接池最大阻塞等待时间(使用负值表示没有限制)
spring.redis.jedis.pool.max-wait=20000
# 读时间(毫秒)
spring.redis.timeout=10000
# 连接超时时间(毫秒)
spring.redis.connect-timeout=10000

3.2、配置类

RedisConfiguration.java

package com.alian.redisLimit.config;

import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.fasterxml.jackson.datatype.jsr310.deser.LocalDateDeserializer;
import com.fasterxml.jackson.datatype.jsr310.deser.LocalDateTimeDeserializer;
import com.fasterxml.jackson.datatype.jsr310.deser.LocalTimeDeserializer;
import com.fasterxml.jackson.datatype.jsr310.ser.LocalDateSerializer;
import com.fasterxml.jackson.datatype.jsr310.ser.LocalDateTimeSerializer;
import com.fasterxml.jackson.datatype.jsr310.ser.LocalTimeSerializer;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cache.annotation.EnableCaching;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.RedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.scripting.support.ResourceScriptSource;

import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.time.format.DateTimeFormatter;

@Slf4j
@Configuration
@EnableCaching
public class RedisConfiguration {

    @Bean
    public RedisScript<Boolean> redisRequestRateLimiterScript() {
        DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
        redisScript.setScriptSource(new ResourceScriptSource(
                new ClassPathResource("META-INF/scripts/request_rate_limiter.lua")));
        redisScript.setResultType(Boolean.class);
        return redisScript;
    }

    /**
     * redis配置
     *
     * @param redisConnectionFactory
     * @return
     */
    @Bean
    public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
        // 实例化redisTemplate
        RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
        //设置连接工厂
        redisTemplate.setConnectionFactory(redisConnectionFactory);
        // key采用String的序列化
        redisTemplate.setKeySerializer(keySerializer());
        // value采用jackson序列化
        redisTemplate.setValueSerializer(valueSerializer());
        // Hash key采用String的序列化
        redisTemplate.setHashKeySerializer(keySerializer());
        // Hash value采用jackson序列化
        redisTemplate.setHashValueSerializer(valueSerializer());
        //执行函数,初始化RedisTemplate
        redisTemplate.afterPropertiesSet();
        return redisTemplate;
    }

    /**
     * key类型采用String序列化
     *
     * @return
     */
    private RedisSerializer<String> keySerializer() {
        return new StringRedisSerializer();
    }

    /**
     * value采用JSON序列化
     *
     * @return
     */
    private RedisSerializer<Object> valueSerializer() {
        //设置jackson序列化
        Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class);
        //设置序列化对象
        jackson2JsonRedisSerializer.setObjectMapper(getMapper());
        return jackson2JsonRedisSerializer;
    }


    /**
     * 使用com.fasterxml.jackson.databind.ObjectMapper
     * 对数据进行处理包括java8里的时间
     *
     * @return
     */
    private ObjectMapper getMapper() {
        ObjectMapper mapper = new ObjectMapper();
        //设置可见性
        mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        //默认键入对象
        mapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
        //设置Java 8 时间序列化
        JavaTimeModule timeModule = new JavaTimeModule();
        timeModule.addSerializer(LocalDateTime.class, new LocalDateTimeSerializer(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")));
        timeModule.addSerializer(LocalDate.class, new LocalDateSerializer(DateTimeFormatter.ofPattern("yyyy-MM-dd")));
        timeModule.addSerializer(LocalTime.class, new LocalTimeSerializer(DateTimeFormatter.ofPattern("HH:mm:ss")));
        timeModule.addDeserializer(LocalDateTime.class, new LocalDateTimeDeserializer(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")));
        timeModule.addDeserializer(LocalDate.class, new LocalDateDeserializer(DateTimeFormatter.ofPattern("yyyy-MM-dd")));
        timeModule.addDeserializer(LocalTime.class, new LocalTimeDeserializer(DateTimeFormatter.ofPattern("HH:mm:ss")));
        //禁用把时间转为时间戳
        mapper.configure(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS, false);
        mapper.registerModule(timeModule);
        return mapper;
    }

}

相比我们之前整合redis,就是多了如下配置:

    @Bean
    public RedisScript<Boolean> redisRequestRateLimiterScript() {
        DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
        redisScript.setScriptSource(new ResourceScriptSource(
                new ClassPathResource("META-INF/scripts/request_rate_limiter.lua")));
        redisScript.setResultType(Boolean.class);
        return redisScript;
    }
  • 实例化默认的DefaultRedisScript
  • 设置Lua 脚本的路径(一般放到resources/META-INF/scripts/ 目录下)
  • 设置Lua 脚本执行返回的结果类型,脚本返回的类型要和这里返回的结果类型要一致,本文返回的结果是布尔值,所以结果是 Boolean.class,如果你Lua执行后返回的结果是字符串,数字或者多个对象(元组),则设置对应的结果类型:String.classLong.classList.class

3.3、注解类

RateLimiters.java

package com.alian.redisLimit.annotate;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimiters {

    RateLimiter[] value();
}

上面就是一个复合注解。

RateLimiter.java

package com.alian.redisLimit.annotate;

import java.lang.annotation.*;

@Documented
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimiter {

    /**
     * Spel表达式
     */
    String [] keys() default {};

    /**
     * 令牌桶的容量,默认300
     */
    int capacity() default 300;

    /**
     * 生成令牌的速度,默认每秒100个
     */
    int rate() default 100;

    /**
     * 拒绝请求时的提示信息
     */
    String showPromptMsg() default "服务器繁忙,请稍候再试";
    
}

  自定义注解也没有什么好说的,主要是定义了:

  • key的名称,用于Redis锁的键
  • 令牌桶的容量,默认300
  • 每秒生成令牌的速度,默认每秒100个
  • 限流时返回给前端的提示信息

3.4、切面类

RateLimiterAspectHandler.java

package com.alian.redisLimit.aop;

import com.alian.redisLimit.annotate.RateLimiter;
import com.alian.redisLimit.annotate.RateLimiters;
import com.alian.redisLimit.exception.RateLimiterException;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;
import org.springframework.util.DigestUtils;

import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;

@Slf4j
@Component
@Aspect
public class RateLimiterAspectHandler {

    @Autowired
    private RedisScript<Boolean> redisScript;

    @Autowired
    private RedisTemplate<String, Object> redisTemplate;

    @Autowired
    private RateLimiterKeyProvider keyProvider;

    @Around(value = "@annotation(rateLimiter)", argNames = "point,rateLimiter")
    public Object around(ProceedingJoinPoint point, RateLimiter rateLimiter) throws Throwable {
        isAllow(point, rateLimiter);
        return point.proceed();
    }

    @Around(value = " @annotation(rateLimiters)", argNames = "point,rateLimiters")
    public Object around(ProceedingJoinPoint point, RateLimiters rateLimiters) throws Throwable {
        RateLimiter[] limiters = rateLimiters.value();
        for (RateLimiter rateLimiter : limiters) {
            isAllow(point, rateLimiter);
        }
        return point.proceed();
    }

    private void isAllow(ProceedingJoinPoint point, RateLimiter rateLimiter) {
        // 获取key
        String key = keyProvider.getKey(point, rateLimiter);
        // 类路径+方法,然后计算md5
        String uniqueKey = getUniqueKey((MethodSignature) point.getSignature());
        // key名称
        key = StringUtils.isNotBlank(key) ? uniqueKey + "." + key : uniqueKey;
        // 拼接成最后的Redis的键,传入需要操作的key到lua脚本中
        List<String> operateKeys = getOperateKeys(key);
        // 执行lua脚本
        Boolean allowed = this.redisTemplate.execute(redisScript, operateKeys, rateLimiter.capacity(), rateLimiter.rate(), Instant.now().getEpochSecond(), 1);
        log.info("rateLimiter {}, result is {}", key, allowed);
        if (Boolean.FALSE.equals(allowed)) {
            log.warn("触发限流,key is : {} ", key);
            throw new RateLimiterException(rateLimiter.showPromptMsg());
        }
    }

    private String getUniqueKey(MethodSignature signature) {
        String format = String.format("%s.%s", signature.getDeclaringTypeName(), signature.getMethod().getName());
        return DigestUtils.md5DigestAsHex(format.getBytes(StandardCharsets.UTF_8));
    }

    private List<String> getOperateKeys(String id) {
        String tokenKey = "request_rate_limiter.{" + id + "}.token";
        String timestampKey = "request_rate_limiter.{" + id + "}.timestamp";
        return Arrays.asList(tokenKey, timestampKey);
    }

}
  • 切面是针对所有使用了@RateLimiter @RateLimiters 注解的方法
  • 首先是获取定义的key的值,设置了key就是针对特定一类限流,没设置就是针对整个接口限流
  • 获取一个方法的唯一值作为Redis中key的一部分,本文是获取类路径+方法名,然后计算md5值作为这个前缀
  • 拼接成最后的Redis的key,传到需要操作的Lua脚本中
  • 执行lua脚本,传入的key就是KEYS[] ,传入的参数就是ARGV[] ,下标从1开始取值,参数要注意类型
  • 如果未获取到则抛出异常(限流了),做一个全局异常捕获,统一返回处理

RateLimiterKeyProvider.java

package com.alian.redisLimit.aop;

import com.alian.redisLimit.annotate.RateLimiter;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.context.expression.MethodBasedEvaluationContext;
import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.ParameterNameDiscoverer;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;

@Slf4j
@Component
public class RateLimiterKeyProvider {

    private ParameterNameDiscoverer discoverer = new DefaultParameterNameDiscoverer();

    private ExpressionParser parser = new SpelExpressionParser();

    public String getKey(JoinPoint joinPoint, RateLimiter rateLimiter) {
        List<String> keyList = new ArrayList<>();
        Method method = getMethod(joinPoint);
        List<String> definitionKeys = getSpelDefinitionKey(rateLimiter.keys(), method, joinPoint.getArgs());
        keyList.addAll(definitionKeys);
        return StringUtils.collectionToDelimitedString(keyList,".","","");
    }

    private Method getMethod(JoinPoint joinPoint) {
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();
        if (method.getDeclaringClass().isInterface()) {
            try {
                method = joinPoint.getTarget().getClass().getDeclaredMethod(signature.getName(),
                        method.getParameterTypes());
            } catch (Exception e) {
                log.error(null,e);
            }
        }
        return method;
    }

    private List<String> getSpelDefinitionKey(String[] definitionKeys, Method method, Object[] parameterValues) {
        List<String> definitionKeyList = new ArrayList<>();
        for (String definitionKey : definitionKeys) {
            if (definitionKey != null && !definitionKey.isEmpty()) {
                EvaluationContext context = new MethodBasedEvaluationContext(null, method, parameterValues, discoverer);
                String key = parser.parseExpression(definitionKey).getValue(context).toString();
                definitionKeyList.add(key);
            }
        }
        return definitionKeyList;
    }
}

3.5、lua脚本

request_rate_limiter.lua

-- 传入的要操作的key:tokenKey
local tokenKey = KEYS[1]
-- 传入的要操作的key:timestampKey
local timestampKey = KEYS[2]

-- 参数1:令牌桶的大小
local capacity = tonumber(ARGV[1])
-- 参数2:生成令牌的速度
local rate = tonumber(ARGV[2])
-- 参数3:当前时间的秒数
local nowTimestamp = tonumber(ARGV[3])
-- 参数4:请求令牌数
local requested = tonumber(ARGV[4])

-- redis.log(redis.LOG_NOTICE,"tokenKey:" .. tokenKey)
-- redis.log(redis.LOG_NOTICE,"timestampKey:" .. timestampKey)
-- redis.log(redis.LOG_NOTICE,"capacity:" .. capacity)
-- redis.log(redis.LOG_NOTICE,"rate:" .. rate)
-- redis.log(redis.LOG_NOTICE,"nowTimestamp:" .. nowTimestamp)
-- redis.log(redis.LOG_NOTICE,"requested:" .. requested)

-- 计算令牌桶填充时间,令牌桶的大小/生成令牌的速度
local fillTime = capacity / rate
-- 失效时间向下取整,采用两倍填充时间保证失效时间充足
local expireTime = math.floor(fillTime * 2)

-- 从redis获取上一次tokenKey的值,如果返回nil,则初始化令牌桶,结果转为数字
local lastToken = tonumber(redis.call("get", tokenKey) or capacity)
-- 从redis获取上一次timestampKey的值,如果返回nil,则时间设置为0,结果转为数字
local lastTimestamp = tonumber(redis.call("get", timestampKey) or 0)
-- 当前时间和最后一次获取的时间的差值:秒数取值范围是从0到expireTime 或者 当前时间值
local timeGaps = math.max(0, nowTimestamp - lastTimestamp)

-- redis.log(redis.LOG_NOTICE,"fillTime:" .. fillTime)
-- redis.log(redis.LOG_NOTICE,"expireTime:" .. expireTime)
-- redis.log(redis.LOG_NOTICE,"lastToken:" .. lastToken)
-- redis.log(redis.LOG_NOTICE,"lastTimestamp:" .. lastTimestamp)
-- redis.log(redis.LOG_NOTICE,"timeGaps:" .. timeGaps)

-- 同1秒内的timeGaps的值都是0,令牌桶的数不会增加,直到扣减完,超过1秒的都会填充令牌
local filledToken = math.min(capacity, lastToken + (timeGaps * rate))
-- 新拿到的令牌值默认是填充后的filledToken
local newToken = filledToken
-- 令牌数大于等于请求令牌数说明可以获取到令牌
local allowed = filledToken >= requested

-- 如果可以拿到令牌,则令牌数扣减掉请求数,得到令牌值
if allowed
then
  newToken = filledToken - requested
end

-- redis.log(redis.LOG_NOTICE,"filledToken:" .. filledToken)
-- redis.log(redis.LOG_NOTICE,"allowed:" .. tostring(allowed))
-- redis.log(redis.LOG_NOTICE,"newToken:" .. newToken)

-- 通过redis设置tokenKey的值为newToken,失效时间为expireTime
redis.call("setex", tokenKey, expireTime, newToken)
-- 通过redis设置timestampKey的值为nowTimestamp,失效时间为expireTime
redis.call("setex", timestampKey, expireTime, nowTimestamp)

-- 返回结果:是否拿到令牌
return allowed

  我想我的脚本已经解释的很详细的,小伙伴可要认真看哦。顺便提一下,如果要打印Lua脚本的日志,则可以使用如下方式:

redis.log(redis.LOG_NOTICE,"filledToken:" .. filledToken)

  需要注意的是修改redis的配置文件的两个值:日志级别和日志文件

# Specify the server verbosity level.
# This can be one of:
# debug (a lot of information, useful for development/testing)
# verbose (many rarely useful info, but not a mess like the debug level)
# notice (moderately verbose, what you want in production probably)
# warning (only very important / critical messages are logged)
loglevel notice

# Specify the log file name. Also 'stdout' can be used to force
# Redis to log on the standard output.
logfile "C:/myProgram/Redis-x64-5.0.14.1/redis.log"

  redis.LOG_NOTICE 对应你配置的 loglevel,改完保存后,重启服务,windows环境记得带配置文件启动

C:\myProgram\Redis-x64-5.0.14.1>redis-server.exe redis.windows.conf

3.6、自定义异常和全局异常

RateLimiterException.java

package com.alian.redisLimit.exception;

public class RateLimiterException extends RuntimeException {

    public RateLimiterException(String message) {
        super(message);
    }

}

  自定义异常类,也没啥好说的,下面就是全局异常,为了省篇幅没有把所有的异常都列出来,小伙伴可以自行添加,主要是对我们RateLimiterException 进行处理。

GlobalExceptionHandler.java

package com.alian.redisLimit.exception;

import com.alian.redisLimit.dto.ApiResponseDto;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Component;
import org.springframework.web.HttpRequestMethodNotSupportedException;
import org.springframework.web.bind.MissingServletRequestParameterException;
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.bind.annotation.ResponseStatus;

import javax.servlet.http.HttpServletRequest;

@Slf4j
@Component
@ControllerAdvice
public class GlobalExceptionHandler {

    @ExceptionHandler
    @ResponseBody
    @ResponseStatus(HttpStatus.OK)
    public ApiResponseDto<?> handle(HttpRequestMethodNotSupportedException exception, HttpServletRequest request) {
        return logWarn(request.getRequestURI() + " " + exception.getMessage(), null, ApiResponseDto.errRequestMethod("请求方法错误"));
    }

    @ExceptionHandler
    @ResponseBody
    @ResponseStatus(HttpStatus.OK)
    public ApiResponseDto handle(MissingServletRequestParameterException exception) {
        return logWarn(exception.getMessage(), null, ApiResponseDto.errParam("参数错误"));
    }

    @ExceptionHandler
    @ResponseBody
    @ResponseStatus(HttpStatus.OK)
    public ApiResponseDto handle(RateLimiterException exception) {
        return ApiResponseDto.fail(exception.getMessage());
    }

    @ExceptionHandler
    @ResponseBody
    @ResponseStatus(HttpStatus.OK)
    public ApiResponseDto handle(Exception exception) {
        log.info("异常类:{}", exception.getClass().getCanonicalName());
        return logError(null, exception, ApiResponseDto.exception("系统异常"));
    }

    private static ApiResponseDto logWarn(String msg, Exception e, ApiResponseDto responseDto) {
        long timestamp = responseDto.getTimestamp();
        String m = "timestamp is " + timestamp;
        if (msg != null) {
            m += ", " + msg;
        }
        if (e == null) {
            log.warn(m);
        } else {
            log.warn(m, e);
        }
        return responseDto;
    }

    private static ApiResponseDto logError(String msg, Exception e, ApiResponseDto responseDto) {
        long timestamp = responseDto.getTimestamp();
        String m = "timestamp is " + timestamp;
        if (msg != null) {
            m += ", " + msg;
        }
        log.error(m, e);
        return responseDto;
    }

}

对应的统一返回封装如下:

ApiResponseDto.java

package com.alian.redisLimit.dto;

import lombok.*;
import lombok.experimental.Accessors;

@Setter
@Getter
@Accessors(chain = true)
@NoArgsConstructor
@AllArgsConstructor
@ToString(exclude = "content")
public class ApiResponseDto<T> {

    /** 成功 */
    public static String CODE_SUCCESS="0000";
    /** 失败 */
    public static String CODE_FAIL="1000";
    /** 系统异常 */
    public static String CODE_EXCEPTION="1001";
    /** 签名错误 */
    public static String CODE_ERR_SIGN="1002";
    /** 参数错误 */
    public static String CODE_ERR_PARAM="1003";
    /** 业务异常 */
    public static String CODE_BIZ_ERR="1004";
    /** 查询无数据,使用明确的参数(如id)进行查询时未找到记录时返回此错误码 */
    public static String CODE_NO_DATA="1005";
    /** 错误的请求方法 */
    public static String CODE_ERR_REQUEST_METHOD="1006";
    /** 错误的请求内容类型 */
    public static String CODE_ERR_CONTENT_TYPE="1007";
    /** 系统繁忙 */
    public static String CODE_SYS_BUSY="1008";
    /** 显示提示 */
    public static String CODE_SHOW_TIP="1009";
    /** 根据bizCode进行处理 */
    public static String CODE_DEAL_BIZ_CODE="1012";
    /** 未找到请求 */
    public static String CODE_NOT_FOUND_CODE="1013";

    public final static ApiResponseDto SUCCESS=new ApiResponseDto();


    private String code =CODE_SUCCESS;

    /** 状态说明 */
    private String msg ="success";

    /** 请求是否成功 */
    public boolean isSuccess(){
        return CODE_SUCCESS.equals(code);
    }

    /** 结果内容 */
    private T content;

    /** 时间戳 */
    private long timestamp=System.currentTimeMillis();

    /** 业务状态码,由业务接口定义 */
    private String bizCode;

    /** 业务状态说明 */
    private String bizMsg;

    public ApiResponseDto(T content) {
        this.content=content;
    }

    public static <T> ApiResponseDto<T> success(){
        return SUCCESS;
    }

    public static <T> ApiResponseDto<T> success(T content){
        return new ApiResponseDto<T>(content);
    }

    public static <T> ApiResponseDto<T> fail(String msg){
        ApiResponseDto<T> response = new ApiResponseDto<>();
        response.setCode(CODE_FAIL);
        response.setMsg(msg);
        return response;
    }

    public static <T> ApiResponseDto<T> exception(String msg){
        ApiResponseDto<T> response = new ApiResponseDto<>();
        response.setCode(CODE_EXCEPTION);
        response.setMsg(msg);
        return response;
    }

    public static <T> ApiResponseDto<T> errSign(String msg){
        ApiResponseDto<T> response = new ApiResponseDto<>();
        response.setCode(CODE_ERR_SIGN);
        response.setMsg(msg);
        return response;
    }

    public static <T> ApiResponseDto<T> errParam(String msg){
        ApiResponseDto<T> response = new ApiResponseDto<>();
        response.setCode(CODE_ERR_PARAM);
        response.setMsg(msg);
        return response;
    }

    public static <T> ApiResponseDto<T> bizErr(String msg){
        ApiResponseDto<T> response = new ApiResponseDto<>();
        response.setCode(CODE_BIZ_ERR);
        response.setMsg(msg);
        return response;
    }

    public static <T> ApiResponseDto<T> notFound(String msg){
        ApiResponseDto<T> response = new ApiResponseDto<>();
        response.setCode(CODE_NOT_FOUND_CODE);
        response.setMsg(msg);
        return response;
    }

    public static <T> ApiResponseDto<T> noData(String msg){
        ApiResponseDto<T> response = new ApiResponseDto<>();
        response.setCode(CODE_NO_DATA);
        response.setMsg(msg);
        return response;
    }
    public static <T> ApiResponseDto<T>  errRequestMethod(String msg){
        ApiResponseDto<T> response = new ApiResponseDto<>();
        response.setCode(CODE_ERR_REQUEST_METHOD);
        response.setMsg(msg);
        return response;
    }
    public static <T> ApiResponseDto<T> errContentType(){
        ApiResponseDto<T> response = new ApiResponseDto<>();
        response.setCode(CODE_ERR_CONTENT_TYPE);
        response.setMsg("错误的请求内容类型");
        return response;
    }
    public static <T> ApiResponseDto<T> sysBusy(){
        ApiResponseDto<T> response = new ApiResponseDto<>();
        response.setCode(CODE_SYS_BUSY);
        response.setMsg("系统繁忙");
        return response;
    }
    public static <T> ApiResponseDto<T>  showTip(String tip){
        ApiResponseDto<T> response = new ApiResponseDto<>();
        response.setCode(CODE_SHOW_TIP);
        response.setMsg(tip);
        return response;
    }

    public ApiResponseDto<T> bizInfo(String bizCode,String bizMsg){
        this.code=bizCode;
        this.msg=bizMsg;
        return this;
    }

    public static <T> ApiResponseDto<T>  dealBizCode(String bizCode,String bizMsg,T content){
        ApiResponseDto<T> response = new ApiResponseDto<>(content);
        response.setCode(CODE_DEAL_BIZ_CODE);
        response.setMsg("根据bizCode进行处理");
        response.setBizCode(bizCode);
        response.setBizMsg(bizMsg);
        return response;
    }
}

3.7、控制层

UserController.java

package com.alian.redisLimit.controller;

import com.alian.redisLimit.annotate.RateLimiter;
import com.alian.redisLimit.annotate.RateLimiters;
import com.alian.redisLimit.dto.ApiResponseDto;
import com.alian.redisLimit.dto.UserDto;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.Map;

@Slf4j
@RequestMapping("/user")
@RestController
public class UserController {

    private static Map<String, UserDto> map = new HashMap<String, UserDto>() {{
        put("BAT001", new UserDto("BAT001", "梁南生", 27, "研发部", 18000.0, LocalDateTime.of(2020, 5, 20, 9, 0, 0)));
        put("BAT002", new UserDto("BAT002", "包雅馨", 25, "财务部", 8800.0, LocalDateTime.of(2016, 11, 10, 8, 30, 0)));
        put("BAT003", new UserDto("BAT003", "罗考聪", 35, "测试部", 6400.0, LocalDateTime.of(2017, 3, 20, 14, 0, 0)));
    }};

    @RateLimiters(value = {
            @RateLimiter(keys = "#id", capacity = 1, rate = 1, showPromptMsg = "您查询太快了,请稍后再试"),
            @RateLimiter(capacity = 5, rate = 2, showPromptMsg = "系统繁忙,请稍后再试")
    })
    @RequestMapping("/findById/{id}")
    public ApiResponseDto<UserDto> findById(@PathVariable("id") String id) {
        UserDto userDto = map.getOrDefault(id, null);
        if (userDto != null) {
            return ApiResponseDto.success(userDto);
        }
        return ApiResponseDto.noData("未查询到数据");
    }

}

  简单模拟根据用户编号查询用户的接口,关键是我们使用注解@RateLimiter 的方法可以做限流,看是否能达到我们的要求。这里有两层意思:

  • 一个用户每秒最多支持1次请求,每秒最多生成1个令牌
  • 整个接口每秒最多支持5次请求,每秒最多生成2个令牌(生产根据需求调整即可)

四、验证

4.1、单用户限流

  我这里就采用压力测试工具 jmeter 进行一个简单的压测了:因为我们代码暂时写的令牌容量是1个请求,每秒最多生成1个令牌。我们模拟1个用户5秒内请求10个次(不是那么精准),看看会有多少触发限流。 jmeter 设置如下图:

Spring Boot 整合Redis使用Lua脚本实现限流_第1张图片

后台结果:

21:17:05 956 INFO [http-nio-8090-exec-1]:rateLimiter 33cd75b80483ce52ca96e58699ae97a2.BAT001, result is true
21:17:05 958 INFO [http-nio-8090-exec-1]:rateLimiter 33cd75b80483ce52ca96e58699ae97a2, result is true
21:17:05 984 INFO [http-nio-8090-exec-2]:rateLimiter 33cd75b80483ce52ca96e58699ae97a2.BAT001, result is false
21:17:05 984 WARN [http-nio-8090-exec-2]:触发限流,key is : 33cd75b80483ce52ca96e58699ae97a2.BAT001 
21:17:06 486 INFO [http-nio-8090-exec-3]:rateLimiter 33cd75b80483ce52ca96e58699ae97a2.BAT001, result is true
21:17:06 486 INFO [http-nio-8090-exec-3]:rateLimiter 33cd75b80483ce52ca96e58699ae97a2, result is true
21:17:06 993 INFO [http-nio-8090-exec-5]:rateLimiter 33cd75b80483ce52ca96e58699ae97a2.BAT001, result is false
21:17:06 993 WARN [http-nio-8090-exec-5]:触发限流,key is : 33cd75b80483ce52ca96e58699ae97a2.BAT001 
21:17:07 483 INFO [http-nio-8090-exec-7]:rateLimiter 33cd75b80483ce52ca96e58699ae97a2.BAT001, result is true
21:17:07 483 INFO [http-nio-8090-exec-7]:rateLimiter 33cd75b80483ce52ca96e58699ae97a2, result is true
21:17:07 977 INFO [http-nio-8090-exec-9]:rateLimiter 33cd75b80483ce52ca96e58699ae97a2.BAT001, result is false
21:17:07 977 WARN [http-nio-8090-exec-9]:触发限流,key is : 33cd75b80483ce52ca96e58699ae97a2.BAT001 
21:17:08 474 INFO [http-nio-8090-exec-2]:rateLimiter 33cd75b80483ce52ca96e58699ae97a2.BAT001, result is true
21:17:08 474 INFO [http-nio-8090-exec-2]:rateLimiter 33cd75b80483ce52ca96e58699ae97a2, result is true
21:17:08 983 INFO [http-nio-8090-exec-3]:rateLimiter 33cd75b80483ce52ca96e58699ae97a2.BAT001, result is false
21:17:08 983 WARN [http-nio-8090-exec-3]:触发限流,key is : 33cd75b80483ce52ca96e58699ae97a2.BAT001 
21:17:09 490 INFO [http-nio-8090-exec-5]:rateLimiter 33cd75b80483ce52ca96e58699ae97a2.BAT001, result is true
21:17:09 493 INFO [http-nio-8090-exec-5]:rateLimiter 33cd75b80483ce52ca96e58699ae97a2, result is true
21:17:09 983 INFO [http-nio-8090-exec-7]:rateLimiter 33cd75b80483ce52ca96e58699ae97a2.BAT001, result is false
21:17:09 983 WARN [http-nio-8090-exec-7]:触发限流,key is : 33cd75b80483ce52ca96e58699ae97a2.BAT001 

jmeter结果:

  Spring Boot 整合Redis使用Lua脚本实现限流_第2张图片
  从上面的结果我们看到10个请求5秒内,有5个通过,5个限流了。

4.2、接口限流

  当我们的用户是满足上面,一秒请求一次的情况下,假设有10个用户并发请求到我们系统了,同样会触发限流。不过,你可以不要认为1秒内10个用户发10个请求,一定是前面5个能请求通过,后面5个不通过,如果是同一秒内,那么结果是这样的,如果是跨秒,比如像下面这样(10个请求还是在一秒内):

开始 时间
前段时间 2023-03-02 16:30:20 500 2023-03-02 16:30:20 999
后段时间 2023-03-02 16:30:21 000 2023-03-02 16:30:21 500

  则前段时间请求数可以从0到10,我们分别列举下情况(假设当前秒内令牌桶数是5,每秒最多填充2个令牌):

前段请求数(个) 后段请求数(个) 前段通过数(个) 后段通过数(个) 限流数(个)
0 10 0 5 5
1 9 1 5 4
2 8 2 5 3
3 7 3 4 3
4 6 4 3 3
5 5 5 2 3
6 4 5 2 3
7 3 5 2 3
8 2 5 2 3
9 1 5 1 4
10 0 5 0 5

  从这里可以看到限流可能与请求是否跨秒和请求数总容量和填充速率配置有关。大家也看的我写的每秒最多填充2个令牌,如果超过1秒但是还没有超时,那么一定要弄懂这句,这句就是核心Lua代码。

-- 同1秒内的timeGaps的值都是0,令牌桶的数不会增加,直到扣减完,超过1秒的都会填充令牌
local filledToken = math.min(capacity, lastToken + (timeGaps * rate))

结语

  从上面大家就发现这个组合注解 @RateLimiters 的强大了,因为它可以同时解决单用户限流和多用户限流。可能小伙伴觉得这个不灵活,令牌的控制都写死在接口了(比如:令牌桶数,令牌生成速率),不能随时调整,这个还不简单?你可以把自定义注解的值都通过系统配置,然后缓存到redis中,然后AOP中处理@RateLimiter 通过缓存或接口获取即可。我们主要是要了解这个限流的思路,实际中一般都会把这个AOP写成一个自定义的Starter,供其他的项目引入依赖使用。

你可能感兴趣的:(Redis笔记,Lua脚本限流,Lua脚本日志输出,Redis实现限流)