update: callback filter request all backends

This commit is contained in:
hongqiaowei
2021-01-29 11:58:56 +08:00
parent 71ed163909
commit 31de1b3b0c
2 changed files with 108 additions and 14 deletions

View File

@@ -41,7 +41,14 @@ import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import we.config.AggregateRedisConfig;
import we.constants.CommonConstants;
import we.fizz.AggregateResource;
import we.fizz.AggregateResult;
import we.fizz.ConfigLoader;
import we.fizz.Pipeline;
import we.fizz.input.Input;
import we.flume.clients.log4j2appender.LogService;
import we.plugin.auth.ApiConfig;
import we.plugin.auth.CallbackConfig;
@@ -50,6 +57,7 @@ import we.proxy.DiscoveryClientUriSelector;
import we.proxy.FizzWebClient;
import we.proxy.ServiceInstance;
import we.util.Constants;
import we.util.MapUtil;
import we.util.ThreadContext;
import we.util.WebUtils;
@@ -58,6 +66,7 @@ import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* @author hongqiaowei
@@ -94,6 +103,9 @@ public class CallbackFilter extends FizzWebFilter {
@Resource(name = AggregateRedisConfig.AGGREGATE_REACTIVE_REDIS_TEMPLATE)
private ReactiveStringRedisTemplate rt;
@Resource
private ConfigLoader aggregateResourceLoader;
@Override
public Mono<Void> doFilter(ServerWebExchange exchange, WebFilterChain chain) {
@@ -106,7 +118,9 @@ public class CallbackFilter extends FizzWebFilter {
DataBufferUtils.join(req.getBody()).defaultIfEmpty(emptyBody)
.flatMap(
b -> {
body[0] = b;
if (b != emptyBody) {
body[0] = b;
}
String bodyStr = body[0].toString(StandardCharsets.UTF_8);
HashMap<String, ServiceInstance> service2instMap = getService2instMap(ac);
HttpHeaders headers = WebUtils.mergeAppendHeaders(exchange);
@@ -120,7 +134,7 @@ public class CallbackFilter extends FizzWebFilter {
)
.doFinally(
s -> {
if (body[0] != emptyBody) {
if (body[0] != null) {
DataBufferUtils.release(body[0]);
}
}
@@ -146,8 +160,7 @@ public class CallbackFilter extends FizzWebFilter {
send = fizzWebClient.send(req.getId(), req.getMethod(), uri, headers, body);
}
} else {
// 如果是聚合接口需预处理再用fizzWebClient调用
send = fizzWebClient.send(req.getId(), req.getMethod(), "xxx", headers, body);
send = requestAggregateResource(exchange, headers, body, r);
}
monos[i] = send;
}
@@ -161,21 +174,80 @@ public class CallbackFilter extends FizzWebFilter {
)
.flatMap(
resps -> {
Object r = null;
for (int i = 1; i < resps.size(); i++) {
// complete and release resp
r = resps.get(i);
if (r instanceof ClientResponse) {
cleanup((ClientResponse) r);
}
}
Object o = resps.get(0);
if (o instanceof ClientResponse) {
ClientResponse remoteResp = (ClientResponse) o;
r = resps.get(0);
if (r instanceof ClientResponse) {
ClientResponse remoteResp = (ClientResponse) r;
return genServerResponse(exchange, remoteResp);
} else if (r instanceof AggregateResult) {
AggregateResult ar = (AggregateResult) r;
return genAggregateResponse(exchange, ar);
} else {
return Mono.empty();
} // else AggregateResult
return Mono.error(new RuntimeException("cant response client with " + r, null, false, false) {});
}
}
)
;
}
private Mono<AggregateResult> requestAggregateResource(ServerWebExchange exchange, HttpHeaders hdrs, DataBuffer body, Receiver receiver) {
// long start = System.currentTimeMillis();
ServerHttpRequest request = exchange.getRequest();
String path = WebUtils.getClientReqPathPrefix(exchange) + receiver.service + receiver.path;
String method = request.getMethodValue();
AggregateResource aggregateResource = aggregateResourceLoader.matchAggregateResource(method, path);
if (aggregateResource == null) {
return Mono.error(new RuntimeException("no aggregate resource: " + method + ' ' + path, null, false, false) {});
} else {
Pipeline pipeline = aggregateResource.getPipeline();
Input input = aggregateResource.getInput();
Map<String, Object> headers = MapUtil.toHashMap(hdrs);
String traceId = WebUtils.getTraceId(exchange);
LogService.setBizId(traceId);
log.debug("matched aggregation api: {}", path);
Map<String, Object> clientInput = new HashMap<>(9);
clientInput.put("path", path);
clientInput.put("method", method);
clientInput.put("headers", headers);
clientInput.put("params", MapUtil.toHashMap(request.getQueryParams()));
if (body != null) {
clientInput.put("body", JSON.parse(body.toString(StandardCharsets.UTF_8)));
}
return pipeline.run(input, clientInput, traceId).subscribeOn(Schedulers.elastic());
}
}
private Mono<? extends Void> genAggregateResponse(ServerWebExchange exchange, AggregateResult ar) {
ServerHttpResponse clientResp = exchange.getResponse();
String traceId = WebUtils.getTraceId(exchange);
LogService.setBizId(traceId);
String js = JSON.toJSONString(ar.getBody());
log.debug("aggregate response body: {}", js);
if (ar.getHeaders() != null && !ar.getHeaders().isEmpty()) {
ar.getHeaders().remove("Content-Length");
clientResp.getHeaders().addAll(ar.getHeaders());
}
if (!clientResp.getHeaders().containsKey("Content-Type")) {
// defalut content-type
clientResp.getHeaders().add("Content-Type", "application/json; charset=UTF-8");
}
List<String> headerTraceIds = clientResp.getHeaders().get(CommonConstants.HEADER_TRACE_ID);
if (headerTraceIds == null || !headerTraceIds.contains(traceId)) {
clientResp.getHeaders().add(CommonConstants.HEADER_TRACE_ID, traceId);
}
// long end = System.currentTimeMillis();
// pipeline.getStepContext().addElapsedTime("总耗时", end - start);
// log.info("ElapsedTimes={}", JSON.toJSONString(pipeline.getStepContext().getElapsedTimes()));
return clientResp
.writeWith(Flux.just(exchange.getResponse().bufferFactory().wrap(js.getBytes())));
}
private Mono<? extends Void> genServerResponse(ServerWebExchange exchange, ClientResponse remoteResp) {
ServerHttpResponse clientResp = exchange.getResponse();
clientResp.setStatusCode(remoteResp.statusCode());

View File

@@ -32,17 +32,23 @@ import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.lang.Nullable;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import we.constants.CommonConstants;
import we.filter.FilterResult;
import we.flume.clients.log4j2appender.LogService;
import we.legacy.RespEntity;
import we.plugin.auth.ApiConfig;
import we.plugin.auth.AuthPluginFilter;
import java.net.URI;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* @author hongqiaowei
*/
@@ -81,6 +87,8 @@ public abstract class WebUtils {
private static final String CLIENT_REQUEST_QUERY = "clientRequestQuery";
private static final String traceId = "traceId";
public static final String BACKEND_PATH = "backendPath";
public static boolean logResponseBody = false;
@@ -474,7 +482,7 @@ public abstract class WebUtils {
public static Mono<Void> responseError(ServerWebExchange exchange, String reporter, int code, String msg, Throwable t) {
return responseError(exchange, reporter, code, msg, t, false);
}
public static Mono<Void> responseErrorAndBindContext(ServerWebExchange exchange, String filter, HttpStatus httpStatus) {
ServerHttpResponse response = exchange.getResponse();
String rid = exchange.getRequest().getId();
@@ -486,8 +494,8 @@ public abstract class WebUtils {
transmitFailFilterResult(exchange, filter);
return buildDirectResponseAndBindContext(exchange, httpStatus, new HttpHeaders(), Constants.Symbol.EMPTY);
}
public static Mono<Void> responseErrorAndBindContext(ServerWebExchange exchange, String filter, HttpStatus httpStatus,
public static Mono<Void> responseErrorAndBindContext(ServerWebExchange exchange, String filter, HttpStatus httpStatus,
HttpHeaders headers, String content) {
ServerHttpResponse response = exchange.getResponse();
String rid = exchange.getRequest().getId();
@@ -521,4 +529,18 @@ public abstract class WebUtils {
}
return ip;
}
public static String getTraceId(ServerWebExchange exchange) {
String id = exchange.getAttribute(traceId);
if (id == null) {
ServerHttpRequest request = exchange.getRequest();
String v = request.getHeaders().getFirst(CommonConstants.HEADER_TRACE_ID);
if (StringUtils.isNotBlank(v)) {
id = v;
} else {
id = CommonConstants.TRACE_ID_PREFIX + request.getId();
}
}
return id;
}
}