背景
springcloud gateway作为流量的统一入口,我们需要实现用户校验的功能以及打印接口调用的请求地址及参数。
功能实现
打印接口调用的请求地址及参数
这个功能需要拦截所有的请求,所以我们去定义一个全局的 GlobalFilter。
@Component("LogFilter")
public class Logfilter implements GlobalFilter, Ordered {
Logger logger = LoggerFactory.getLogger(this.getClass());
private final List<HttpMessageReader<?>> messageReaders = HandlerStrategies.withDefaults().messageReaders();
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
// 获取用户传来的数据类型
MediaType mediaType = exchange.getRequest().getHeaders().getContentType();
ServerRequest serverRequest = ServerRequest.create(exchange, messageReaders);
// 如果是json格式,将body内容转化为object or map 都可
if (MediaType.APPLICATION_JSON.isCompatibleWith(mediaType)) {
Mono<Object> modifiedBody = serverRequest.bodyToMono(Object.class)
.flatMap(body -> {
recordLog(exchange, body);
return Mono.just(body);
});
return getVoidMono(exchange, chain, Object.class, modifiedBody);
}
// 如果是表单请求
else if (MediaType.APPLICATION_FORM_URLENCODED.isCompatibleWith(mediaType)) {
Mono<String> modifiedBody = serverRequest.bodyToMono(String.class)
// .log("modify_request_mono", Level.INFO)
.flatMap(body -> {
recordLog(exchange, body);
return Mono.just(body);
});
return getVoidMono(exchange, chain, String.class, modifiedBody);
}
// 无法兼容的请求,则不读取body,像Get请求这种
recordLog(exchange, "");
return chain.filter(exchange.mutate().request(exchange.getRequest()).build());
}
/**
* 参照 ModifyRequestBodyGatewayFilterFactory.java 截取的方法
*
* @param exchange
* @param chain
* @param outClass
* @param modifiedBody
* @return
*/
private Mono<Void> getVoidMono(ServerWebExchange exchange, GatewayFilterChain chain, Class outClass,
Mono<?> modifiedBody) {
BodyInserter bodyInserter = BodyInserters.fromPublisher(modifiedBody, outClass);
HttpHeaders headers = new HttpHeaders();
headers.putAll(exchange.getRequest().getHeaders());
// the new content type will be computed by bodyInserter
// and then set in the request decorator
headers.remove(HttpHeaders.CONTENT_LENGTH);
CachedBodyOutputMessage outputMessage = new CachedBodyOutputMessage(exchange, headers);
return bodyInserter.insert(outputMessage, new BodyInserterContext())
// .log("modify_request", Level.INFO)
.then(Mono.defer(() -> {
//由于httpRequest的body 体只能读取一次,所以需要重新构建一个httpRequest保证后续获取body不报错
ServerHttpRequestDecorator decorator = new ServerHttpRequestDecorator(
exchange.getRequest()) {
@Override
public HttpHeaders getHeaders() {
long contentLength = headers.getContentLength();
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.putAll(super.getHeaders());
if (contentLength > 0) {
httpHeaders.setContentLength(contentLength);
} else {
// TODO: this causes a 'HTTP/1.1 411 Length Required' on httpbin.org
httpHeaders.set(HttpHeaders.TRANSFER_ENCODING, "chunked");
}
return httpHeaders;
}
@Override
public Flux<DataBuffer> getBody() {
return outputMessage.getBody();
}
};
return chain.filter(exchange.mutate().request(decorator).build());
}));
}
/**
* 记录到请求日志中去
*
* @param exchange request
* @param body 请求的body内容
*/
private void recordLog(ServerWebExchange exchange, Object body) {
ServerHttpRequest request = exchange.getRequest();
TreeMap<String,Object> params = new TreeMap<>();
// 记录要访问的url
StringBuilder builder = new StringBuilder(" request url: ");
builder.append(request.getURI().getRawPath());
// 记录访问的方法
HttpMethod method = request.getMethod();
if (null != method) {
builder.append(", method: ").append(method.name());
}
// 记录头部信息
builder.append(", header { ");
for (Map.Entry<String, List<String>> entry : request.getHeaders().entrySet()) {
builder.append(entry.getKey()).append(":").append(String.join(",", entry.getValue())).append(",");
}
// 记录参数
builder.append("} param: ");
// 处理get的请求
if (null != method && HttpMethod.GET.matches(method.name())) {
// 记录请求的参数信息 针对GET 请求
MultiValueMap<String, String> queryParams = request.getQueryParams();
for (Map.Entry<String, List<String>> entry : queryParams.entrySet()) {
builder.append(entry.getKey()).append("=").append(String.join(",", entry.getValue())).append(",");
}
} else {
//因为参数后续可能会参与校验 为了后续获取方便,从body中读取参数 将参数保存到exchange的attributes属性中
builder.append(body);
JSONObject jsonObject = (JSONObject)JSONObject.toJSON(body);
params.putAll(jsonObject.getInnerMap());
exchange.getAttributes().put("CacheRequestFilter",params);
}
logger.info(builder.toString());
}
@Override
public int getOrder() {
//order 配置的是拦截器触发的顺序 ,需要该全局的拦截器在其他拦截器之前,数值越小越靠前
return 10;
}
}
用户校验
用户校验只需要对部分接口做校验,因此我们定义了GatewayFilter
spring.cloud.gateway.routes[0].id=web
spring.cloud.gateway.routes[0].uri=lb://api-server
spring.cloud.gateway.routes[0].predicates[0]=Path=/cp/**
spring.cloud.gateway.routes[0].filters[0]=StripPrefix=1
//只有path为cp开头的请求才会走校验
spring.cloud.gateway.routes[0].filters[1]=Sign=true
public class SignGetwayFilter implements GatewayFilter, Ordered {
private final Logger logger = LoggerFactory.getLogger(SignGetwayFilter.class);
private final List<HttpMessageReader<?>> messageReaders = HandlerStrategies.withDefaults().messageReaders();
private List<String> excludepath;
private SignConfig signatureConfig;
public SignGetwayFilter(SignConfig signatureConfig, List<String> excludepath) {
this.signatureConfig = signatureConfig;
this.excludepath = excludepath;
}
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
ServerHttpRequest request = exchange.getRequest();
ServerHttpResponse response = exchange.getResponse();
HttpMethod method = request.getMethod();
String path = request.getURI().getPath();
//excludpath不需要验签
if (excludepath != null && !excludepath.isEmpty() && excludepath.contains(path)) {
return chain.filter(exchange);
}
if (method == HttpMethod.GET) {
return ResponseUtils.createErrorResult(response, -1, "不支持get请求");
}
String content_type = request.getHeaders().get(HttpHeaders.CONTENT_TYPE).get(0);
logger.debug("content_type={}", content_type);
if (StringUtils.isBlank(content_type) || !content_type.equals(MediaType.APPLICATION_JSON_VALUE)) {
return ResponseUtils.createErrorResult(response, -1, SignStatus.Failed.getMessage());
}
DataBufferFactory bufferFactory = response.bufferFactory();
//获取从Logfilter保存的请求参数
TreeMap<String,Object> treeMap = exchange.getAttribute("CacheRequestFilter");
//检验参数
if(!check(treeMap)){
//如果未通过返回错误信息
return ResponseUtils.createErrorResult(response, -1, "未通过");
}
return chain.filter(exchange);
}
@Override
public int getOrder() {
return 100;
}
}
@Component
public class SignGatewayFilterFactory extends
AbstractGatewayFilterFactory<SignGatewayFilterFactory.Config> {
private final Logger logger = LoggerFactory.getLogger(SignGatewayFilterFactory.class);
@Autowired
private SignConfig signatureConfig;
public SignGatewayFilterFactory() {
super(Config.class);
}
@Override
public SignGetwayFilter apply(Config config) {
return new SignGetwayFilter(this.signatureConfig, config.getIncludepath(), config.getExcludepath());
}
}