Spring boot配置XSS防御过滤

后端接口要做XSS攻击防御,从网上查一下,有很多防御方式。对于什么是XSS攻击,网上也有很多解释。本篇博客就针对自己项目需要做下记录。

框架:前后端分离、Spring Boot

场景:后端接口参数不定,有@RequestBody形式接收,有@RequestParam形式接收,所以会有不同处理。

下面贴上代码:

过滤器:

public class XssFilter implements Filter {

    @Override
    public void init(FilterConfig config) throws ServletException {}

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
            throws IOException, ServletException {
        XssHttpServletRequestWrapper xssHttpServletRequestWrapper = new XssHttpServletRequestWrapper((HttpServletRequest)request);
        chain.doFilter(xssHttpServletRequestWrapper, response);
    }

    @Override
    public void destroy() {}
}

敏感字符转换类:

public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {
    HttpServletRequest orgRequest = null;

    private String body;

    public XssHttpServletRequestWrapper(HttpServletRequest request) {
        super(request);
        orgRequest = request;
        body = HttpGetBody.getBodyString(request);
    }

    /**
     * 覆盖getParameter方法,将参数名和参数值都做xss过滤。
* 如果需要获得原始的值,则通过super.getParameterValues(name)来获取
* getParameterNames,getParameterValues和getParameterMap也可能需要覆盖 */ @Override public String getParameter(String name) { String value = super.getParameter(xssEncode(name, 0)); if (null != value) { value = xssEncode(value, 0); } return value; } @Override public String[] getParameterValues(String name) { String[] values = super.getParameterValues(xssEncode(name, 0)); if (values == null) { return null; } int count = values.length; String[] encodedValues = new String[count]; for (int i = 0; i < count; i++) { encodedValues[i] = xssEncode(values[i], 0); } return encodedValues; } @Override public Map getParameterMap() { HashMap paramMap = (HashMap) super.getParameterMap(); paramMap = (HashMap) paramMap.clone(); for (Iterator iterator = paramMap.entrySet().iterator(); iterator.hasNext(); ) { Map.Entry entry = (Map.Entry) iterator.next(); String[] values = (String[]) entry.getValue(); for (int i = 0; i < values.length; i++) { if (values[i] instanceof String) { values[i] = xssEncode(values[i], 0); } } entry.setValue(values); } return paramMap; } @Override public ServletInputStream getInputStream() throws IOException { ServletInputStream inputStream = null; if (StringUtil.isNotEmpty(body)) { body = xssEncode(body, 1); inputStream = new TranslateServletInputStream(body); } return inputStream; } /** * 覆盖getHeader方法,将参数名和参数值都做xss过滤。
* 如果需要获得原始的值,则通过super.getHeaders(name)来获取
* getHeaderNames 也可能需要覆盖 */ @Override public String getHeader(String name) { String value = super.getHeader(xssEncode(name, 0)); if (value != null) { value = xssEncode(value, 0); } return value; } /** * 将容易引起xss漏洞的半角字符直接替换成全角字符 * * @param s * @return */ private static String xssEncode(String s, int type) { if (s == null || s.isEmpty()) { return s; } StringBuilder sb = new StringBuilder(s.length() + 16); for (int i = 0; i < s.length(); i++) { char c = s.charAt(i); if (type == 0) { switch (c) { case '\'': // 全角单引号 sb.append('‘'); break; case '\"': // 全角双引号 sb.append('“'); break; case '>': // 全角大于号 sb.append('>'); break; case '<': // 全角小于号 sb.append('<'); break; case '&': // 全角&符号 sb.append('&'); break; case '\\': // 全角斜线 sb.append('\'); break; case '#': // 全角井号 sb.append('#'); break; // < 字符的 URL 编码形式表示的 ASCII 字符(十六进制格式) 是: %3c case '%': processUrlEncoder(sb, s, i); break; default: sb.append(c); break; } } else { switch (c) { case '>': // 全角大于号 sb.append('>'); break; case '<': // 全角小于号 sb.append('<'); break; case '&': // 全角&符号 sb.append('&'); break; case '\\': // 全角斜线 sb.append('\'); break; case '#': // 全角井号 sb.append('#'); break; // < 字符的 URL 编码形式表示的 ASCII 字符(十六进制格式) 是: %3c case '%': processUrlEncoder(sb, s, i); break; default: sb.append(c); break; } } } return sb.toString(); } public static void processUrlEncoder(StringBuilder sb, String s, int index) { if (s.length() >= index + 2) { // %3c, %3C if (s.charAt(index + 1) == '3' && (s.charAt(index + 2) == 'c' || s.charAt(index + 2) == 'C')) { sb.append('<'); return; } // %3c (0x3c=60) if (s.charAt(index + 1) == '6' && s.charAt(index + 2) == '0') { sb.append('<'); return; } // %3e, %3E if (s.charAt(index + 1) == '3' && (s.charAt(index + 2) == 'e' || s.charAt(index + 2) == 'E')) { sb.append('>'); return; } // %3e (0x3e=62) if (s.charAt(index + 1) == '6' && s.charAt(index + 2) == '2') { sb.append('>'); return; } } sb.append(s.charAt(index)); } /** * 获取最原始的request * * @return */ public HttpServletRequest getOrgRequest() { return orgRequest; } /** * 获取最原始的request的静态方法 * * @return */ public static HttpServletRequest getOrgRequest(HttpServletRequest req) { if (req instanceof XssHttpServletRequestWrapper) { return ((XssHttpServletRequestWrapper) req).getOrgRequest(); } return req; } }

配置类:

@Configuration
public class XSSFilterConfig {

    @Bean
    public FilterRegistrationBean filterRegistrationBean() {
        FilterRegistrationBean registration = new FilterRegistrationBean();
        registration.setFilter(xssFilter());
        registration.addUrlPatterns("/*");
        registration.addInitParameter("paramName", "paramValue");
        registration.setName("xssFilter");
        return registration;
    }

    /**
     * 创建一个bean
     * @return
     */
    @Bean(name = "xssFilter")
    public Filter xssFilter() {
        return new XssFilter();
    }
}

补充两个工具类的代码:

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;


public class TranslateServletInputStream extends ServletInputStream {
    private InputStream inputStream;
    /**
     * 解析json之后的文本
     */
    private String body;

    public TranslateServletInputStream(String body) throws IOException {
        this.body = body;
        inputStream = null;
    }

    @Override
    public boolean isReady() {
        return false;
    }

    @Override
    public void setReadListener(ReadListener readListener) {

    }

    @Override
    public boolean isFinished() {
        return false;
    }

    private InputStream acquireInputStream() throws IOException {
        if (inputStream == null) {
            inputStream = new ByteArrayInputStream(body.getBytes(Charset.forName("UTF-8")));
            //通过解析之后传入的文本生成inputStream以便后面controller调用
        }

        return inputStream;
    }

    @Override
    public void close() throws IOException {
        try {
            if (inputStream != null) {
                inputStream.close();
            }
        } catch (IOException e) {
            throw e;
        } finally {
            inputStream = null;
        }
    }

    @Override
    public boolean markSupported() {
        return false;
    }

    @Override
    public synchronized void mark(int i) {
        throw new UnsupportedOperationException("mark not supported");
    }

    @Override
    public synchronized void reset() throws IOException {
        throw new IOException(new UnsupportedOperationException("reset not supported"));
    }

    @Override
    public int read() throws IOException {
        return acquireInputStream().read();

    }

}
import javax.servlet.ServletRequest;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.Charset;


public class HttpGetBody {

    /**
     * 获取请求Body
     * @param request
     * @return
     */
    public static String getBodyString(ServletRequest request) {
        StringBuffer sb = new StringBuffer();
        InputStream inputStream = null;
        BufferedReader reader = null;
        try {
            inputStream = request.getInputStream();
            reader = new BufferedReader(new InputStreamReader(inputStream, Charset.forName("UTF-8")));
            String line = "";
            while ((line = reader.readLine()) != null) {
                sb.append(line);
            }
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (inputStream != null) {
                try {
                    inputStream.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
        return sb.toString();
    }
}

如上代码所示,测试是参数或url传入敏感字符,即可转换为安全字符。

你可能感兴趣的:(Java)