diff --git a/pom.xml b/pom.xml index 7ce2cb8..29cf8ab 100644 --- a/pom.xml +++ b/pom.xml @@ -86,7 +86,6 @@ org.springframework.boot spring-boot-starter-data-redis - true diff --git a/src/main/java/com/taoyuanx/securitydemo/config/GlobalConfig.java b/src/main/java/com/taoyuanx/securitydemo/config/GlobalConfig.java index c9f1e60..9aa8bf3 100644 --- a/src/main/java/com/taoyuanx/securitydemo/config/GlobalConfig.java +++ b/src/main/java/com/taoyuanx/securitydemo/config/GlobalConfig.java @@ -4,17 +4,12 @@ import com.taoyuanx.securitydemo.security.ratelimit.AbstractRateLimiter; import com.taoyuanx.securitydemo.security.ratelimit.RedisRateLimiter; import lombok.Data; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.env.Environment; -import org.springframework.data.redis.connection.RedisConnectionFactory; -import org.springframework.data.redis.core.ReactiveStringRedisTemplate; -import org.springframework.data.redis.core.RedisTemplate; import org.springframework.data.redis.core.StringRedisTemplate; -import java.net.UnknownHostException; import java.util.List; /** @@ -40,19 +35,22 @@ public class GlobalConfig { private String fileStorageDir; private String systemFileFormat; + public String getConfig(String configKey) { return environment.getProperty(configKey); } /** * 限流实现类 + * * @param redisTemplate * @return */ @Bean - public AbstractRateLimiter rateLimiter(RedisTemplate redisTemplate){ - RedisRateLimiter redisRateLimiter=new RedisRateLimiter(redisTemplate); - return redisRateLimiter; + @Autowired + public AbstractRateLimiter rateLimiter(StringRedisTemplate redisTemplate) { + RedisRateLimiter redisRateLimiter = new RedisRateLimiter(redisTemplate); + return redisRateLimiter; } diff --git a/src/main/java/com/taoyuanx/securitydemo/config/MvcConfig.java b/src/main/java/com/taoyuanx/securitydemo/config/MvcConfig.java index 6530228..e54b3f3 100644 --- a/src/main/java/com/taoyuanx/securitydemo/config/MvcConfig.java +++ b/src/main/java/com/taoyuanx/securitydemo/config/MvcConfig.java @@ -154,6 +154,7 @@ public class MvcConfig implements WebMvcConfigurer { } @Bean + public FileHandler fileHandler(){ FileHandler fileHandler=new FileHandler(globalConfig.getFileStorageDir(),globalConfig.getTokenKey(),false,globalConfig.getSystemFileFormat()); return fileHandler; diff --git a/src/main/java/com/taoyuanx/securitydemo/controller/BussinessController.java b/src/main/java/com/taoyuanx/securitydemo/controller/BussinessController.java index 87401f0..cf1df85 100644 --- a/src/main/java/com/taoyuanx/securitydemo/controller/BussinessController.java +++ b/src/main/java/com/taoyuanx/securitydemo/controller/BussinessController.java @@ -9,6 +9,7 @@ import com.taoyuanx.securitydemo.dto.AccountDTO; import com.taoyuanx.securitydemo.exception.ServiceException; import com.taoyuanx.securitydemo.helper.ToeknHelper; import com.taoyuanx.securitydemo.security.*; +import com.taoyuanx.securitydemo.utils.CookieUtil; import com.taoyuanx.securitydemo.utils.FileHandler; import com.taoyuanx.securitydemo.utils.FileTypeCheckUtil; import com.taoyuanx.securitydemo.utils.PasswordUtil; @@ -156,6 +157,19 @@ public class BussinessController { } + + /** + * 登录安全控制 + * + * @return + */ + @PostMapping("loginOut") + @ResponseBody + public void loginOut(HttpServletResponse response, HttpServletRequest request) throws Exception { + CookieUtil.removeCookie(response, "/", SystemConstants.TOKEN_COOKIE_KEY); + request.getSession().invalidate(); + } + /** * 黑名单测试 * @@ -178,7 +192,7 @@ public class BussinessController { */ String fileId = DateUtil.format(new Date(), "yyyy-mm-dd") + "/" + multipartFile.getOriginalFilename(); File file = new File(globalConfig.getFileStorageDir(), fileId); - if(!file.getParentFile().exists()){ + if (!file.getParentFile().exists()) { file.getParentFile().mkdirs(); } multipartFile.transferTo(file); diff --git a/src/main/java/com/taoyuanx/securitydemo/interceptor/TokenAuthHandlerIntercepter.java b/src/main/java/com/taoyuanx/securitydemo/interceptor/TokenAuthHandlerIntercepter.java index 6f0ee25..db11412 100644 --- a/src/main/java/com/taoyuanx/securitydemo/interceptor/TokenAuthHandlerIntercepter.java +++ b/src/main/java/com/taoyuanx/securitydemo/interceptor/TokenAuthHandlerIntercepter.java @@ -68,10 +68,11 @@ public class TokenAuthHandlerIntercepter implements HandlerInterceptor { return false; } Map tokenData = toeknHelper.vafy(token); - Integer tokenAccountId = (Integer) tokenData.get(SystemConstants.TOKEN_ACCOUNTID_KEY); - if (tokenAccountId.longValue() != accountId) { + Long tokenAccountId = toeknHelper.getAccountId(tokenData); + if (tokenAccountId != accountId) { return false; } + return true; } if (null != publicUrl && publicUrl.size() > 0) { diff --git a/src/main/java/com/taoyuanx/securitydemo/security/RateLimit.java b/src/main/java/com/taoyuanx/securitydemo/security/RateLimit.java index e9c578a..ba04cd2 100644 --- a/src/main/java/com/taoyuanx/securitydemo/security/RateLimit.java +++ b/src/main/java/com/taoyuanx/securitydemo/security/RateLimit.java @@ -14,10 +14,12 @@ public @interface RateLimit { * limit 每秒并发 * key 限流的key * type 限流类型 参见:com.taoyuanx.securitydemo.security.RateLimitType + * totalCount 次数限流 */ double limit() default 100; String key() default "" ; RateLimitType type() default RateLimitType.METHOD; + long totalCount() default 0; } \ No newline at end of file diff --git a/src/main/java/com/taoyuanx/securitydemo/security/RateLimitType.java b/src/main/java/com/taoyuanx/securitydemo/security/RateLimitType.java index a4c2c0a..4066288 100644 --- a/src/main/java/com/taoyuanx/securitydemo/security/RateLimitType.java +++ b/src/main/java/com/taoyuanx/securitydemo/security/RateLimitType.java @@ -8,7 +8,8 @@ package com.taoyuanx.securitydemo.security; public enum RateLimitType { IP(0, "IP限流"), METHOD(1, "方法名"), SERVICE_KEY(3, "业务自定义key"), - GLOBAL(4,"系统全局"); + GLOBAL(4,"系统全局"), + TOTAL_COUNT(5,"总次数限制"); private int code; private String desc; @@ -31,6 +32,8 @@ public enum RateLimitType { case 4: return GLOBAL; + case 5: + return TOTAL_COUNT; } return null; } diff --git a/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/AbstractRateLimiter.java b/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/AbstractRateLimiter.java index 3b897c7..398f621 100644 --- a/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/AbstractRateLimiter.java +++ b/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/AbstractRateLimiter.java @@ -4,6 +4,7 @@ package com.taoyuanx.securitydemo.security.ratelimit; * @author dushitaoyuan * @desc 抽象限流 * @date 2019/9/5 + * */ public abstract class AbstractRateLimiter { @@ -18,18 +19,16 @@ public abstract class AbstractRateLimiter { public boolean tryAcquire(String key, Double limit){ return doTryAcquire(1,key,limit); } + protected abstract boolean doTryAcquire(int permits, String key, Double limit); /** - * 尝试获取令牌 - * - * @param permits 获取令牌数量 - * @param key 限流标识 - * @param limit 限流速率 + * 增加资源访问次数 用户可自行持久化记录 + * @param count + * @param key + * @param totalCount * @return */ - public boolean tryAcquire(int permits, String key, Double limit){ - return doTryAcquire(permits,key,limit); - } + public abstract boolean tryCount(int count,String key,Long totalCount); + - protected abstract boolean doTryAcquire(int permits, String key, Double limit); } diff --git a/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/GuavaRateLimiter.java b/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/GuavaRateLimiter.java index af851cc..f4339e8 100644 --- a/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/GuavaRateLimiter.java +++ b/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/GuavaRateLimiter.java @@ -1,9 +1,13 @@ package com.taoyuanx.securitydemo.security.ratelimit; -import com.google.common.collect.Maps; +import com.google.common.hash.BloomFilter; +import com.google.common.hash.Funnels; import com.google.common.util.concurrent.RateLimiter; +import java.nio.charset.Charset; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.LongAdder; /** * @author dushitaoyuan @@ -11,18 +15,15 @@ import java.util.Map; * @date 2019/9/5 */ public class GuavaRateLimiter extends AbstractRateLimiter { - private Map rateHolder = Maps.newConcurrentMap(); private static final int MAX_HOLDER_SIZE = 50000; - /* @Override - public boolean tryAcquire(String key, Double limit) { - return doTryAcquire(1, key, limit); - } + private Map rateHolder = new ConcurrentHashMap<>(MAX_HOLDER_SIZE); + + private Map countHolder = new ConcurrentHashMap(); + /** + * 总数限流到0后,标记 + */ + private BloomFilter TOTAL_LIMIT_ZERO_FLAG = BloomFilter.create(Funnels.stringFunnel(Charset.defaultCharset()), MAX_HOLDER_SIZE * 20); - @Override - public boolean tryAcquire(int permits, String key, Double limit) { - return doTryAcquire(permits, key, limit); - } -*/ protected boolean doTryAcquire(int permits, String key, Double limit) { //超过固定阈值,清空,重构 if (rateHolder.size() > MAX_HOLDER_SIZE) { @@ -37,5 +38,36 @@ public class GuavaRateLimiter extends AbstractRateLimiter { return rateLimiter.tryAcquire(permits); } + @Override + public boolean tryCount(int count, String key, Long totalCount) { + //标记后,直接返回false + if (TOTAL_LIMIT_ZERO_FLAG.mightContain(key)) { + return false; + } + //超过固定阈值,清空,重构 防止内存溢出 + if (countHolder.size() > MAX_HOLDER_SIZE) { + countHolder.clear(); + } + LongAdder longAdder = null; + if (countHolder.containsKey(key)) { + longAdder = countHolder.get(key); + longAdder.add(-count); + //资源总数用完后,标记 + if (longAdder.longValue() <= 0) { + TOTAL_LIMIT_ZERO_FLAG.put(key); + countHolder.remove(key); + return true; + } + return false; + } + if (count > totalCount) { + return false; + } + longAdder = new LongAdder(); + countHolder.putIfAbsent(key, longAdder); + countHolder.get(key).add(count); + return true; + } + } diff --git a/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/RedisRateLimiter.java b/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/RedisRateLimiter.java index 0fb0815..e4b6ecb 100644 --- a/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/RedisRateLimiter.java +++ b/src/main/java/com/taoyuanx/securitydemo/security/ratelimit/RedisRateLimiter.java @@ -1,7 +1,7 @@ package com.taoyuanx.securitydemo.security.ratelimit; import org.springframework.core.io.ClassPathResource; -import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.core.StringRedisTemplate; import org.springframework.data.redis.core.script.DefaultRedisScript; import org.springframework.data.redis.core.script.RedisScript; import org.springframework.scripting.support.ResourceScriptSource; @@ -17,40 +17,40 @@ import java.util.List; * @date 2019/9/5 */ public class RedisRateLimiter extends AbstractRateLimiter { - private RedisTemplate stringRedisTemplate; - private RedisScript> script; + private StringRedisTemplate redisTemplate; + private RedisScript> tokenScript; + private RedisScript countScript; - public RedisRateLimiter(RedisTemplate redisTemplate) { + public RedisRateLimiter(StringRedisTemplate redisTemplate) { DefaultRedisScript script = new DefaultRedisScript(); - script.setScriptSource(new ResourceScriptSource( - new ClassPathResource("META-INF/demo.lua"))); + script.setScriptSource(new ResourceScriptSource(new ClassPathResource("META-INF/rate_limiter_token.lua"))); script.setResultType(List.class); - this.script = script; - this.stringRedisTemplate = redisTemplate; + this.tokenScript = script; + + script = new DefaultRedisScript(); + script.setScriptSource(new ResourceScriptSource(new ClassPathResource("META-INF/rate_limiter_count.lua"))); + script.setResultType(Long.class); + this.countScript = script; + + this.redisTemplate = redisTemplate; } - /* @Override - public boolean tryAcquire(String key, Double limit) { - return doTryAcquire(1, key, limit); - } - - @Override - public boolean tryAcquire(int permits, String key, Double limit) { - return doTryAcquire(permits, key, limit); - }*/ @Override protected boolean doTryAcquire(int permits, String key, Double limit) { - List scriptArgs = Arrays.asList(limit.longValue(), limit.longValue(), Instant.now().getEpochSecond(), permits); - Object execute = stringRedisTemplate.execute(this.script, getKeys(key), - scriptArgs); + String[] scriptArgs = {limit.longValue() + "", limit.longValue() + "", Instant.now().getEpochSecond() + "", permits + ""}; + List results = redisTemplate.execute(this.tokenScript, getKeys(key), scriptArgs); + return results.get(0) == 1L; + } - return execute == null; - + @Override + public boolean tryCount(int count, String key, Long totalCount) { + String[] scriptArgs = {count + "", totalCount + ""}; + Long result = redisTemplate.execute(this.countScript, Arrays.asList(key), scriptArgs); + return result == 1L; } private List getKeys(String key) { - int keyId = key.hashCode(); String prefix = "request_rate_limiter.{" + keyId; String tokenKey = prefix + "}.tokens"; diff --git a/src/main/java/com/taoyuanx/securitydemo/utils/CookieUtil.java b/src/main/java/com/taoyuanx/securitydemo/utils/CookieUtil.java index c6c401b..1764544 100644 --- a/src/main/java/com/taoyuanx/securitydemo/utils/CookieUtil.java +++ b/src/main/java/com/taoyuanx/securitydemo/utils/CookieUtil.java @@ -1,6 +1,7 @@ package com.taoyuanx.securitydemo.utils; import javax.servlet.http.Cookie; +import javax.servlet.http.HttpServletResponse; public class CookieUtil { //判断cookie是否存在 @@ -14,7 +15,13 @@ public class CookieUtil { } return null; } - + //删除cookie + public static void removeCookie(HttpServletResponse response,String path, String cookieName){ + Cookie cookie=new Cookie(cookieName,null); + cookie.setPath(path); + cookie.setMaxAge(0); + response.addCookie(cookie); + } diff --git a/src/main/resources/META-INF/demo.lua b/src/main/resources/META-INF/demo.lua deleted file mode 100644 index 5d20fb1..0000000 --- a/src/main/resources/META-INF/demo.lua +++ /dev/null @@ -1,10 +0,0 @@ ---获取KEY -local key1 = KEYS[1] -local key2 = KEYS[2] ---输出参数 -redis.log(redis.LOG_WARNING, "rate " .. ARGV[1]) -redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2]) -redis.log(redis.LOG_WARNING, "now " .. ARGV[3]) -redis.log(redis.LOG_WARNING, "requested " .. ARGV[4]) - -return {1,2} \ No newline at end of file diff --git a/src/main/resources/META-INF/rate_limiter_count.lua b/src/main/resources/META-INF/rate_limiter_count.lua new file mode 100644 index 0000000..406a432 --- /dev/null +++ b/src/main/resources/META-INF/rate_limiter_count.lua @@ -0,0 +1,24 @@ +--如果等于0说明超时,其他则是当前资源的访问数量 + +--申请资源数量 +local count= tonumber(ARGV[1]) + +if count == nil then + count = 1 +end + + +-- 获取剩余资源数量 +local last_count = tonumber(redis.call("get", KEYS[1])) +if last_count == nil then + last_count = ARGV[2]; +end + +--计数减少 +if last_count > count then + redis.call("DECRBY", KEYS[1], count) + return 1 +else + return 0 +end +--todo bitmap 标记 资源为0 \ No newline at end of file diff --git a/src/main/resources/META-INF/rate_limiter_token.lua b/src/main/resources/META-INF/rate_limiter_token.lua new file mode 100644 index 0000000..8df1a89 --- /dev/null +++ b/src/main/resources/META-INF/rate_limiter_token.lua @@ -0,0 +1,53 @@ +--令牌算法实现 +--限流标识和时间戳key +local tokens_key = KEYS[1] +local timestamp_key = KEYS[2] + +-- rate 令牌产生速率 +local rate = tonumber(ARGV[1]) +-- 令牌容量 +local capacity = tonumber(ARGV[2]) +--时间戳 +local now = tonumber(ARGV[3]) +--申请令牌数量 +local requested = tonumber(ARGV[4]) +--填满漏桶所需要的时间 容量除以速率 +local fill_time = capacity/rate +--过期时间为填满漏桶时间的2倍 +local ttl = math.floor(fill_time*2) + +--redis.log(redis.LOG_WARNING, "rate " .. ARGV[1]) +--redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2]) +--redis.log(redis.LOG_WARNING, "now " .. ARGV[3]) +--redis.log(redis.LOG_WARNING, "requested " .. ARGV[4]) +--redis.log(redis.LOG_WARNING, "filltime " .. fill_time) +--redis.log(redis.LOG_WARNING, "ttl " .. ttl) +--剩余令牌 +local last_tokens = tonumber(redis.call("get", tokens_key)) +if last_tokens == nil then + last_tokens = capacity +end + +--上次获取令牌的时间 +local last_refreshed = tonumber(redis.call("get", timestamp_key)) +if last_refreshed == nil then + last_refreshed = 0 +end + +--重新计算令牌数量=当前时间与上次获取令牌的时间差值*令牌生产速率+剩余令牌数量 +--申请令牌数量小于当前桶内数量时,请求被允许 +local delta = math.max(0, now-last_refreshed) +local filled_tokens = math.min(capacity, last_tokens+(delta*rate)) +local allowed = filled_tokens >= requested +local new_tokens = filled_tokens +local allowed_num = 0 +if allowed then + new_tokens = filled_tokens - requested + allowed_num = 1 +end + + +redis.call("setex", tokens_key, ttl, new_tokens) +redis.call("setex", timestamp_key, ttl, now) + +return {allowed_num, new_tokens} diff --git a/src/main/resources/META-INF/request_rate_limiter.lua b/src/main/resources/META-INF/request_rate_limiter.lua deleted file mode 100644 index 6eacb43..0000000 --- a/src/main/resources/META-INF/request_rate_limiter.lua +++ /dev/null @@ -1,64 +0,0 @@ - -local tokens_key = KEYS[1] -local timestamp_key = KEYS[2] -redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key) - -redis.log(redis.LOG_WARNING, "timestamp_key" .. timestamp_key) - - -redis.log(redis.LOG_WARNING, "rate " .. ARGV[1]) -redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2]) -redis.log(redis.LOG_WARNING, "now " .. ARGV[3]) -redis.log(redis.LOG_WARNING, "requested " .. ARGV[4]) - - - - - - -local rate = tonumber(ARGV[1]) -local capacity = tonumber(ARGV[2]) -local now = tonumber(ARGV[3]) -local requested = tonumber(ARGV[4]) - -local fill_time = capacity/rate -local ttl = math.floor(fill_time*2) - -redis.log(redis.LOG_WARNING, "rate " .. ARGV[1]) -redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2]) -redis.log(redis.LOG_WARNING, "now " .. ARGV[3]) -redis.log(redis.LOG_WARNING, "requested " .. ARGV[4]) -redis.log(redis.LOG_WARNING, "filltime " .. fill_time) -redis.log(redis.LOG_WARNING, "ttl " .. ttl) - -local last_tokens = tonumber(redis.call("get", tokens_key)) -if last_tokens == nil then - last_tokens = capacity -end ---redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens) - -local last_refreshed = tonumber(redis.call("get", timestamp_key)) -if last_refreshed == nil then - last_refreshed = 0 -end ---redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed) - -local delta = math.max(0, now-last_refreshed) -local filled_tokens = math.min(capacity, last_tokens+(delta*rate)) -local allowed = filled_tokens >= requested -local new_tokens = filled_tokens -local allowed_num = 0 -if allowed then - new_tokens = filled_tokens - requested - allowed_num = 1 -end - ---redis.log(redis.LOG_WARNING, "delta " .. delta) ---redis.log(redis.LOG_WARNING, "filled_tokens " .. filled_tokens) ---redis.log(redis.LOG_WARNING, "allowed_num " .. allowed_num) ---redis.log(redis.LOG_WARNING, "new_tokens " .. new_tokens) - -redis.call("setex", tokens_key, ttl, new_tokens) -redis.call("setex", timestamp_key, ttl, now) - -return { allowed_num, new_tokens } diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index 3731d5e..859e814 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -15,9 +15,9 @@ server.port=8080 spring.redis.database=2 -spring.redis.host=172.16.0.32 +spring.redis.host=192.168.30.211 spring.redis.port=6379 -spring.redis.password=guoruiredis +spring.redis.password= server.servlet.session.timeout=1800