diff --git a/pom.xml b/pom.xml index 3920b14..7ce2cb8 100644 --- a/pom.xml +++ b/pom.xml @@ -26,11 +26,11 @@ org.springframework.boot spring-boot-starter - + org.springframework.boot spring-boot-starter-web @@ -81,11 +81,13 @@ hutool-core 4.0.3 - + + + + org.springframework.boot + spring-boot-starter-data-redis + true + diff --git a/src/main/java/com/taoyuanx/securitydemo/config/GlobalConfig.java b/src/main/java/com/taoyuanx/securitydemo/config/GlobalConfig.java index 8f9cabf..c9f1e60 100644 --- a/src/main/java/com/taoyuanx/securitydemo/config/GlobalConfig.java +++ b/src/main/java/com/taoyuanx/securitydemo/config/GlobalConfig.java @@ -1,11 +1,20 @@ package com.taoyuanx.securitydemo.config; +import com.taoyuanx.securitydemo.security.ratelimit.AbstractRateLimiter; +import com.taoyuanx.securitydemo.security.ratelimit.RedisRateLimiter; import lombok.Data; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.env.Environment; +import org.springframework.data.redis.connection.RedisConnectionFactory; +import org.springframework.data.redis.core.ReactiveStringRedisTemplate; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.core.StringRedisTemplate; +import java.net.UnknownHostException; import java.util.List; /** @@ -35,5 +44,16 @@ public class GlobalConfig { return environment.getProperty(configKey); } + /** + * 限流实现类 + * @param redisTemplate + * @return + */ + @Bean + public AbstractRateLimiter rateLimiter(RedisTemplate redisTemplate){ + RedisRateLimiter redisRateLimiter=new RedisRateLimiter(redisTemplate); + return redisRateLimiter; + } + } diff --git a/src/main/java/com/taoyuanx/securitydemo/config/MvcConfig.java b/src/main/java/com/taoyuanx/securitydemo/config/MvcConfig.java index c206593..6530228 100644 --- a/src/main/java/com/taoyuanx/securitydemo/config/MvcConfig.java +++ b/src/main/java/com/taoyuanx/securitydemo/config/MvcConfig.java @@ -159,4 +159,6 @@ public class MvcConfig implements WebMvcConfigurer { return fileHandler; } + + } diff --git a/src/main/java/com/taoyuanx/securitydemo/controller/BussinessController.java b/src/main/java/com/taoyuanx/securitydemo/controller/BussinessController.java index 5be8695..87401f0 100644 --- a/src/main/java/com/taoyuanx/securitydemo/controller/BussinessController.java +++ b/src/main/java/com/taoyuanx/securitydemo/controller/BussinessController.java @@ -112,7 +112,7 @@ public class BussinessController { */ @GetMapping("rateLimit_key") @ResponseBody - @RateLimit(type = RateLimitType.SERVICE_KEY, limit = 2, key = "api/rateLimit_key") + @RateLimit(type = RateLimitType.SERVICE_KEY, limit = 100, key = "api/rateLimit_key") public String rateLimitKey() { return "hello rateLimit!"; } diff --git a/src/main/java/com/taoyuanx/securitydemo/security/RateLimit.java b/src/main/java/com/taoyuanx/securitydemo/security/RateLimit.java index 104ed05..e9c578a 100644 --- a/src/main/java/com/taoyuanx/securitydemo/security/RateLimit.java +++ b/src/main/java/com/taoyuanx/securitydemo/security/RateLimit.java @@ -11,7 +11,9 @@ import java.lang.annotation.*; @Retention(RetentionPolicy.RUNTIME) public @interface RateLimit { /** - * 限流并发 和限流的key,类型 + * limit 每秒并发 + * key 限流的key + * type 限流类型 参见:com.taoyuanx.securitydemo.security.RateLimitType */ double limit() default 100; diff --git a/src/main/java/com/taoyuanx/securitydemo/security/RateLimitAspect.java b/src/main/java/com/taoyuanx/securitydemo/security/RateLimitAspect.java index 625481e..5bfe731 100644 --- a/src/main/java/com/taoyuanx/securitydemo/security/RateLimitAspect.java +++ b/src/main/java/com/taoyuanx/securitydemo/security/RateLimitAspect.java @@ -1,9 +1,8 @@ package com.taoyuanx.securitydemo.security; -import com.google.common.collect.Maps; -import com.google.common.util.concurrent.RateLimiter; import com.taoyuanx.securitydemo.exception.LimitException; +import com.taoyuanx.securitydemo.security.ratelimit.AbstractRateLimiter; import com.taoyuanx.securitydemo.utils.RequestUtil; import org.aspectj.lang.ProceedingJoinPoint; import org.aspectj.lang.Signature; @@ -13,11 +12,11 @@ import org.aspectj.lang.annotation.Pointcut; import org.aspectj.lang.reflect.MethodSignature; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.stereotype.Component; import java.lang.reflect.Method; -import java.util.Map; /** * aop限流 @@ -27,8 +26,8 @@ import java.util.Map; public class RateLimitAspect { private static final Logger LOG = LoggerFactory.getLogger(RateLimitAspect.class); - private Map rateHolder = Maps.newConcurrentMap(); - private static final int MAX_HOLDER_SIZE = 50000; + @Autowired + AbstractRateLimiter rateLimiter; @Pointcut("execution(* com.taoyuanx.securitydemo.controller..*.*(..))&& (@annotation(RateLimit)||@annotation(Rate))") public void ratePointCut() { @@ -46,13 +45,11 @@ public class RateLimitAspect { handleRateLimit(rateLimit, methodName); } else { Rate rate = AnnotationUtils.findAnnotation(currentMethod, Rate.class); - if (rate != null) { - RateLimit[] rateLimitArray = rate.rate(); - if (rateLimitArray != null) { - for (RateLimit limit : rateLimitArray) { - handleRateLimit(limit, methodName); - } + if (rate != null && rate.rate() != null) { + for (RateLimit limit : rate.rate()) { + handleRateLimit(limit, methodName); } + } } @@ -61,13 +58,6 @@ public class RateLimitAspect { private void handleRateLimit(RateLimit rateLimit, String methodName) throws Throwable { - RateLimiter rateLimiter = doGetRateLimiter(rateLimit, methodName); - if (!rateLimiter.tryAcquire()) { - throw new LimitException("请求过于频繁,请稍后再试"); - } - } - - private RateLimiter doGetRateLimiter(RateLimit rateLimit, String methodName) { RateLimitType type = rateLimit.type(); String key = rateLimit.key(); if (type == null) { @@ -84,40 +74,24 @@ public class RateLimitAspect { } else { key = RequestUtil.getRemoteIp() + "_" + serviceKey; } - if (LOG.isDebugEnabled()) { - LOG.debug("采用[{}]限流策略,限流key:{}", RateLimitType.IP, key); - } break; case METHOD: key = methodName; - if (LOG.isDebugEnabled()) { - LOG.debug("采用[{}]限流策略,限流key:{}", RateLimitType.METHOD, key); - } break; case SERVICE_KEY: - if (LOG.isDebugEnabled()) { - LOG.debug("采用[{}]限流策略,限流key:{}", RateLimitType.SERVICE_KEY, key); - } + break; case GLOBAL: key = "global"; - if (LOG.isDebugEnabled()) { - LOG.debug("采用[{}]限流策略,限流key:{}", RateLimitType.GLOBAL, key); - } break; } } - //超过固定阈值,清空,重构 - if (rateHolder.size() > MAX_HOLDER_SIZE) { - rateHolder.clear(); + if (LOG.isDebugEnabled()) { + LOG.debug("采用[{}]限流策略,限流key:{}", RateLimitType.IP, key); } - if (rateHolder.containsKey(key)) { - return rateHolder.get(key); + if (!rateLimiter.tryAcquire(key, rateLimit.limit())) { + throw new LimitException("请求过于频繁,请稍后再试"); } - RateLimiter rateLimiter = RateLimiter.create(rateLimit.limit()); - rateHolder.putIfAbsent(key, rateLimiter); - return rateHolder.get(key); - } diff --git a/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/AbstractRateLimiter.java b/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/AbstractRateLimiter.java new file mode 100644 index 0000000..3b897c7 --- /dev/null +++ b/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/AbstractRateLimiter.java @@ -0,0 +1,35 @@ +package com.taoyuanx.securitydemo.security.ratelimit; + +/** + * @author dushitaoyuan + * @desc 抽象限流 + * @date 2019/9/5 + */ +public abstract class AbstractRateLimiter { + + + /** + * 尝试获取令牌 + * + * @param key 限流标识 + * @param limit 限流速率 + * @return + */ + public boolean tryAcquire(String key, Double limit){ + return doTryAcquire(1,key,limit); + } + + /** + * 尝试获取令牌 + * + * @param permits 获取令牌数量 + * @param key 限流标识 + * @param limit 限流速率 + * @return + */ + public boolean tryAcquire(int permits, String key, Double limit){ + return doTryAcquire(permits,key,limit); + } + + protected abstract boolean doTryAcquire(int permits, String key, Double limit); +} diff --git a/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/GuavaRateLimiter.java b/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/GuavaRateLimiter.java new file mode 100644 index 0000000..af851cc --- /dev/null +++ b/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/GuavaRateLimiter.java @@ -0,0 +1,41 @@ +package com.taoyuanx.securitydemo.security.ratelimit; + +import com.google.common.collect.Maps; +import com.google.common.util.concurrent.RateLimiter; + +import java.util.Map; + +/** + * @author dushitaoyuan + * @desc guava限流实现 + * @date 2019/9/5 + */ +public class GuavaRateLimiter extends AbstractRateLimiter { + private Map rateHolder = Maps.newConcurrentMap(); + private static final int MAX_HOLDER_SIZE = 50000; + /* @Override + public boolean tryAcquire(String key, Double limit) { + return doTryAcquire(1, key, limit); + } + + @Override + public boolean tryAcquire(int permits, String key, Double limit) { + return doTryAcquire(permits, key, limit); + } +*/ + protected boolean doTryAcquire(int permits, String key, Double limit) { + //超过固定阈值,清空,重构 + if (rateHolder.size() > MAX_HOLDER_SIZE) { + rateHolder.clear(); + } + RateLimiter rateLimiter = null; + if (rateHolder.containsKey(key)) { + rateLimiter = rateHolder.get(key); + } + rateLimiter = RateLimiter.create(limit); + rateHolder.putIfAbsent(key, rateLimiter); + return rateLimiter.tryAcquire(permits); + } + + +} diff --git a/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/RedisRateLimiter.java b/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/RedisRateLimiter.java new file mode 100644 index 0000000..0fb0815 --- /dev/null +++ b/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/RedisRateLimiter.java @@ -0,0 +1,62 @@ +package com.taoyuanx.securitydemo.security.ratelimit; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.core.script.DefaultRedisScript; +import org.springframework.data.redis.core.script.RedisScript; +import org.springframework.scripting.support.ResourceScriptSource; + +import java.time.Instant; +import java.util.Arrays; +import java.util.List; + +/** + * @author dushitaoyuan + * @desc redis 限流实现 + * 复制 spring-cloud-gateway,可与spring-boot,spring-cloud等分离使用 + * @date 2019/9/5 + */ +public class RedisRateLimiter extends AbstractRateLimiter { + private RedisTemplate stringRedisTemplate; + private RedisScript> script; + + public RedisRateLimiter(RedisTemplate redisTemplate) { + DefaultRedisScript script = new DefaultRedisScript(); + script.setScriptSource(new ResourceScriptSource( + new ClassPathResource("META-INF/demo.lua"))); + script.setResultType(List.class); + this.script = script; + this.stringRedisTemplate = redisTemplate; + } + + /* @Override + public boolean tryAcquire(String key, Double limit) { + return doTryAcquire(1, key, limit); + } + + @Override + public boolean tryAcquire(int permits, String key, Double limit) { + return doTryAcquire(permits, key, limit); + }*/ + @Override + protected boolean doTryAcquire(int permits, String key, Double limit) { + List scriptArgs = Arrays.asList(limit.longValue(), limit.longValue(), Instant.now().getEpochSecond(), permits); + Object execute = stringRedisTemplate.execute(this.script, getKeys(key), + scriptArgs); + + + return execute == null; + + } + + private List getKeys(String key) { + + int keyId = key.hashCode(); + String prefix = "request_rate_limiter.{" + keyId; + String tokenKey = prefix + "}.tokens"; + String timestampKey = prefix + "}.timestamp"; + return Arrays.asList(tokenKey, timestampKey); + } + + +} diff --git a/src/main/resources/META-INF/demo.lua b/src/main/resources/META-INF/demo.lua new file mode 100644 index 0000000..5d20fb1 --- /dev/null +++ b/src/main/resources/META-INF/demo.lua @@ -0,0 +1,10 @@ +--获取KEY +local key1 = KEYS[1] +local key2 = KEYS[2] +--输出参数 +redis.log(redis.LOG_WARNING, "rate " .. ARGV[1]) +redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2]) +redis.log(redis.LOG_WARNING, "now " .. ARGV[3]) +redis.log(redis.LOG_WARNING, "requested " .. ARGV[4]) + +return {1,2} \ No newline at end of file diff --git a/src/main/resources/META-INF/request_rate_limiter.lua b/src/main/resources/META-INF/request_rate_limiter.lua new file mode 100644 index 0000000..6eacb43 --- /dev/null +++ b/src/main/resources/META-INF/request_rate_limiter.lua @@ -0,0 +1,64 @@ + +local tokens_key = KEYS[1] +local timestamp_key = KEYS[2] +redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key) + +redis.log(redis.LOG_WARNING, "timestamp_key" .. timestamp_key) + + +redis.log(redis.LOG_WARNING, "rate " .. ARGV[1]) +redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2]) +redis.log(redis.LOG_WARNING, "now " .. ARGV[3]) +redis.log(redis.LOG_WARNING, "requested " .. ARGV[4]) + + + + + + +local rate = tonumber(ARGV[1]) +local capacity = tonumber(ARGV[2]) +local now = tonumber(ARGV[3]) +local requested = tonumber(ARGV[4]) + +local fill_time = capacity/rate +local ttl = math.floor(fill_time*2) + +redis.log(redis.LOG_WARNING, "rate " .. ARGV[1]) +redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2]) +redis.log(redis.LOG_WARNING, "now " .. ARGV[3]) +redis.log(redis.LOG_WARNING, "requested " .. ARGV[4]) +redis.log(redis.LOG_WARNING, "filltime " .. fill_time) +redis.log(redis.LOG_WARNING, "ttl " .. ttl) + +local last_tokens = tonumber(redis.call("get", tokens_key)) +if last_tokens == nil then + last_tokens = capacity +end +--redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens) + +local last_refreshed = tonumber(redis.call("get", timestamp_key)) +if last_refreshed == nil then + last_refreshed = 0 +end +--redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed) + +local delta = math.max(0, now-last_refreshed) +local filled_tokens = math.min(capacity, last_tokens+(delta*rate)) +local allowed = filled_tokens >= requested +local new_tokens = filled_tokens +local allowed_num = 0 +if allowed then + new_tokens = filled_tokens - requested + allowed_num = 1 +end + +--redis.log(redis.LOG_WARNING, "delta " .. delta) +--redis.log(redis.LOG_WARNING, "filled_tokens " .. filled_tokens) +--redis.log(redis.LOG_WARNING, "allowed_num " .. allowed_num) +--redis.log(redis.LOG_WARNING, "new_tokens " .. new_tokens) + +redis.call("setex", tokens_key, ttl, new_tokens) +redis.call("setex", timestamp_key, ttl, now) + +return { allowed_num, new_tokens } diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index 70d9273..3731d5e 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -13,6 +13,12 @@ server.port=8080 #spring.datasource.filters=wall + +spring.redis.database=2 +spring.redis.host=172.16.0.32 +spring.redis.port=6379 +spring.redis.password=guoruiredis + server.servlet.session.timeout=1800 #自定义配置 @@ -33,3 +39,5 @@ application.config.expire-seconds=1800 application.config.file-storage-dir=G:/temp #系统文件访问地址,参见fileHandler application.config.systemFileFormat=http://localhost:${server.port}/api/file + + diff --git a/src/test/java/com/taoyuanx/securitydemo/BoomFilterTest.java b/src/test/java/com/taoyuanx/securitydemo/BoomFilterTest.java index ca8e6ea..241e80d 100644 --- a/src/test/java/com/taoyuanx/securitydemo/BoomFilterTest.java +++ b/src/test/java/com/taoyuanx/securitydemo/BoomFilterTest.java @@ -3,6 +3,15 @@ package com.taoyuanx.securitydemo; import com.google.common.hash.BloomFilter; import com.google.common.hash.Funnels; import org.junit.Test; +import org.springframework.expression.EvaluationContext; +import org.springframework.expression.Expression; +import org.springframework.expression.ExpressionParser; +import org.springframework.expression.common.TemplateParserContext; +import org.springframework.expression.spel.standard.SpelExpressionParser; +import org.springframework.expression.spel.support.StandardEvaluationContext; + +import java.util.HashMap; +import java.util.Map; /** * @author dushitaoyuan @@ -29,4 +38,24 @@ public class BoomFilterTest { System.out.println("失败次数:" + count); } + @Test + public void testEl(){ + String el="${m}"; + ExpressionParser parser = new SpelExpressionParser(); + EvaluationContext context = new StandardEvaluationContext(); + context.setVariable("m","1234"); + Expression expression = parser.parseExpression(el); + + + System.out.println(expression.getValue(context, String.class)); + + // 定义变量 + /* String name = "Tom"; + EvaluationContext context = new StandardEvaluationContext(); // 表达式的上下文, + context.setVariable("myName", name); // 为了让表达式可以访问该对象, 先把对象放到上下文中 + ExpressionParser parser = new SpelExpressionParser(); + System.out.println( parser.parseExpression("#myName").getValue(context, String.class));; // Tom , 使用变量 + +*/ + } } diff --git a/src/test/java/com/taoyuanx/securitydemo/SecurityDemoApplicationTests.java b/src/test/java/com/taoyuanx/securitydemo/SecurityDemoApplicationTests.java index 98303cf..113cc45 100644 --- a/src/test/java/com/taoyuanx/securitydemo/SecurityDemoApplicationTests.java +++ b/src/test/java/com/taoyuanx/securitydemo/SecurityDemoApplicationTests.java @@ -17,4 +17,14 @@ public class SecurityDemoApplicationTests { bussinessController.admin(); } + + @Test + public void reateLimit() + { + int batch=100; + for(int i=0;i