redis限流增加

This commit is contained in:
dushitaoyuan
2019-09-05 18:44:17 +08:00
parent a574dc46ea
commit c8f22d1cb0
14 changed files with 310 additions and 51 deletions

22
pom.xml
View File

@@ -26,11 +26,11 @@
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>
</dependency>
<!-- <dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-configuration-processor</artifactId>
&lt;!&ndash; <optional>true</optional>&ndash;&gt;
</dependency>-->
<!-- <dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-configuration-processor</artifactId>
&lt;!&ndash; <optional>true</optional>&ndash;&gt;
</dependency>-->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
@@ -81,11 +81,13 @@
<artifactId>hutool-core</artifactId>
<version>4.0.3</version>
</dependency>
<!-- <dependency>
<groupId>commons-fileupload</groupId>
<artifactId>commons-fileupload</artifactId>
<version>1.4</version>
</dependency>-->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
<optional>true</optional>
</dependency>
</dependencies>
<build>

View File

@@ -1,11 +1,20 @@
package com.taoyuanx.securitydemo.config;
import com.taoyuanx.securitydemo.security.ratelimit.AbstractRateLimiter;
import com.taoyuanx.securitydemo.security.ratelimit.RedisRateLimiter;
import lombok.Data;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.env.Environment;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.ReactiveStringRedisTemplate;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.StringRedisTemplate;
import java.net.UnknownHostException;
import java.util.List;
/**
@@ -35,5 +44,16 @@ public class GlobalConfig {
return environment.getProperty(configKey);
}
/**
* 限流实现类
* @param redisTemplate
* @return
*/
@Bean
public AbstractRateLimiter rateLimiter(RedisTemplate redisTemplate){
RedisRateLimiter redisRateLimiter=new RedisRateLimiter(redisTemplate);
return redisRateLimiter;
}
}

View File

@@ -159,4 +159,6 @@ public class MvcConfig implements WebMvcConfigurer {
return fileHandler;
}
}

View File

@@ -112,7 +112,7 @@ public class BussinessController {
*/
@GetMapping("rateLimit_key")
@ResponseBody
@RateLimit(type = RateLimitType.SERVICE_KEY, limit = 2, key = "api/rateLimit_key")
@RateLimit(type = RateLimitType.SERVICE_KEY, limit = 100, key = "api/rateLimit_key")
public String rateLimitKey() {
return "hello rateLimit!";
}

View File

@@ -11,7 +11,9 @@ import java.lang.annotation.*;
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimit {
/**
* 限流并发 和限流的key,类型
* limit 每秒并发
* key 限流的key
* type 限流类型 参见:com.taoyuanx.securitydemo.security.RateLimitType
*/
double limit() default 100;

View File

@@ -1,9 +1,8 @@
package com.taoyuanx.securitydemo.security;
import com.google.common.collect.Maps;
import com.google.common.util.concurrent.RateLimiter;
import com.taoyuanx.securitydemo.exception.LimitException;
import com.taoyuanx.securitydemo.security.ratelimit.AbstractRateLimiter;
import com.taoyuanx.securitydemo.utils.RequestUtil;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.Signature;
@@ -13,11 +12,11 @@ import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.stereotype.Component;
import java.lang.reflect.Method;
import java.util.Map;
/**
* aop限流
@@ -27,8 +26,8 @@ import java.util.Map;
public class RateLimitAspect {
private static final Logger LOG = LoggerFactory.getLogger(RateLimitAspect.class);
private Map<String, RateLimiter> rateHolder = Maps.newConcurrentMap();
private static final int MAX_HOLDER_SIZE = 50000;
@Autowired
AbstractRateLimiter rateLimiter;
@Pointcut("execution(* com.taoyuanx.securitydemo.controller..*.*(..))&& (@annotation(RateLimit)||@annotation(Rate))")
public void ratePointCut() {
@@ -46,13 +45,11 @@ public class RateLimitAspect {
handleRateLimit(rateLimit, methodName);
} else {
Rate rate = AnnotationUtils.findAnnotation(currentMethod, Rate.class);
if (rate != null) {
RateLimit[] rateLimitArray = rate.rate();
if (rateLimitArray != null) {
for (RateLimit limit : rateLimitArray) {
handleRateLimit(limit, methodName);
}
if (rate != null && rate.rate() != null) {
for (RateLimit limit : rate.rate()) {
handleRateLimit(limit, methodName);
}
}
}
@@ -61,13 +58,6 @@ public class RateLimitAspect {
private void handleRateLimit(RateLimit rateLimit, String methodName) throws Throwable {
RateLimiter rateLimiter = doGetRateLimiter(rateLimit, methodName);
if (!rateLimiter.tryAcquire()) {
throw new LimitException("请求过于频繁,请稍后再试");
}
}
private RateLimiter doGetRateLimiter(RateLimit rateLimit, String methodName) {
RateLimitType type = rateLimit.type();
String key = rateLimit.key();
if (type == null) {
@@ -84,40 +74,24 @@ public class RateLimitAspect {
} else {
key = RequestUtil.getRemoteIp() + "_" + serviceKey;
}
if (LOG.isDebugEnabled()) {
LOG.debug("采用[{}]限流策略,限流key:{}", RateLimitType.IP, key);
}
break;
case METHOD:
key = methodName;
if (LOG.isDebugEnabled()) {
LOG.debug("采用[{}]限流策略,限流key:{}", RateLimitType.METHOD, key);
}
break;
case SERVICE_KEY:
if (LOG.isDebugEnabled()) {
LOG.debug("采用[{}]限流策略,限流key:{}", RateLimitType.SERVICE_KEY, key);
}
break;
case GLOBAL:
key = "global";
if (LOG.isDebugEnabled()) {
LOG.debug("采用[{}]限流策略,限流key:{}", RateLimitType.GLOBAL, key);
}
break;
}
}
//超过固定阈值,清空,重构
if (rateHolder.size() > MAX_HOLDER_SIZE) {
rateHolder.clear();
if (LOG.isDebugEnabled()) {
LOG.debug("采用[{}]限流策略,限流key:{}", RateLimitType.IP, key);
}
if (rateHolder.containsKey(key)) {
return rateHolder.get(key);
if (!rateLimiter.tryAcquire(key, rateLimit.limit())) {
throw new LimitException("请求过于频繁,请稍后再试");
}
RateLimiter rateLimiter = RateLimiter.create(rateLimit.limit());
rateHolder.putIfAbsent(key, rateLimiter);
return rateHolder.get(key);
}

View File

@@ -0,0 +1,35 @@
package com.taoyuanx.securitydemo.security.ratelimit;
/**
* @author dushitaoyuan
* @desc 抽象限流
* @date 2019/9/5
*/
public abstract class AbstractRateLimiter {
/**
* 尝试获取令牌
*
* @param key 限流标识
* @param limit 限流速率
* @return
*/
public boolean tryAcquire(String key, Double limit){
return doTryAcquire(1,key,limit);
}
/**
* 尝试获取令牌
*
* @param permits 获取令牌数量
* @param key 限流标识
* @param limit 限流速率
* @return
*/
public boolean tryAcquire(int permits, String key, Double limit){
return doTryAcquire(permits,key,limit);
}
protected abstract boolean doTryAcquire(int permits, String key, Double limit);
}

View File

@@ -0,0 +1,41 @@
package com.taoyuanx.securitydemo.security.ratelimit;
import com.google.common.collect.Maps;
import com.google.common.util.concurrent.RateLimiter;
import java.util.Map;
/**
* @author dushitaoyuan
* @desc guava限流实现
* @date 2019/9/5
*/
public class GuavaRateLimiter extends AbstractRateLimiter {
private Map<String, RateLimiter> rateHolder = Maps.newConcurrentMap();
private static final int MAX_HOLDER_SIZE = 50000;
/* @Override
public boolean tryAcquire(String key, Double limit) {
return doTryAcquire(1, key, limit);
}
@Override
public boolean tryAcquire(int permits, String key, Double limit) {
return doTryAcquire(permits, key, limit);
}
*/
protected boolean doTryAcquire(int permits, String key, Double limit) {
//超过固定阈值,清空,重构
if (rateHolder.size() > MAX_HOLDER_SIZE) {
rateHolder.clear();
}
RateLimiter rateLimiter = null;
if (rateHolder.containsKey(key)) {
rateLimiter = rateHolder.get(key);
}
rateLimiter = RateLimiter.create(limit);
rateHolder.putIfAbsent(key, rateLimiter);
return rateLimiter.tryAcquire(permits);
}
}

View File

@@ -0,0 +1,62 @@
package com.taoyuanx.securitydemo.security.ratelimit;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.scripting.support.ResourceScriptSource;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
/**
* @author dushitaoyuan
* @desc redis 限流实现
* 复制 spring-cloud-gateway,可与spring-boot,spring-cloud等分离使用
* @date 2019/9/5
*/
public class RedisRateLimiter extends AbstractRateLimiter {
private RedisTemplate stringRedisTemplate;
private RedisScript<List<Long>> script;
public RedisRateLimiter(RedisTemplate redisTemplate) {
DefaultRedisScript script = new DefaultRedisScript();
script.setScriptSource(new ResourceScriptSource(
new ClassPathResource("META-INF/demo.lua")));
script.setResultType(List.class);
this.script = script;
this.stringRedisTemplate = redisTemplate;
}
/* @Override
public boolean tryAcquire(String key, Double limit) {
return doTryAcquire(1, key, limit);
}
@Override
public boolean tryAcquire(int permits, String key, Double limit) {
return doTryAcquire(permits, key, limit);
}*/
@Override
protected boolean doTryAcquire(int permits, String key, Double limit) {
List<Object> scriptArgs = Arrays.asList(limit.longValue(), limit.longValue(), Instant.now().getEpochSecond(), permits);
Object execute = stringRedisTemplate.execute(this.script, getKeys(key),
scriptArgs);
return execute == null;
}
private List<String> getKeys(String key) {
int keyId = key.hashCode();
String prefix = "request_rate_limiter.{" + keyId;
String tokenKey = prefix + "}.tokens";
String timestampKey = prefix + "}.timestamp";
return Arrays.asList(tokenKey, timestampKey);
}
}

View File

@@ -0,0 +1,10 @@
--获取KEY
local key1 = KEYS[1]
local key2 = KEYS[2]
--输出参数
redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
redis.log(redis.LOG_WARNING, "now " .. ARGV[3])
redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
return {1,2}

View File

@@ -0,0 +1,64 @@
local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)
redis.log(redis.LOG_WARNING, "timestamp_key" .. timestamp_key)
redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
redis.log(redis.LOG_WARNING, "now " .. ARGV[3])
redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])
local fill_time = capacity/rate
local ttl = math.floor(fill_time*2)
redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
redis.log(redis.LOG_WARNING, "now " .. ARGV[3])
redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
redis.log(redis.LOG_WARNING, "filltime " .. fill_time)
redis.log(redis.LOG_WARNING, "ttl " .. ttl)
local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
last_tokens = capacity
end
--redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens)
local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
last_refreshed = 0
end
--redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed)
local delta = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
local allowed_num = 0
if allowed then
new_tokens = filled_tokens - requested
allowed_num = 1
end
--redis.log(redis.LOG_WARNING, "delta " .. delta)
--redis.log(redis.LOG_WARNING, "filled_tokens " .. filled_tokens)
--redis.log(redis.LOG_WARNING, "allowed_num " .. allowed_num)
--redis.log(redis.LOG_WARNING, "new_tokens " .. new_tokens)
redis.call("setex", tokens_key, ttl, new_tokens)
redis.call("setex", timestamp_key, ttl, now)
return { allowed_num, new_tokens }

View File

@@ -13,6 +13,12 @@ server.port=8080
#spring.datasource.filters=wall
spring.redis.database=2
spring.redis.host=172.16.0.32
spring.redis.port=6379
spring.redis.password=guoruiredis
server.servlet.session.timeout=1800
#自定义配置
@@ -33,3 +39,5 @@ application.config.expire-seconds=1800
application.config.file-storage-dir=G:/temp
#系统文件访问地址,参见fileHandler
application.config.systemFileFormat=http://localhost:${server.port}/api/file

View File

@@ -3,6 +3,15 @@ package com.taoyuanx.securitydemo;
import com.google.common.hash.BloomFilter;
import com.google.common.hash.Funnels;
import org.junit.Test;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.Expression;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.common.TemplateParserContext;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import java.util.HashMap;
import java.util.Map;
/**
* @author dushitaoyuan
@@ -29,4 +38,24 @@ public class BoomFilterTest {
System.out.println("失败次数:" + count);
}
@Test
public void testEl(){
String el="${m}";
ExpressionParser parser = new SpelExpressionParser();
EvaluationContext context = new StandardEvaluationContext();
context.setVariable("m","1234");
Expression expression = parser.parseExpression(el);
System.out.println(expression.getValue(context, String.class));
// 定义变量
/* String name = "Tom";
EvaluationContext context = new StandardEvaluationContext(); // 表达式的上下文,
context.setVariable("myName", name); // 为了让表达式可以访问该对象, 先把对象放到上下文中
ExpressionParser parser = new SpelExpressionParser();
System.out.println( parser.parseExpression("#myName").getValue(context, String.class));; // Tom , 使用变量
*/
}
}

View File

@@ -17,4 +17,14 @@ public class SecurityDemoApplicationTests {
bussinessController.admin();
}
@Test
public void reateLimit()
{
int batch=100;
for(int i=0;i<batch;i++){
System.out.println(bussinessController.rateLimitKey());
}
}
}