Spring AbstractRequestLoggingFilter在大请求时因OOM而失败

时间:2017-07-20 13:35:05

标签: spring out-of-memory servlet-filters

如果我启用setIncludePayload(true)并向servlet发送大量请求,则应用程序因OOM错误而失败。

我使用Spring 3.2.8。

有什么不对?

1 个答案:

答案 0 :(得分:0)

问题是此过滤器不适合生产。它缓存字节数组缓冲区中的所有内容,为OOM提供大量请求,如文件上传。

我更改了源代码以避免此问题,请参阅下文。

注意:有效负载只能在afterRequest方法中访问,因为我们需要将请求主体保存到临时文件中。

import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;

import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;

/**
 * org.springframework.web.filter.AbstractRequestLoggingFilter will fail with OOM on large file upload. We fix it with limited size of byte buffer
 */

public abstract class AbstractRequestLoggingWithMaxSizeCheckFilter extends OncePerRequestFilter {
    public static final String DEFAULT_BEFORE_MESSAGE_PREFIX = "Before request [";
    public static final String DEFAULT_BEFORE_MESSAGE_SUFFIX = "]";
    public static final String DEFAULT_AFTER_MESSAGE_PREFIX = "After request [";
    public static final String DEFAULT_AFTER_MESSAGE_SUFFIX = "]";
    private static final int DEFAULT_MAX_PAYLOAD_LENGTH = 50;
    private boolean includeQueryString = false;
    private boolean includeClientInfo = false;
    private boolean includePayload = false;
    private int maxPayloadLength = 50;
    private String beforeMessagePrefix = "Before request [";
    private String beforeMessageSuffix = "]";
    private String afterMessagePrefix = "After request [";
    private String afterMessageSuffix = "]";

    public AbstractRequestLoggingWithMaxSizeCheckFilter() {
    }

    public void setIncludeQueryString(boolean includeQueryString) {
        this.includeQueryString = includeQueryString;
    }

    protected boolean isIncludeQueryString() {
        return this.includeQueryString;
    }

    public void setIncludeClientInfo(boolean includeClientInfo) {
        this.includeClientInfo = includeClientInfo;
    }

    protected boolean isIncludeClientInfo() {
        return this.includeClientInfo;
    }

    public void setIncludePayload(boolean includePayload) {
        this.includePayload = includePayload;
    }

    protected boolean isIncludePayload() {
        return this.includePayload;
    }

    public void setMaxPayloadLength(int maxPayloadLength) {
        Assert.isTrue(maxPayloadLength >= 0, "'maxPayloadLength' should be larger than or equal to 0");
        this.maxPayloadLength = maxPayloadLength;
    }

    protected int getMaxPayloadLength() {
        return this.maxPayloadLength;
    }

    public void setBeforeMessagePrefix(String beforeMessagePrefix) {
        this.beforeMessagePrefix = beforeMessagePrefix;
    }

    public void setBeforeMessageSuffix(String beforeMessageSuffix) {
        this.beforeMessageSuffix = beforeMessageSuffix;
    }

    public void setAfterMessagePrefix(String afterMessagePrefix) {
        this.afterMessagePrefix = afterMessagePrefix;
    }

    public void setAfterMessageSuffix(String afterMessageSuffix) {
        this.afterMessageSuffix = afterMessageSuffix;
    }

    protected boolean shouldNotFilterAsyncDispatch() {
        return false;
    }

    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        boolean isFirstRequest = !this.isAsyncDispatch((HttpServletRequest) request);
        if (this.isIncludePayload() && isFirstRequest) {
            request = new AbstractRequestLoggingWithMaxSizeCheckFilter.RequestCachingRequestWrapper((HttpServletRequest) request, maxPayloadLength);
        }

        if (isFirstRequest) {
            this.beforeRequest((HttpServletRequest) request, this.getBeforeMessage((HttpServletRequest) request));
        }

        try {
            filterChain.doFilter((ServletRequest) request, response);
        } finally {
            if (!this.isAsyncStarted((HttpServletRequest) request)) {
                this.afterRequest((HttpServletRequest) request, this.getAfterMessage((HttpServletRequest) request));
            }

        }

    }

    private String getBeforeMessage(HttpServletRequest request) {
        return this.createMessage(request, this.beforeMessagePrefix, this.beforeMessageSuffix);
    }

    private String getAfterMessage(HttpServletRequest request) {
        return this.createMessage(request, this.afterMessagePrefix, this.afterMessageSuffix);
    }

    protected String createMessage(HttpServletRequest request, String prefix, String suffix) {
        StringBuilder msg = new StringBuilder();
        msg.append(prefix);
        msg.append("uri=").append(request.getRequestURI());
        if (this.isIncludeQueryString()) {
            msg.append('?').append(request.getQueryString());
        }

        if (this.isIncludeClientInfo()) {
            String client = request.getRemoteAddr();
            if (StringUtils.hasLength(client)) {
                msg.append(";client=").append(client);
            }

            HttpSession session = request.getSession(false);
            if (session != null) {
                msg.append(";session=").append(session.getId());
            }

            String user = request.getRemoteUser();
            if (user != null) {
                msg.append(";user=").append(user);
            }
        }

        if (this.isIncludePayload() && request instanceof AbstractRequestLoggingWithMaxSizeCheckFilter.RequestCachingRequestWrapper) {
            AbstractRequestLoggingWithMaxSizeCheckFilter.RequestCachingRequestWrapper wrapper = (AbstractRequestLoggingWithMaxSizeCheckFilter.RequestCachingRequestWrapper) request;
            byte[] buf = wrapper.toByteArray();
            if (buf.length > 0) {
                String payload;
                try {
                    payload = new String(buf, wrapper.getCharacterEncoding());
                } catch (UnsupportedEncodingException var10) {
                    payload = "[unknown]";
                }

                msg.append(";payload=").append(payload);
            }
        }

        msg.append(suffix);
        return msg.toString();
    }

    protected abstract void beforeRequest(HttpServletRequest var1, String var2);

    protected abstract void afterRequest(HttpServletRequest var1, String var2);

    private static class RequestCachingRequestWrapper extends HttpServletRequestWrapper {
        private final ByteArrayOutputStream bos;
        private final ServletInputStream inputStream;
        private BufferedReader reader;
        private int maxPayloadLength;
        private boolean capped;

        private RequestCachingRequestWrapper(HttpServletRequest request, int maxPayloadLength) throws IOException {
            super(request);
            this.bos = new ByteArrayOutputStream();
            this.inputStream = new AbstractRequestLoggingWithMaxSizeCheckFilter.RequestCachingRequestWrapper.RequestCachingInputStream(request.getInputStream());
            this.maxPayloadLength = maxPayloadLength;
        }

        public ServletInputStream getInputStream() throws IOException {
            return this.inputStream;
        }

        public String getCharacterEncoding() {
            return super.getCharacterEncoding() != null ? super.getCharacterEncoding() : "ISO-8859-1";
        }

        public BufferedReader getReader() throws IOException {
            if (this.reader == null) {
                this.reader = new BufferedReader(new InputStreamReader(this.inputStream, this.getCharacterEncoding()));
            }

            return this.reader;
        }

        private byte[] toByteArray() {
            return this.bos.toByteArray();
        }

        private class RequestCachingInputStream extends ServletInputStream {
            private final ServletInputStream is;

            private RequestCachingInputStream(ServletInputStream is) {
                this.is = is;
            }

            public int read() throws IOException {
                int ch = this.is.read();
                if (ch != -1) {
                    if (!capped) {
                        AbstractRequestLoggingWithMaxSizeCheckFilter.RequestCachingRequestWrapper.this.bos.write(ch);
                        if (AbstractRequestLoggingWithMaxSizeCheckFilter.RequestCachingRequestWrapper.this.bos.size() >= maxPayloadLength) {
                            AbstractRequestLoggingWithMaxSizeCheckFilter.RequestCachingRequestWrapper.this.bos.write("...(truncated)".getBytes("UTF-8"));
                            capped = true;
                        }
                    }
                }

                return ch;
            }
        }
    }
}