|
|
@@ -5,6 +5,7 @@ import com.lc.ibps.base.framework.table.ICommonDao;
|
|
|
import org.activiti.engine.impl.asyncexecutor.AcquireTimerJobsRunnable;
|
|
|
import org.slf4j.Logger;
|
|
|
import org.slf4j.LoggerFactory;
|
|
|
+import org.springframework.mock.web.MockHttpServletRequest;
|
|
|
import org.springframework.stereotype.Component;
|
|
|
|
|
|
import javax.annotation.Resource;
|
|
|
@@ -14,7 +15,8 @@ import javax.servlet.http.HttpServletResponse;
|
|
|
import java.io.IOException;
|
|
|
import java.net.URI;
|
|
|
import java.net.URISyntaxException;
|
|
|
-import java.util.Map;
|
|
|
+import java.util.*;
|
|
|
+
|
|
|
/**
|
|
|
* cros跨域访问和host头控制
|
|
|
*
|
|
|
@@ -26,42 +28,43 @@ public class CORSFilter implements Filter {
|
|
|
@Resource
|
|
|
private ICommonDao<?> commonDao;
|
|
|
|
|
|
+ private static final Set<String> ALLOWED_HOSTS = new HashSet<>(); // 缓存允许的Host
|
|
|
+ private static final Set<String> ALLOWED_CORSES = new HashSet<>(); // 缓存允许的Origin
|
|
|
+
|
|
|
+ private long lastRefreshTime_host = 0;
|
|
|
+ private long lastRefreshTime_cors = 0;
|
|
|
+ private long refreshInterval_host = 1 * 60 * 1000; //参数配置默认刷新时间
|
|
|
+ private long refreshInterval_cors = 1 * 60 * 1000; //参数配置默认刷新时间
|
|
|
+
|
|
|
@Override
|
|
|
public void init(FilterConfig filterConfig) throws ServletException {
|
|
|
|
|
|
}
|
|
|
|
|
|
+
|
|
|
+
|
|
|
@Override
|
|
|
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
|
|
|
HttpServletRequest httpRequest = (HttpServletRequest) servletRequest;
|
|
|
HttpServletResponse response = (HttpServletResponse) servletResponse;
|
|
|
-
|
|
|
String requestPath = httpRequest.getRequestURI();
|
|
|
if ("/health".equals(requestPath) || "/ping".equals(requestPath)) {
|
|
|
// 是心跳检查请求
|
|
|
filterChain.doFilter(httpRequest, response);
|
|
|
return;
|
|
|
}
|
|
|
- /**Host头验证*/
|
|
|
- String hostHeader = httpRequest.getHeader("Host");
|
|
|
- if (hostHeader != null) {
|
|
|
- // 从配置中获取允许的host列表(可以同样使用数据库配置)
|
|
|
- String hostSql = " select id_,biao_ti_,can_shu_zhi_1_ from t_zlcsb where shi_fou_qi_yong_ = 1 and jian_zhi_='%s'";
|
|
|
- hostSql = String.format(hostSql,"HOST");
|
|
|
- Map<String,Object> hostzlcs = commonDao.queryOne(hostSql);
|
|
|
- if(BeanUtils.isNotEmpty(hostzlcs)){
|
|
|
- //获取白名单配置
|
|
|
- String bmd = BeanUtils.isNotEmpty(hostzlcs.get("can_shu_zhi_1_")) ?
|
|
|
- (String)hostzlcs.get("can_shu_zhi_1_") : "";
|
|
|
- // 提取主机名(去掉端口)
|
|
|
- String requestHost = hostHeader.split(":")[0];
|
|
|
- if (!bmd.contains(requestHost)) {
|
|
|
- log.warn("白名单{}->非法Host头: {}",bmd, hostHeader);
|
|
|
- response.setStatus(HttpServletResponse.SC_FORBIDDEN);//403
|
|
|
- response.getWriter().write("Invalid Host header not allowed");
|
|
|
- return;
|
|
|
- }
|
|
|
- }
|
|
|
+ //80端口已经禁止了,服务器端也应该禁止TRACE\TRACK方法
|
|
|
+ String method = httpRequest.getMethod();
|
|
|
+ log.warn("method:{}",method);
|
|
|
+ if ("TRACE".equalsIgnoreCase(method) || "TRACK".equalsIgnoreCase(method)) {
|
|
|
+ response.setStatus(HttpServletResponse.SC_METHOD_NOT_ALLOWED);//405
|
|
|
+ response.getWriter().write("method TRACE/TRACK not allowed");
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ // Host头验证
|
|
|
+ if(!validateHostHeader(httpRequest, response)){
|
|
|
+ log.warn("Host头验证白名单失败",ALLOWED_HOSTS);
|
|
|
+ return; // 验证失败已设置响应
|
|
|
}
|
|
|
// === Host头验证结束 ===
|
|
|
// response.setHeader("Access-Control-Allow-Origin", "*");
|
|
|
@@ -74,22 +77,18 @@ public class CORSFilter implements Filter {
|
|
|
String origin = httpRequest.getHeader("Origin");
|
|
|
|
|
|
// 2. 不存在启用的CORS跨域白名单配置,直接跳过
|
|
|
- String sql = " select id_,biao_ti_,can_shu_zhi_1_ from t_zlcsb where shi_fou_qi_yong_ = 1 and jian_zhi_='%s'";
|
|
|
- sql = String.format(sql,"CORS");
|
|
|
- Map<String,Object> corszlcs = commonDao.queryOne(sql);
|
|
|
- if(BeanUtils.isEmpty(corszlcs)){
|
|
|
+ refreshHostWhitelistIfNeeded("CORS");
|
|
|
+ if(ALLOWED_CORSES.contains("N999999")){
|
|
|
+ log.warn("不存在启用的CORS跨域白名单配置,ALLOWED_CORSES:{}",ALLOWED_CORSES);
|
|
|
filterChain.doFilter(httpRequest, response);
|
|
|
return;
|
|
|
}
|
|
|
//3.无Origin头(同源请求或非浏览器请求),跳过
|
|
|
if (origin == null) {
|
|
|
+ log.warn("origin 值为null");
|
|
|
filterChain.doFilter(httpRequest, response);
|
|
|
return;
|
|
|
}
|
|
|
- //获取白名单配置
|
|
|
- String bmd = BeanUtils.isNotEmpty(corszlcs.get("can_shu_zhi_1_")) ?
|
|
|
- (String)corszlcs.get("can_shu_zhi_1_") : "";
|
|
|
-
|
|
|
//4.非同源请求且开启了跨域白名单配置,校验请求是否为白名单
|
|
|
// 提取请求来源的协议+域名(不含端口和路径)
|
|
|
String requestDomain = extractBaseDomain(origin);
|
|
|
@@ -101,14 +100,16 @@ public class CORSFilter implements Filter {
|
|
|
|
|
|
if (requestDomain.equals(currentDomain)) {
|
|
|
// 情况2:同源请求(协议+域名相同,端口不同也视为同源)
|
|
|
+ log.warn("同源请求requestDomain:{}->currentDomain:{}",requestDomain,currentDomain);
|
|
|
filterChain.doFilter(httpRequest, response);
|
|
|
- } else if (bmd.contains(requestDomain)) {
|
|
|
+ } else if (ALLOWED_CORSES.contains(requestDomain)) {
|
|
|
// 情况3:合法的跨域请求(白名单)
|
|
|
// 处理预检请求
|
|
|
/* if ("OPTIONS".equalsIgnoreCase(httpRequest.getMethod())) {
|
|
|
response.setStatus(HttpServletResponse.SC_OK);
|
|
|
return;
|
|
|
}*/
|
|
|
+ //log.warn("合法的跨域请求ALLOWED_CORSES:{}->->requestDomain:{}",requestDomain,currentDomain);
|
|
|
filterChain.doFilter(httpRequest, response);
|
|
|
} else {
|
|
|
// 情况4:非法的跨域请求
|
|
|
@@ -119,6 +120,131 @@ public class CORSFilter implements Filter {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ /**Host头验证 - 改进版*/
|
|
|
+ private boolean validateHostHeader(HttpServletRequest httpRequest, HttpServletResponse response) throws IOException {
|
|
|
+ // 1. 检查并刷新Host白名单缓存
|
|
|
+ refreshHostWhitelistIfNeeded("HOST");
|
|
|
+
|
|
|
+ // 2. 获取Host头(优先考虑X-Forwarded-Host)
|
|
|
+ String hostHeader = getFirstNonEmptyHeader(httpRequest,
|
|
|
+ Arrays.asList("X-Forwarded-Host", "Host"));
|
|
|
+
|
|
|
+ if (hostHeader == null || hostHeader.isEmpty()) {
|
|
|
+ return true; // 没有Host头的情况由上层处理
|
|
|
+ }
|
|
|
+
|
|
|
+ // 3. 提取主机名(更严谨的方式)
|
|
|
+ String requestHost = extractHostname(hostHeader);
|
|
|
+
|
|
|
+ // 4. 验证白名单
|
|
|
+ if(ALLOWED_HOSTS.contains("N999999")){
|
|
|
+ return true;//没有配置的情况下不校验
|
|
|
+ }
|
|
|
+ if (!ALLOWED_HOSTS.contains(requestHost.toLowerCase())) {
|
|
|
+ String clientIP = httpRequest.getRemoteAddr();
|
|
|
+ log.warn("非法Host头拦截 - IP: {}, Host: {}, 白名单: {}",
|
|
|
+ clientIP, hostHeader, ALLOWED_HOSTS);
|
|
|
+ response.setStatus(HttpServletResponse.SC_FORBIDDEN);//403
|
|
|
+ response.getWriter().write("Invalid Host header");
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ // 获取第一个非空的请求头
|
|
|
+ private String getFirstNonEmptyHeader(HttpServletRequest request, List<String> headerNames) {
|
|
|
+ for (String name : headerNames) {
|
|
|
+ String value = request.getHeader(name);
|
|
|
+ if (value != null && !value.trim().isEmpty()) {
|
|
|
+ return value;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ // 更严谨的Host提取方法
|
|
|
+ private String extractHostname(String hostHeader) {
|
|
|
+ try {
|
|
|
+ // 处理端口号
|
|
|
+ String host = hostHeader.split(":")[0];
|
|
|
+ // 处理IPv6地址
|
|
|
+ if (host.startsWith("[") && host.contains("]")) {
|
|
|
+ host = host.substring(1, host.indexOf("]"));
|
|
|
+ }
|
|
|
+ return host.trim().toLowerCase();
|
|
|
+ } catch (Exception e) {
|
|
|
+ return hostHeader.trim().toLowerCase();
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 刷新白名单缓存
|
|
|
+ private synchronized void refreshHostWhitelistIfNeeded(String key) {
|
|
|
+ long currentTime = System.currentTimeMillis();
|
|
|
+ long refreshTime = 0;
|
|
|
+ long lastRefreshTime = 0;
|
|
|
+ if("HOST".equals(key)){
|
|
|
+ refreshTime = refreshInterval_host;
|
|
|
+ lastRefreshTime = lastRefreshTime_host;
|
|
|
+ }
|
|
|
+ if("CORS".equals(key)){
|
|
|
+ refreshTime = refreshInterval_cors;
|
|
|
+ lastRefreshTime = lastRefreshTime_cors;
|
|
|
+ }
|
|
|
+ //每隔配置的时间就查询一次数据库
|
|
|
+ if (currentTime - lastRefreshTime > refreshTime) {
|
|
|
+ try {
|
|
|
+ Set<String> newIPs = new HashSet<>();
|
|
|
+ String paramSql = "select biao_ti_,can_shu_zhi_1_,can_shu_zhi_2_ from t_zlcsb where shi_fou_qi_yong_ = 1 and jian_zhi_='"+key+"'";
|
|
|
+ //String.format(paramSql,key);
|
|
|
+ Map<String,Object> paramzlcs = commonDao.queryOne(paramSql);
|
|
|
+ log.warn("paramzlcs的值为:{}---key值为{}",paramzlcs,key);
|
|
|
+
|
|
|
+ if (BeanUtils.isNotEmpty(paramzlcs)) {
|
|
|
+ String bmdIp = (String)paramzlcs.get("can_shu_zhi_1_");
|
|
|
+ if (bmdIp != null) {
|
|
|
+ // 更严谨的白名单分割方式
|
|
|
+ String[] bmds = bmdIp.split("[,\\s]+");
|
|
|
+ for (String ip : bmds) {
|
|
|
+ if (!ip.trim().isEmpty()) {
|
|
|
+ newIPs.add(ip.trim().toLowerCase());
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ log.warn("拆分后newIPs的值为:{}---key值为{}",newIPs,key);
|
|
|
+ if("HOST".equals(key)){
|
|
|
+ refreshInterval_host = Long.parseLong((String)paramzlcs.get("can_shu_zhi_2_"));
|
|
|
+ lastRefreshTime_host = currentTime;
|
|
|
+ ALLOWED_HOSTS.clear();
|
|
|
+ ALLOWED_HOSTS.addAll(newIPs);
|
|
|
+ log.warn("读取配置后添加的HOST的IP为:{}",ALLOWED_HOSTS);
|
|
|
+ }
|
|
|
+ if("CORS".equals(key)){
|
|
|
+ refreshInterval_cors = Long.parseLong((String)paramzlcs.get("can_shu_zhi_2_"));
|
|
|
+ lastRefreshTime_cors = currentTime;
|
|
|
+ ALLOWED_CORSES.clear();
|
|
|
+ ALLOWED_CORSES.addAll(newIPs);
|
|
|
+ log.warn("读取配置后添加的CORS的IP为:{}",ALLOWED_CORSES);
|
|
|
+ }
|
|
|
+ }else{
|
|
|
+ //没有启用参数,跳过
|
|
|
+ if("CORS".equals(key)){
|
|
|
+ ALLOWED_CORSES.clear();
|
|
|
+ ALLOWED_CORSES.add("N999999");
|
|
|
+ lastRefreshTime_cors = currentTime;
|
|
|
+ log.warn("没有启用参数,跳过CORS");
|
|
|
+ }
|
|
|
+ if("HOST".equals(key)){
|
|
|
+ ALLOWED_HOSTS.clear();
|
|
|
+ ALLOWED_HOSTS.add("N999999");
|
|
|
+ lastRefreshTime_host = currentTime;
|
|
|
+ log.warn("没有启用参数,跳过HOST");
|
|
|
+ }
|
|
|
+ //return;
|
|
|
+ }
|
|
|
+ } catch (Exception e) {
|
|
|
+ log.error("刷新Host白名单失败", e);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
@Override
|
|
|
public void destroy() {
|
|
|
|