diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequest.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequest.java index daaef35a073a..24f63e9567c3 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequest.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequest.java @@ -26,10 +26,13 @@ import java.security.Principal; import java.time.Instant; import java.util.AbstractMap; +import java.util.AbstractSet; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.Enumeration; +import java.util.Iterator; import java.util.List; import java.util.Locale; import java.util.Map; @@ -47,6 +50,7 @@ import jakarta.servlet.http.HttpSession; import jakarta.servlet.http.Part; +import org.jetbrains.annotations.NotNull; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ResolvableType; import org.springframework.http.HttpHeaders; @@ -469,18 +473,75 @@ public boolean containsKey(Object key) { @Override public void clear() { - List attributeNames = Collections.list(this.servletRequest.getAttributeNames()); - attributeNames.forEach(this.servletRequest::removeAttribute); + this.servletRequest.getAttributeNames().asIterator().forEachRemaining(this.servletRequest::removeAttribute); } @Override public Set> entrySet() { - return Collections.list(this.servletRequest.getAttributeNames()).stream() - .map(name -> { - Object value = this.servletRequest.getAttribute(name); - return new SimpleImmutableEntry<>(name, value); - }) - .collect(Collectors.toSet()); + return new AbstractSet<>() { + @Override + public Iterator> iterator() { + return new Iterator<>() { + private final Iterator attributes = ServletAttributesMap.this.servletRequest.getAttributeNames().asIterator(); + @Override + public boolean hasNext() { + return attributes.hasNext(); + } + + @Override + public Entry next() { + String attribute = attributes.next(); + Object value = ServletAttributesMap.this.servletRequest.getAttribute(attribute); + return new SimpleImmutableEntry<>(attribute, value); + } + }; + } + + @Override + public boolean isEmpty() { + return ServletAttributesMap.this.isEmpty(); + } + + @Override + public int size() { + return ServletAttributesMap.this.size(); + } + + @Override + public boolean contains(Object o) { + if (!(o instanceof Map.Entry entry)) { + return false; + } + String attribute = (String) entry.getKey(); + Object value = ServletAttributesMap.this.servletRequest.getAttribute(attribute); + return value != null && value.equals(entry.getValue()); + } + + @Override + public boolean addAll(@NotNull Collection> c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean remove(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean removeAll(Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean retainAll(@NotNull Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(); + } + }; } @Override @@ -503,6 +564,22 @@ public Object remove(Object key) { this.servletRequest.removeAttribute(name); return value; } + + @Override + public int size() { + Enumeration attributes = this.servletRequest.getAttributeNames(); + int size = 0; + while (attributes.hasMoreElements()) { + size++; + attributes.nextElement(); + } + return size; + } + + @Override + public boolean isEmpty() { + return !this.servletRequest.getAttributeNames().hasMoreElements(); + } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultServerRequestTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultServerRequestTests.java index 1690193cf194..3c702ab5130c 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultServerRequestTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultServerRequestTests.java @@ -27,10 +27,12 @@ import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.OptionalLong; +import java.util.Set; import jakarta.servlet.http.Cookie; import jakarta.servlet.http.Part; @@ -115,6 +117,46 @@ void attribute() { assertThat(request.attribute("foo")).contains("bar"); } + @Test + void attributes() { + MockHttpServletRequest servletRequest = PathPatternsTestUtils.initRequest("GET", "/", true); + servletRequest.setAttribute("foo", "bar"); + servletRequest.setAttribute("baz", "qux"); + + DefaultServerRequest request = new DefaultServerRequest(servletRequest, this.messageConverters); + + Map attributesMap = request.attributes(); + assertThat(attributesMap).isNotEmpty(); + assertThat(attributesMap).containsEntry("foo", "bar"); + assertThat(attributesMap).containsEntry("baz", "qux"); + assertThat(attributesMap).doesNotContainEntry("foo", "blah"); + + Set> entrySet = attributesMap.entrySet(); + assertThat(entrySet).isNotEmpty(); + assertThat(entrySet).hasSize(attributesMap.size()); + assertThat(entrySet).contains(Map.entry("foo", "bar")); + assertThat(entrySet).contains(Map.entry("baz", "qux")); + assertThat(entrySet).doesNotContain(Map.entry("foo", "blah")); + assertThat(entrySet).isUnmodifiable(); + + assertThat(entrySet.iterator()).toIterable().contains(Map.entry("foo", "bar"), Map.entry("baz", "qux")); + Iterator attributes = servletRequest.getAttributeNames().asIterator(); + Iterator> entrySetIterator = entrySet.iterator(); + while (attributes.hasNext()) { + attributes.next(); + assertThat(entrySetIterator).hasNext(); + entrySetIterator.next(); + } + assertThat(entrySetIterator).isExhausted(); + + attributesMap.clear(); + assertThat(attributesMap).isEmpty(); + assertThat(attributesMap).hasSize(0); + assertThat(entrySet).isEmpty(); + assertThat(entrySet).hasSize(0); + assertThat(entrySet.iterator()).isExhausted(); + } + @Test void params() { MockHttpServletRequest servletRequest = PathPatternsTestUtils.initRequest("GET", "/", true);