增加redis限流,增加限流策略,总数限流

This commit is contained in:
dushitaoyuan
2019-09-06 00:19:11 +08:00
parent c8f22d1cb0
commit 85f269a90e
16 changed files with 192 additions and 133 deletions

View File

@@ -86,7 +86,6 @@
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
<optional>true</optional>
</dependency>
</dependencies>

View File

@@ -4,17 +4,12 @@ import com.taoyuanx.securitydemo.security.ratelimit.AbstractRateLimiter;
import com.taoyuanx.securitydemo.security.ratelimit.RedisRateLimiter;
import lombok.Data;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.env.Environment;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.ReactiveStringRedisTemplate;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.StringRedisTemplate;
import java.net.UnknownHostException;
import java.util.List;
/**
@@ -40,19 +35,22 @@ public class GlobalConfig {
private String fileStorageDir;
private String systemFileFormat;
public String getConfig(String configKey) {
return environment.getProperty(configKey);
}
/**
* 限流实现类
*
* @param redisTemplate
* @return
*/
@Bean
public AbstractRateLimiter rateLimiter(RedisTemplate redisTemplate){
RedisRateLimiter redisRateLimiter=new RedisRateLimiter(redisTemplate);
return redisRateLimiter;
@Autowired
public AbstractRateLimiter rateLimiter(StringRedisTemplate redisTemplate) {
RedisRateLimiter redisRateLimiter = new RedisRateLimiter(redisTemplate);
return redisRateLimiter;
}

View File

@@ -154,6 +154,7 @@ public class MvcConfig implements WebMvcConfigurer {
}
@Bean
public FileHandler fileHandler(){
FileHandler fileHandler=new FileHandler(globalConfig.getFileStorageDir(),globalConfig.getTokenKey(),false,globalConfig.getSystemFileFormat());
return fileHandler;

View File

@@ -9,6 +9,7 @@ import com.taoyuanx.securitydemo.dto.AccountDTO;
import com.taoyuanx.securitydemo.exception.ServiceException;
import com.taoyuanx.securitydemo.helper.ToeknHelper;
import com.taoyuanx.securitydemo.security.*;
import com.taoyuanx.securitydemo.utils.CookieUtil;
import com.taoyuanx.securitydemo.utils.FileHandler;
import com.taoyuanx.securitydemo.utils.FileTypeCheckUtil;
import com.taoyuanx.securitydemo.utils.PasswordUtil;
@@ -156,6 +157,19 @@ public class BussinessController {
}
/**
* 登录安全控制
*
* @return
*/
@PostMapping("loginOut")
@ResponseBody
public void loginOut(HttpServletResponse response, HttpServletRequest request) throws Exception {
CookieUtil.removeCookie(response, "/", SystemConstants.TOKEN_COOKIE_KEY);
request.getSession().invalidate();
}
/**
* 黑名单测试
*
@@ -178,7 +192,7 @@ public class BussinessController {
*/
String fileId = DateUtil.format(new Date(), "yyyy-mm-dd") + "/" + multipartFile.getOriginalFilename();
File file = new File(globalConfig.getFileStorageDir(), fileId);
if(!file.getParentFile().exists()){
if (!file.getParentFile().exists()) {
file.getParentFile().mkdirs();
}
multipartFile.transferTo(file);

View File

@@ -68,10 +68,11 @@ public class TokenAuthHandlerIntercepter implements HandlerInterceptor {
return false;
}
Map<String, Object> tokenData = toeknHelper.vafy(token);
Integer tokenAccountId = (Integer) tokenData.get(SystemConstants.TOKEN_ACCOUNTID_KEY);
if (tokenAccountId.longValue() != accountId) {
Long tokenAccountId = toeknHelper.getAccountId(tokenData);
if (tokenAccountId != accountId) {
return false;
}
return true;
}
if (null != publicUrl && publicUrl.size() > 0) {

View File

@@ -14,10 +14,12 @@ public @interface RateLimit {
* limit 每秒并发
* key 限流的key
* type 限流类型 参见:com.taoyuanx.securitydemo.security.RateLimitType
* totalCount 次数限流
*/
double limit() default 100;
String key() default "" ;
RateLimitType type() default RateLimitType.METHOD;
long totalCount() default 0;
}

View File

@@ -8,7 +8,8 @@ package com.taoyuanx.securitydemo.security;
public enum RateLimitType {
IP(0, "IP限流"), METHOD(1, "方法名"),
SERVICE_KEY(3, "业务自定义key"),
GLOBAL(4,"系统全局");
GLOBAL(4,"系统全局"),
TOTAL_COUNT(5,"总次数限制");
private int code;
private String desc;
@@ -31,6 +32,8 @@ public enum RateLimitType {
case 4:
return GLOBAL;
case 5:
return TOTAL_COUNT;
}
return null;
}

View File

@@ -4,6 +4,7 @@ package com.taoyuanx.securitydemo.security.ratelimit;
* @author dushitaoyuan
* @desc 抽象限流
* @date 2019/9/5
*
*/
public abstract class AbstractRateLimiter {
@@ -18,18 +19,16 @@ public abstract class AbstractRateLimiter {
public boolean tryAcquire(String key, Double limit){
return doTryAcquire(1,key,limit);
}
protected abstract boolean doTryAcquire(int permits, String key, Double limit);
/**
* 尝试获取令牌
*
* @param permits 获取令牌数量
* @param key 限流标识
* @param limit 限流速率
* 增加资源访问次数 用户可自行持久化记录
* @param count
* @param key
* @param totalCount
* @return
*/
public boolean tryAcquire(int permits, String key, Double limit){
return doTryAcquire(permits,key,limit);
}
public abstract boolean tryCount(int count,String key,Long totalCount);
protected abstract boolean doTryAcquire(int permits, String key, Double limit);
}

View File

@@ -1,9 +1,13 @@
package com.taoyuanx.securitydemo.security.ratelimit;
import com.google.common.collect.Maps;
import com.google.common.hash.BloomFilter;
import com.google.common.hash.Funnels;
import com.google.common.util.concurrent.RateLimiter;
import java.nio.charset.Charset;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
/**
* @author dushitaoyuan
@@ -11,18 +15,15 @@ import java.util.Map;
* @date 2019/9/5
*/
public class GuavaRateLimiter extends AbstractRateLimiter {
private Map<String, RateLimiter> rateHolder = Maps.newConcurrentMap();
private static final int MAX_HOLDER_SIZE = 50000;
/* @Override
public boolean tryAcquire(String key, Double limit) {
return doTryAcquire(1, key, limit);
}
private Map<String, RateLimiter> rateHolder = new ConcurrentHashMap<>(MAX_HOLDER_SIZE);
private Map<String, LongAdder> countHolder = new ConcurrentHashMap();
/**
* 总数限流到0后,标记
*/
private BloomFilter<CharSequence> TOTAL_LIMIT_ZERO_FLAG = BloomFilter.create(Funnels.stringFunnel(Charset.defaultCharset()), MAX_HOLDER_SIZE * 20);
@Override
public boolean tryAcquire(int permits, String key, Double limit) {
return doTryAcquire(permits, key, limit);
}
*/
protected boolean doTryAcquire(int permits, String key, Double limit) {
//超过固定阈值,清空,重构
if (rateHolder.size() > MAX_HOLDER_SIZE) {
@@ -37,5 +38,36 @@ public class GuavaRateLimiter extends AbstractRateLimiter {
return rateLimiter.tryAcquire(permits);
}
@Override
public boolean tryCount(int count, String key, Long totalCount) {
//标记后,直接返回false
if (TOTAL_LIMIT_ZERO_FLAG.mightContain(key)) {
return false;
}
//超过固定阈值,清空,重构 防止内存溢出
if (countHolder.size() > MAX_HOLDER_SIZE) {
countHolder.clear();
}
LongAdder longAdder = null;
if (countHolder.containsKey(key)) {
longAdder = countHolder.get(key);
longAdder.add(-count);
//资源总数用完后,标记
if (longAdder.longValue() <= 0) {
TOTAL_LIMIT_ZERO_FLAG.put(key);
countHolder.remove(key);
return true;
}
return false;
}
if (count > totalCount) {
return false;
}
longAdder = new LongAdder();
countHolder.putIfAbsent(key, longAdder);
countHolder.get(key).add(count);
return true;
}
}

View File

@@ -1,7 +1,7 @@
package com.taoyuanx.securitydemo.security.ratelimit;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.scripting.support.ResourceScriptSource;
@@ -17,40 +17,40 @@ import java.util.List;
* @date 2019/9/5
*/
public class RedisRateLimiter extends AbstractRateLimiter {
private RedisTemplate stringRedisTemplate;
private RedisScript<List<Long>> script;
private StringRedisTemplate redisTemplate;
private RedisScript<List<Long>> tokenScript;
private RedisScript<Long> countScript;
public RedisRateLimiter(RedisTemplate redisTemplate) {
public RedisRateLimiter(StringRedisTemplate redisTemplate) {
DefaultRedisScript script = new DefaultRedisScript();
script.setScriptSource(new ResourceScriptSource(
new ClassPathResource("META-INF/demo.lua")));
script.setScriptSource(new ResourceScriptSource(new ClassPathResource("META-INF/rate_limiter_token.lua")));
script.setResultType(List.class);
this.script = script;
this.stringRedisTemplate = redisTemplate;
this.tokenScript = script;
script = new DefaultRedisScript();
script.setScriptSource(new ResourceScriptSource(new ClassPathResource("META-INF/rate_limiter_count.lua")));
script.setResultType(Long.class);
this.countScript = script;
this.redisTemplate = redisTemplate;
}
/* @Override
public boolean tryAcquire(String key, Double limit) {
return doTryAcquire(1, key, limit);
}
@Override
public boolean tryAcquire(int permits, String key, Double limit) {
return doTryAcquire(permits, key, limit);
}*/
@Override
protected boolean doTryAcquire(int permits, String key, Double limit) {
List<Object> scriptArgs = Arrays.asList(limit.longValue(), limit.longValue(), Instant.now().getEpochSecond(), permits);
Object execute = stringRedisTemplate.execute(this.script, getKeys(key),
scriptArgs);
String[] scriptArgs = {limit.longValue() + "", limit.longValue() + "", Instant.now().getEpochSecond() + "", permits + ""};
List<Long> results = redisTemplate.execute(this.tokenScript, getKeys(key), scriptArgs);
return results.get(0) == 1L;
}
return execute == null;
@Override
public boolean tryCount(int count, String key, Long totalCount) {
String[] scriptArgs = {count + "", totalCount + ""};
Long result = redisTemplate.execute(this.countScript, Arrays.asList(key), scriptArgs);
return result == 1L;
}
private List<String> getKeys(String key) {
int keyId = key.hashCode();
String prefix = "request_rate_limiter.{" + keyId;
String tokenKey = prefix + "}.tokens";

View File

@@ -1,6 +1,7 @@
package com.taoyuanx.securitydemo.utils;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletResponse;
public class CookieUtil {
//判断cookie是否存在
@@ -14,7 +15,13 @@ public class CookieUtil {
}
return null;
}
//删除cookie
public static void removeCookie(HttpServletResponse response,String path, String cookieName){
Cookie cookie=new Cookie(cookieName,null);
cookie.setPath(path);
cookie.setMaxAge(0);
response.addCookie(cookie);
}

View File

@@ -1,10 +0,0 @@
--获取KEY
local key1 = KEYS[1]
local key2 = KEYS[2]
--输出参数
redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
redis.log(redis.LOG_WARNING, "now " .. ARGV[3])
redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
return {1,2}

View File

@@ -0,0 +1,24 @@
--如果等于0说明超时,其他则是当前资源的访问数量
--申请资源数量
local count= tonumber(ARGV[1])
if count == nil then
count = 1
end
-- 获取剩余资源数量
local last_count = tonumber(redis.call("get", KEYS[1]))
if last_count == nil then
last_count = ARGV[2];
end
--计数减少
if last_count > count then
redis.call("DECRBY", KEYS[1], count)
return 1
else
return 0
end
--todo bitmap 标记 资源为0

View File

@@ -0,0 +1,53 @@
--令牌算法实现
--限流标识和时间戳key
local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
-- rate 令牌产生速率
local rate = tonumber(ARGV[1])
-- 令牌容量
local capacity = tonumber(ARGV[2])
--时间戳
local now = tonumber(ARGV[3])
--申请令牌数量
local requested = tonumber(ARGV[4])
--填满漏桶所需要的时间 容量除以速率
local fill_time = capacity/rate
--过期时间为填满漏桶时间的2倍
local ttl = math.floor(fill_time*2)
--redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
--redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
--redis.log(redis.LOG_WARNING, "now " .. ARGV[3])
--redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
--redis.log(redis.LOG_WARNING, "filltime " .. fill_time)
--redis.log(redis.LOG_WARNING, "ttl " .. ttl)
--剩余令牌
local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
last_tokens = capacity
end
--上次获取令牌的时间
local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
last_refreshed = 0
end
--重新计算令牌数量=当前时间与上次获取令牌的时间差值*令牌生产速率+剩余令牌数量
--申请令牌数量小于当前桶内数量时,请求被允许
local delta = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
local allowed_num = 0
if allowed then
new_tokens = filled_tokens - requested
allowed_num = 1
end
redis.call("setex", tokens_key, ttl, new_tokens)
redis.call("setex", timestamp_key, ttl, now)
return {allowed_num, new_tokens}

View File

@@ -1,64 +0,0 @@
local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)
redis.log(redis.LOG_WARNING, "timestamp_key" .. timestamp_key)
redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
redis.log(redis.LOG_WARNING, "now " .. ARGV[3])
redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])
local fill_time = capacity/rate
local ttl = math.floor(fill_time*2)
redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
redis.log(redis.LOG_WARNING, "now " .. ARGV[3])
redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
redis.log(redis.LOG_WARNING, "filltime " .. fill_time)
redis.log(redis.LOG_WARNING, "ttl " .. ttl)
local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
last_tokens = capacity
end
--redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens)
local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
last_refreshed = 0
end
--redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed)
local delta = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
local allowed_num = 0
if allowed then
new_tokens = filled_tokens - requested
allowed_num = 1
end
--redis.log(redis.LOG_WARNING, "delta " .. delta)
--redis.log(redis.LOG_WARNING, "filled_tokens " .. filled_tokens)
--redis.log(redis.LOG_WARNING, "allowed_num " .. allowed_num)
--redis.log(redis.LOG_WARNING, "new_tokens " .. new_tokens)
redis.call("setex", tokens_key, ttl, new_tokens)
redis.call("setex", timestamp_key, ttl, now)
return { allowed_num, new_tokens }

View File

@@ -15,9 +15,9 @@ server.port=8080
spring.redis.database=2
spring.redis.host=172.16.0.32
spring.redis.host=192.168.30.211
spring.redis.port=6379
spring.redis.password=guoruiredis
spring.redis.password=
server.servlet.session.timeout=1800