diff --git a/spring-web/src/main/java/org/springframework/web/bind/support/WebRequestDataBinder.java b/spring-web/src/main/java/org/springframework/web/bind/support/WebRequestDataBinder.java index bc870ea5d556..a4fbe06e86c9 100644 --- a/spring-web/src/main/java/org/springframework/web/bind/support/WebRequestDataBinder.java +++ b/spring-web/src/main/java/org/springframework/web/bind/support/WebRequestDataBinder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2012 the original author or authors. + * Copyright 2002-2013 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,12 +17,20 @@ package org.springframework.web.bind.support; import org.springframework.beans.MutablePropertyValues; +import org.springframework.util.ClassUtils; +import org.springframework.util.StringUtils; import org.springframework.validation.BindException; import org.springframework.web.bind.WebDataBinder; import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.context.request.WebRequest; +import org.springframework.web.multipart.MultipartException; import org.springframework.web.multipart.MultipartRequest; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.Part; +import java.io.IOException; + /** * Special {@link org.springframework.validation.DataBinder} to perform data binding * from web request parameters to JavaBeans, including support for multipart files. @@ -59,6 +67,10 @@ */ public class WebRequestDataBinder extends WebDataBinder { + private static final String MULTIPART_CONTENT_TYPE = "multipart"; + + private static final String CONTENT_TYPE = "Content-Type"; + /** * Create a new WebRequestDataBinder instance, with default object name. * @param target the target object to bind onto (or {@code null} @@ -89,21 +101,27 @@ public WebRequestDataBinder(Object target, String objectName) { *

Multipart files are bound via their parameter name, just like normal * HTTP parameters: i.e. "uploadedFile" to an "uploadedFile" bean property, * invoking a "setUploadedFile" setter method. - *

The type of the target property for a multipart file can be MultipartFile, + *

The type of the target property for a multipart file can be Part, MultipartFile, * byte[], or String. The latter two receive the contents of the uploaded file; * all metadata like original file name, content type, etc are lost in those cases. * @param request request with parameters to bind (can be multipart) * @see org.springframework.web.multipart.MultipartRequest * @see org.springframework.web.multipart.MultipartFile - * @see #bindMultipartFiles + * @see javax.servlet.http.Part * @see #bind(org.springframework.beans.PropertyValues) */ public void bind(WebRequest request) { MutablePropertyValues mpvs = new MutablePropertyValues(request.getParameterMap()); - if (request instanceof NativeWebRequest) { + + if(isMultiPartNativeRequest(request)) { MultipartRequest multipartRequest = ((NativeWebRequest) request).getNativeRequest(MultipartRequest.class); if (multipartRequest != null) { bindMultipart(multipartRequest.getMultiFileMap(), mpvs); + } else if (ClassUtils.hasMethod(HttpServletRequest.class, "getParts")) { + HttpServletRequest multipartSerlvetRequest = ((NativeWebRequest) request) + .getNativeRequest(HttpServletRequest.class); + Servlet3MultiPartBinder binder = new Servlet3MultiPartBinder(); + binder.bindParts(multipartSerlvetRequest, mpvs); } } doBind(mpvs); @@ -121,4 +139,38 @@ public void closeNoCatch() throws BindException { } } + /** + * Check if the request given as parameter is a NativeWebRequest + * and a multipart request (by checking its Content-Type header). + *

If so, this request will be parsed to bind its multiple parts. + * @param request request with parameters to bind + * @see org.springframework.web.context.request.NativeWebRequest + */ + protected boolean isMultiPartNativeRequest(WebRequest request) { + + return (StringUtils.startsWithIgnoreCase( + request.getHeader(CONTENT_TYPE),MULTIPART_CONTENT_TYPE) + && request instanceof NativeWebRequest); + } + + /** + * Encapsulate Part binding code for Servlet 3.0+ only containers. + * @see javax.servlet.http.Part + */ + private class Servlet3MultiPartBinder { + + public void bindParts(HttpServletRequest request, MutablePropertyValues mpvs) { + try { + for(Part part : request.getParts()) { + mpvs.add(part.getName(),part); + } + } catch (IOException ex) { + throw new MultipartException("Could not parse multipart servlet request", ex); + } catch(ServletException ex) { + throw new MultipartException("Could not parse multipart servlet request", ex); + } + } + + } + } diff --git a/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderIntegrationTests.java new file mode 100644 index 000000000000..355c191ceffa --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderIntegrationTests.java @@ -0,0 +1,196 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.support; + +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHolder; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.http.MediaType; +import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.mock.web.test.MockMultipartFile; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.SocketUtils; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.context.request.ServletWebRequest; + +import javax.servlet.ServletException; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.Part; +import java.io.IOException; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * @author Brian Clozel + */ +public class WebRequestDataBinderIntegrationTests { + + protected static String baseUrl; + + protected static MediaType contentType; + + private static Server jettyServer; + + private RestTemplate template; + + private static PartsServlet partsServlet; + + private static PartListServlet partListServlet; + + @Before + public void createTemplate() { + template = new RestTemplate(new HttpComponentsClientHttpRequestFactory()); + } + + @BeforeClass + public static void startJettyServer() throws Exception { + int port = SocketUtils.findAvailableTcpPort(); + jettyServer = new Server(port); + baseUrl = "http://localhost:" + port; + ServletContextHandler handler = new ServletContextHandler(); + + partsServlet = new PartsServlet(); + partListServlet = new PartListServlet(); + + handler.addServlet(new ServletHolder(partsServlet), "/parts"); + handler.addServlet(new ServletHolder(partListServlet), "/partlist"); + jettyServer.setHandler(handler); + jettyServer.start(); + } + + @AfterClass + public static void stopJettyServer() throws Exception { + if (jettyServer != null) { + jettyServer.stop(); + } + } + + @SuppressWarnings("serial") + private abstract static class AbstractStandardMultipartServlet extends HttpServlet { + + private T bean; + + @Override + public void service(HttpServletRequest request, HttpServletResponse response) throws + ServletException, IOException { + + WebRequestDataBinder binder = new WebRequestDataBinder(bean); + ServletWebRequest webRequest = new ServletWebRequest(request, response); + + binder.bind(webRequest); + + response.setStatus(HttpServletResponse.SC_OK); + } + + public void setBean(T bean) { + this.bean = bean; + } + } + + private static class PartsBean { + + public Part firstPart; + + public Part secondPart; + + public Part getFirstPart() { + return firstPart; + } + + public void setFirstPart(Part firstPart) { + this.firstPart = firstPart; + } + + public Part getSecondPart() { + return secondPart; + } + + public void setSecondPart(Part secondPart) { + this.secondPart = secondPart; + } + } + + @SuppressWarnings("serial") + private static class PartsServlet extends AbstractStandardMultipartServlet { + + } + + private static class PartListBean { + + public List partList; + + public List getPartList() { + return partList; + } + + public void setPartList(List partList) { + this.partList = partList; + } + } + + @SuppressWarnings("serial") + private static class PartListServlet extends AbstractStandardMultipartServlet { + + } + + @Test + public void testPartsBinding() { + + PartsBean bean = new PartsBean(); + partsServlet.setBean(bean); + + MultiValueMap parts = new LinkedMultiValueMap(); + MockMultipartFile firstPart = new MockMultipartFile("fileName", "aValue".getBytes()); + parts.add("firstPart", firstPart); + parts.add("secondPart", "secondValue"); + + template.postForLocation(baseUrl + "/parts", parts); + + assertNotNull(bean.getFirstPart()); + assertNotNull(bean.getSecondPart()); + } + + @Test + public void testPartListBinding() { + + PartListBean bean = new PartListBean(); + partListServlet.setBean(bean); + + MultiValueMap parts = new LinkedMultiValueMap(); + parts.add("partList", "first value"); + parts.add("partList", "second value"); + Resource logo = new ClassPathResource("/org/springframework/http/converter/logo.jpg"); + parts.add("partList", logo); + + template.postForLocation(baseUrl + "/partlist", parts); + + assertNotNull(bean.getPartList()); + assertEquals(parts.size(), bean.getPartList().size()); + } +}