Fix repeated read issue of request body of CallbackFilter #478

This commit is contained in:
Francis Dong
2023-06-01 15:39:02 +08:00
parent aa84548441
commit 0652b70019
2 changed files with 65 additions and 49 deletions

View File

@@ -17,6 +17,25 @@
package com.fizzgate.filter; package com.fizzgate.filter;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
import javax.annotation.Resource;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.annotation.Order;
import org.springframework.data.redis.core.ReactiveStringRedisTemplate;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilterChain;
import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSON;
import com.fizzgate.config.AggregateRedisConfig; import com.fizzgate.config.AggregateRedisConfig;
import com.fizzgate.plugin.auth.ApiConfig; import com.fizzgate.plugin.auth.ApiConfig;
@@ -27,33 +46,15 @@ import com.fizzgate.proxy.CallbackService;
import com.fizzgate.proxy.DiscoveryClientUriSelector; import com.fizzgate.proxy.DiscoveryClientUriSelector;
import com.fizzgate.proxy.ServiceInstance; import com.fizzgate.proxy.ServiceInstance;
import com.fizzgate.service_registry.RegistryCenterService; import com.fizzgate.service_registry.RegistryCenterService;
import com.fizzgate.spring.http.server.reactive.ext.FizzServerHttpRequestDecorator;
import com.fizzgate.spring.web.server.ext.FizzServerWebExchangeDecorator;
import com.fizzgate.util.Consts; import com.fizzgate.util.Consts;
import com.fizzgate.util.NettyDataBufferUtils; import com.fizzgate.util.NettyDataBufferUtils;
import com.fizzgate.util.ThreadContext; import com.fizzgate.util.ThreadContext;
import com.fizzgate.util.WebUtils; import com.fizzgate.util.WebUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.annotation.Order;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.PooledDataBuffer;
import org.springframework.data.redis.core.ReactiveStringRedisTemplate;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import javax.annotation.Resource;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
/** /**
* @author hongqiaowei * @author hongqiaowei
*/ */
@@ -89,34 +90,49 @@ public class CallbackFilter extends FizzWebFilter {
@Resource @Resource
private GatewayGroupService gatewayGroupService; private GatewayGroupService gatewayGroupService;
@Override
public Mono<Void> doFilter(ServerWebExchange exchange, WebFilterChain chain) {
FilterResult pfr = WebUtils.getPrevFilterResult(exchange); @Override
if (!pfr.success) { public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
return WebUtils.getDirectResponse(exchange); String traceId = WebUtils.getTraceId(exchange);
org.apache.logging.log4j.ThreadContext.put(Consts.TRACE_ID, traceId);
ServerHttpRequest req = exchange.getRequest();
if (req instanceof FizzServerHttpRequestDecorator) {
return doFilter(exchange, chain);
} }
return
NettyDataBufferUtils.join(req.getBody()).defaultIfEmpty(NettyDataBufferUtils.EMPTY_DATA_BUFFER)
.flatMap(
body -> {
FizzServerHttpRequestDecorator requestDecorator = new FizzServerHttpRequestDecorator(req);
if (body != NettyDataBufferUtils.EMPTY_DATA_BUFFER) {
try {
requestDecorator.setBody(body);
} finally {
NettyDataBufferUtils.release(body);
}
}
ServerWebExchange mutatedExchange = exchange.mutate().request(requestDecorator).build();
ServerWebExchange newExchange = mutatedExchange;
MediaType contentType = req.getHeaders().getContentType();
if (MediaType.APPLICATION_FORM_URLENCODED.isCompatibleWith(contentType)) {
newExchange = new FizzServerWebExchangeDecorator(mutatedExchange);
}
return doFilter(newExchange, chain);
}
);
}
public Mono<Void> doFilter(ServerWebExchange exchange, WebFilterChain chain) {
String traceId = WebUtils.getTraceId(exchange);
org.apache.logging.log4j.ThreadContext.put(Consts.TRACE_ID, traceId);
ApiConfig ac = WebUtils.getApiConfig(exchange); ApiConfig ac = WebUtils.getApiConfig(exchange);
if (ac != null && ac.type == ApiConfig.Type.CALLBACK) { if (ac != null && ac.type == ApiConfig.Type.CALLBACK) {
CallbackConfig cc = ac.callbackConfig; CallbackConfig cc = ac.callbackConfig;
ServerHttpRequest req = exchange.getRequest(); FizzServerHttpRequestDecorator req = (FizzServerHttpRequestDecorator) exchange.getRequest();
return return req.getBody().defaultIfEmpty(NettyDataBufferUtils.EMPTY_DATA_BUFFER).single().flatMap(b -> {
DataBufferUtils.join(req.getBody()).defaultIfEmpty(NettyDataBufferUtils.EMPTY_DATA_BUFFER) String body = b.toString(StandardCharsets.UTF_8);
.flatMap(
b -> {
DataBuffer body = null;
if (b != NettyDataBufferUtils.EMPTY_DATA_BUFFER) {
if (b instanceof PooledDataBuffer) {
try {
body = NettyDataBufferUtils.copy2heap(b);
} finally {
NettyDataBufferUtils.release(b);
}
} else {
body = b;
}
}
HashMap<String, ServiceInstance> service2instMap = getService2instMap(ac); HashMap<String, ServiceInstance> service2instMap = getService2instMap(ac);
HttpHeaders headers = WebUtils.mergeAppendHeaders(exchange); HttpHeaders headers = WebUtils.mergeAppendHeaders(exchange);
pushReq2manager(exchange, headers, body, service2instMap, cc.id, ac.gatewayGroups.iterator().next()); pushReq2manager(exchange, headers, body, service2instMap, cc.id, ac.gatewayGroups.iterator().next());
@@ -175,7 +191,7 @@ public class CallbackFilter extends FizzWebFilter {
private static final String _receivers = "\"receivers\":"; private static final String _receivers = "\"receivers\":";
private static final String _gatewayGroup = "\"gatewayGroup\":"; private static final String _gatewayGroup = "\"gatewayGroup\":";
private void pushReq2manager(ServerWebExchange exchange, HttpHeaders headers, DataBuffer body, HashMap<String, ServiceInstance> service2instMap, int callbackConfigId, private void pushReq2manager(ServerWebExchange exchange, HttpHeaders headers, Object body, HashMap<String, ServiceInstance> service2instMap, int callbackConfigId,
String gatewayGroup) { String gatewayGroup) {
ServerHttpRequest req = exchange.getRequest(); ServerHttpRequest req = exchange.getRequest();
@@ -215,7 +231,8 @@ public class CallbackFilter extends FizzWebFilter {
if (body != null) { if (body != null) {
b.append(Consts.S.COMMA); b.append(Consts.S.COMMA);
String bodyStr = body.toString(StandardCharsets.UTF_8); // String bodyStr = body.toString(StandardCharsets.UTF_8);
String bodyStr = body.toString();
MediaType contentType = req.getHeaders().getContentType(); MediaType contentType = req.getHeaders().getContentType();
if (contentType != null && contentType.getSubtype().equalsIgnoreCase(json)) { if (contentType != null && contentType.getSubtype().equalsIgnoreCase(json)) {
b.append(_body); b.append(JSON.toJSONString(bodyStr)); b.append(_body); b.append(JSON.toJSONString(bodyStr));

View File

@@ -19,7 +19,6 @@ package com.fizzgate.proxy;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpRequest;
@@ -81,7 +80,7 @@ public class CallbackService {
aggrConfigPrefix = systemConfig.getGatewayPrefix() + '/'; aggrConfigPrefix = systemConfig.getGatewayPrefix() + '/';
} }
public Mono<Void> requestBackends(ServerWebExchange exchange, HttpHeaders headers, DataBuffer body, CallbackConfig cc, Map<String, ServiceInstance> service2instMap) { public Mono<Void> requestBackends(ServerWebExchange exchange, HttpHeaders headers, String body, CallbackConfig cc, Map<String, ServiceInstance> service2instMap) {
ServerHttpRequest req = exchange.getRequest(); ServerHttpRequest req = exchange.getRequest();
String traceId = WebUtils.getTraceId(exchange); String traceId = WebUtils.getTraceId(exchange);
HttpMethod method = req.getMethod(); HttpMethod method = req.getMethod();
@@ -140,21 +139,21 @@ public class CallbackService {
; ;
} }
private Function<Throwable, Mono<? extends ClientResponse>> crError(ServerWebExchange exchange, Receiver r, HttpMethod method, HttpHeaders headers, DataBuffer body) { private Function<Throwable, Mono<? extends ClientResponse>> crError(ServerWebExchange exchange, Receiver r, HttpMethod method, HttpHeaders headers, String body) {
return t -> { return t -> {
log(exchange, r, method, headers, body, t); log(exchange, r, method, headers, body, t);
return Mono.just(new FizzFailClientResponse(t)); return Mono.just(new FizzFailClientResponse(t));
}; };
} }
private Function<Throwable, Mono<AggregateResult>> arError(ServerWebExchange exchange, Receiver r, HttpMethod method, HttpHeaders headers, DataBuffer body) { private Function<Throwable, Mono<AggregateResult>> arError(ServerWebExchange exchange, Receiver r, HttpMethod method, HttpHeaders headers, String body) {
return t -> { return t -> {
log(exchange, r, method, headers, body, t); log(exchange, r, method, headers, body, t);
return Mono.just(new FailAggregateResult(t)); return Mono.just(new FailAggregateResult(t));
}; };
} }
private void log(ServerWebExchange exchange, Receiver r, HttpMethod method, HttpHeaders headers, DataBuffer body, Throwable t) { private void log(ServerWebExchange exchange, Receiver r, HttpMethod method, HttpHeaders headers, String body, Throwable t) {
StringBuilder b = ThreadContext.getStringBuilder(); StringBuilder b = ThreadContext.getStringBuilder();
WebUtils.request2stringBuilder(exchange, b); WebUtils.request2stringBuilder(exchange, b);
b.append(Consts.S.LINE_SEPARATOR).append(callback).append(Consts.S.LINE_SEPARATOR); b.append(Consts.S.LINE_SEPARATOR).append(callback).append(Consts.S.LINE_SEPARATOR);