后端接口要做XSS攻击防御,从网上查一下,有很多防御方式。对于什么是XSS攻击,网上也有很多解释。本篇博客就针对自己项目需要做下记录。
框架:前后端分离、Spring Boot
场景:后端接口参数不定,有@RequestBody形式接收,有@RequestParam形式接收,所以会有不同处理。
下面贴上代码:
过滤器:
public class XssFilter implements Filter {
@Override
public void init(FilterConfig config) throws ServletException {}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
XssHttpServletRequestWrapper xssHttpServletRequestWrapper = new XssHttpServletRequestWrapper((HttpServletRequest)request);
chain.doFilter(xssHttpServletRequestWrapper, response);
}
@Override
public void destroy() {}
}
敏感字符转换类:
public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {
HttpServletRequest orgRequest = null;
private String body;
public XssHttpServletRequestWrapper(HttpServletRequest request) {
super(request);
orgRequest = request;
body = HttpGetBody.getBodyString(request);
}
/**
* 覆盖getParameter方法,将参数名和参数值都做xss过滤。
* 如果需要获得原始的值,则通过super.getParameterValues(name)来获取
* getParameterNames,getParameterValues和getParameterMap也可能需要覆盖
*/
@Override
public String getParameter(String name) {
String value = super.getParameter(xssEncode(name, 0));
if (null != value) {
value = xssEncode(value, 0);
}
return value;
}
@Override
public String[] getParameterValues(String name) {
String[] values = super.getParameterValues(xssEncode(name, 0));
if (values == null) {
return null;
}
int count = values.length;
String[] encodedValues = new String[count];
for (int i = 0; i < count; i++) {
encodedValues[i] = xssEncode(values[i], 0);
}
return encodedValues;
}
@Override
public Map getParameterMap() {
HashMap paramMap = (HashMap) super.getParameterMap();
paramMap = (HashMap) paramMap.clone();
for (Iterator iterator = paramMap.entrySet().iterator(); iterator.hasNext(); ) {
Map.Entry entry = (Map.Entry) iterator.next();
String[] values = (String[]) entry.getValue();
for (int i = 0; i < values.length; i++) {
if (values[i] instanceof String) {
values[i] = xssEncode(values[i], 0);
}
}
entry.setValue(values);
}
return paramMap;
}
@Override
public ServletInputStream getInputStream() throws IOException {
ServletInputStream inputStream = null;
if (StringUtil.isNotEmpty(body)) {
body = xssEncode(body, 1);
inputStream = new TranslateServletInputStream(body);
}
return inputStream;
}
/**
* 覆盖getHeader方法,将参数名和参数值都做xss过滤。
* 如果需要获得原始的值,则通过super.getHeaders(name)来获取
* getHeaderNames 也可能需要覆盖
*/
@Override
public String getHeader(String name) {
String value = super.getHeader(xssEncode(name, 0));
if (value != null) {
value = xssEncode(value, 0);
}
return value;
}
/**
* 将容易引起xss漏洞的半角字符直接替换成全角字符
*
* @param s
* @return
*/
private static String xssEncode(String s, int type) {
if (s == null || s.isEmpty()) {
return s;
}
StringBuilder sb = new StringBuilder(s.length() + 16);
for (int i = 0; i < s.length(); i++) {
char c = s.charAt(i);
if (type == 0) {
switch (c) {
case '\'':
// 全角单引号
sb.append('‘');
break;
case '\"':
// 全角双引号
sb.append('“');
break;
case '>':
// 全角大于号
sb.append('>');
break;
case '<':
// 全角小于号
sb.append('<');
break;
case '&':
// 全角&符号
sb.append('&');
break;
case '\\':
// 全角斜线
sb.append('\');
break;
case '#':
// 全角井号
sb.append('#');
break;
// < 字符的 URL 编码形式表示的 ASCII 字符(十六进制格式) 是: %3c
case '%':
processUrlEncoder(sb, s, i);
break;
default:
sb.append(c);
break;
}
} else {
switch (c) {
case '>':
// 全角大于号
sb.append('>');
break;
case '<':
// 全角小于号
sb.append('<');
break;
case '&':
// 全角&符号
sb.append('&');
break;
case '\\':
// 全角斜线
sb.append('\');
break;
case '#':
// 全角井号
sb.append('#');
break;
// < 字符的 URL 编码形式表示的 ASCII 字符(十六进制格式) 是: %3c
case '%':
processUrlEncoder(sb, s, i);
break;
default:
sb.append(c);
break;
}
}
}
return sb.toString();
}
public static void processUrlEncoder(StringBuilder sb, String s, int index) {
if (s.length() >= index + 2) {
// %3c, %3C
if (s.charAt(index + 1) == '3' && (s.charAt(index + 2) == 'c' || s.charAt(index + 2) == 'C')) {
sb.append('<');
return;
}
// %3c (0x3c=60)
if (s.charAt(index + 1) == '6' && s.charAt(index + 2) == '0') {
sb.append('<');
return;
}
// %3e, %3E
if (s.charAt(index + 1) == '3' && (s.charAt(index + 2) == 'e' || s.charAt(index + 2) == 'E')) {
sb.append('>');
return;
}
// %3e (0x3e=62)
if (s.charAt(index + 1) == '6' && s.charAt(index + 2) == '2') {
sb.append('>');
return;
}
}
sb.append(s.charAt(index));
}
/**
* 获取最原始的request
*
* @return
*/
public HttpServletRequest getOrgRequest() {
return orgRequest;
}
/**
* 获取最原始的request的静态方法
*
* @return
*/
public static HttpServletRequest getOrgRequest(HttpServletRequest req) {
if (req instanceof XssHttpServletRequestWrapper) {
return ((XssHttpServletRequestWrapper) req).getOrgRequest();
}
return req;
}
}
配置类:
@Configuration
public class XSSFilterConfig {
@Bean
public FilterRegistrationBean filterRegistrationBean() {
FilterRegistrationBean registration = new FilterRegistrationBean();
registration.setFilter(xssFilter());
registration.addUrlPatterns("/*");
registration.addInitParameter("paramName", "paramValue");
registration.setName("xssFilter");
return registration;
}
/**
* 创建一个bean
* @return
*/
@Bean(name = "xssFilter")
public Filter xssFilter() {
return new XssFilter();
}
}
补充两个工具类的代码:
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
public class TranslateServletInputStream extends ServletInputStream {
private InputStream inputStream;
/**
* 解析json之后的文本
*/
private String body;
public TranslateServletInputStream(String body) throws IOException {
this.body = body;
inputStream = null;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
@Override
public boolean isFinished() {
return false;
}
private InputStream acquireInputStream() throws IOException {
if (inputStream == null) {
inputStream = new ByteArrayInputStream(body.getBytes(Charset.forName("UTF-8")));
//通过解析之后传入的文本生成inputStream以便后面controller调用
}
return inputStream;
}
@Override
public void close() throws IOException {
try {
if (inputStream != null) {
inputStream.close();
}
} catch (IOException e) {
throw e;
} finally {
inputStream = null;
}
}
@Override
public boolean markSupported() {
return false;
}
@Override
public synchronized void mark(int i) {
throw new UnsupportedOperationException("mark not supported");
}
@Override
public synchronized void reset() throws IOException {
throw new IOException(new UnsupportedOperationException("reset not supported"));
}
@Override
public int read() throws IOException {
return acquireInputStream().read();
}
}
import javax.servlet.ServletRequest;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
public class HttpGetBody {
/**
* 获取请求Body
* @param request
* @return
*/
public static String getBodyString(ServletRequest request) {
StringBuffer sb = new StringBuffer();
InputStream inputStream = null;
BufferedReader reader = null;
try {
inputStream = request.getInputStream();
reader = new BufferedReader(new InputStreamReader(inputStream, Charset.forName("UTF-8")));
String line = "";
while ((line = reader.readLine()) != null) {
sb.append(line);
}
} catch (IOException e) {
e.printStackTrace();
} finally {
if (inputStream != null) {
try {
inputStream.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if (reader != null) {
try {
reader.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
return sb.toString();
}
}
如上代码所示,测试是参数或url传入敏感字符,即可转换为安全字符。