Sepehr GH Sepehr GH - 23 days ago 10
Java Question

Resolving multipart/form-data request in spring filter

I'm trying to write and develop my own CSRF filter in Spring MVC 3 (There are some extra trainings that made me do that, so please do not suggest me to use Spring Security. I know it, thank you!!)

My filter works fine with all forms except those that have enctype="multipart/form-data", So actually I can not get request parameters from normal HttpServletRequest.

Ive tried casting it to MultipartHttpServletRequest but I found out I can not do that either.

Note that my objective is not getting files, only simple form input named "csrf". Ive already uploaded files with my forms.

Here is my code till now:

CSRFilter

public class CSRFilter extends GenericFilterBean {
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req;
HttpServletResponse response = (HttpServletResponse) res;

CSRF csrf = new CSRF(req);
if(csrf.isOk()){
chain.doFilter(req, res);
}else {
//todo : Show Error Page
String redirect = request.getScheme() + "://" + request.getServerName() + ":" + request.getServerPort() + request.getContextPath() + "/access-forbidden";
response.sendRedirect(redirect);
}

}
}


CSRF

public class CSRF {
HttpServletRequest request;
ServletRequest req;
String token;
boolean ok;
private static final Logger logger = Logger.getLogger(CSRF.class);


public CSRF(ServletRequest request) {
this.request = (HttpServletRequest) request;
this.req = request;
init();
}

public CSRF() {
}


public void setRequest(HttpServletRequest request) {
this.request = (HttpServletRequest) request;
this.req = request;
init();
}

private void init() {
if (request.getMethod().equals("GET")) {
generateToken();
addCSRFTokenToSession();
addCSRFTokenToModelAttribute();
ok = true;
} else if (request.getMethod().equals("POST")) {
if (checkPostedCsrfToken()) {
ok = true;
}
}
}

private void generateToken() {
String token;
java.util.Date date = new java.util.Date();
UUID uuid = UUID.randomUUID();
token = uuid.toString() + String.valueOf(new Timestamp(date.getTime()));
try {
this.token = sha1(token);
} catch (NoSuchAlgorithmException e) {
e.printStackTrace();
this.token = token;
}
}

private void addCSRFTokenToSession() {
request.getSession().setAttribute("csrf", token);
}

private void addCSRFTokenToModelAttribute() {
request.setAttribute("csrf", token);
}

private boolean checkPostedCsrfToken() {
System.out.println("____ CSRF CHECK POST _____");
if (request.getParameterMap().containsKey("csrf")) {
String csrf = request.getParameter("csrf");
if (csrf.equals(request.getSession().getAttribute("csrf"))) {
return true;
}
}else {
//Check for multipart requests

MultipartHttpServletRequest multiPartRequest = new DefaultMultipartHttpServletRequest((HttpServletRequest) req);
if (multiPartRequest.getParameterMap().containsKey("csrf")) {
String csrf = multiPartRequest.getParameter("csrf");
if (csrf.equals(request.getSession().getAttribute("csrf"))) {
return true;
}
}
}

log();
return false;
}

private void log() {
HttpSession session = request.getSession();
String username = (String) session.getAttribute("username");
if(username==null){
username = "unknown (not logged in)";
}
String ipAddress = request.getHeader("X-FORWARDED-FOR");
if (ipAddress == null) {
ipAddress = request.getRemoteAddr();
}
String userAgent = request.getHeader("User-Agent");
String address = request.getRequestURI();
System.out.println("a CSRF attack detected from IP: " + ipAddress + " in address \"" + address + "\" - Client User Agent : " + userAgent + " Username: " + username);

logger.error("a CSRF attack detected from IP: " + ipAddress + " in address \"" + address + "\" - Client User Agent : " + userAgent + " Username: " + username);
}

public boolean isOk() {
return ok;
}

static String sha1(String input) throws NoSuchAlgorithmException {
MessageDigest mDigest = MessageDigest.getInstance("SHA1");
byte[] result = mDigest.digest(input.getBytes());
StringBuffer sb = new StringBuffer();
for (int i = 0; i < result.length; i++) {
sb.append(Integer.toString((result[i] & 0xff) + 0x100, 16).substring(1));
}
return sb.toString();
}
}


I have this line in my dispatcher too :

<bean id="multipartResolver" class="org.springframework.web.multipart.commons.CommonsMultipartResolver">
<!-- one of the properties available; the maximum file size in bytes -->
<property name="maxUploadSize" value="40000000"/>
</bean>


and also I use springMultipartResolver filter ...

<filter>
<display-name>springMultipartFilter</display-name>
<filter-name>springMultipartFilter</filter-name>
<filter-class>org.springframework.web.multipart.support.MultipartFilter</filter-class>
</filter>
<filter-mapping>
<filter-name>springMultipartFilter</filter-name>
<url-pattern>/*</url-pattern>
</filter-mapping>
</filter>


I get
java.lang.IllegalStateException: Multipart request not initialized
Exception when I try it on multipart/form-data forms.

I looked at many Examples in internet. Most of them was for file uploading purpose and could not help me, I also tried different ways to cast HttpServletRequest to any other object that gives me resolved multipart request, But I could not succeed.

How can I do it ?

Thanks.

Answer

You can not cast HttpServletRequest to MultipartHttpServletRequest, because you first have to resolve your request.

I used CommonsMultipartResolver Class and got MultipartHttpServletRequest using commonsMultipartResolver.resolveMultipart(request) method where request is type of HttpServletRequest.

So, here is my CSRF class, checkPostedCsrfToken() method:

private boolean checkPostedCsrfToken() {
        if (request.getParameterMap().containsKey("csrf")) {
            String csrf = request.getParameter("csrf");
            if (csrf.equals(request.getSession().getAttribute("csrf"))) {
                return true;
            }
        } else if (request.getContentType() != null && request.getContentType().toLowerCase().contains("multipart/form-data")) {
            CommonsMultipartResolver commonsMultipartResolver = new CommonsMultipartResolver();
            MultipartHttpServletRequest multipartRequest = commonsMultipartResolver.resolveMultipart(request);
            if (multipartRequest.getParameterMap().containsKey("csrf")) {
                String csrf = multipartRequest.getParameter("csrf");
                if (csrf.equals(request.getSession().getAttribute("csrf"))) {
                    return true;
                }
            }
        }

        log();
        return false;
    }

But, Note that you will loose all request parameters and data. So you have to extend HttpServletRequestWrapper class to read request bytes and use them to get parameters if it matters to you that parameters don't get lost throw filter chain.

Here is a good helper class I found in StackOverflow, (I cant find the question again, I will edit this if I find it).

MultiReadHttpServletRequest

public class MultiReadHttpServletRequest extends HttpServletRequestWrapper {
    private ByteArrayOutputStream cachedBytes;

    public MultiReadHttpServletRequest(HttpServletRequest request) {
        super(request);
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        if (cachedBytes == null)
            cacheInputStream();

        return new CachedServletInputStream();
    }

    @Override
    public BufferedReader getReader() throws IOException{
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    private void cacheInputStream() throws IOException {
    /* Cache the inputstream in order to read it multiple times. For
     * convenience, I use apache.commons IOUtils
     */
        cachedBytes = new ByteArrayOutputStream();
        IOUtils.copy(super.getInputStream(), cachedBytes);
    }

    /* An inputstream which reads the cached request body */
    public class CachedServletInputStream extends ServletInputStream {
        private ByteArrayInputStream input;

        public CachedServletInputStream() {
      /* create a new input stream from the cached request body */
            input = new ByteArrayInputStream(cachedBytes.toByteArray());
        }

        @Override
        public int read() throws IOException {
            return input.read();
        }
    }
}

now all you need to do is to use MultiReadHttpServletRequest instead of normal HttpServletRequest in filter :

public class CSRFilter extends GenericFilterBean {
    @Override
    public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest) req;
        HttpServletResponse response = (HttpServletResponse) res;
        MultiReadHttpServletRequest multiReadHttpServletRequest = new MultiReadHttpServletRequest(request);
        CSRF csrf = new CSRF(multiReadHttpServletRequest);
        if(csrf.isOk()){
            chain.doFilter(multiReadHttpServletRequest, res);
        }else {
            //todo : Show Error Page
            String redirect = request.getScheme() + "://" + request.getServerName() + ":" + request.getServerPort() + request.getContextPath() + "/access-forbidden";
            response.sendRedirect(redirect);
        }
    }
}

I wish this helps someone :)