Skip to content

Commit 69883f6

Browse files
authored
Merge pull request #631 from 2012160085/main
Fix Asynchronous Dispatch Logic in AwsAsyncContext with Spring's DispatcherServlet
2 parents 7ca2f07 + 1fa314b commit 69883f6

File tree

19 files changed

+382
-51
lines changed

19 files changed

+382
-51
lines changed

aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsAsyncContext.java

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,23 @@
3232
public class AwsAsyncContext implements AsyncContext {
3333
private HttpServletRequest req;
3434
private HttpServletResponse res;
35-
private AwsLambdaServletContainerHandler handler;
3635
private List<AsyncListenerHolder> listeners;
3736
private long timeout;
3837
private AtomicBoolean dispatched;
3938
private AtomicBoolean completed;
39+
private AtomicBoolean dispatchStarted;
4040

4141
private Logger log = LoggerFactory.getLogger(AwsAsyncContext.class);
4242

43-
public AwsAsyncContext(HttpServletRequest request, HttpServletResponse response, AwsLambdaServletContainerHandler servletHandler) {
43+
public AwsAsyncContext(HttpServletRequest request, HttpServletResponse response) {
4444
log.debug("Initializing async context for request: " + SecurityUtils.crlf(request.getPathInfo()) + " - " + SecurityUtils.crlf(request.getMethod()));
4545
req = request;
4646
res = response;
47-
handler = servletHandler;
4847
listeners = new ArrayList<>();
4948
timeout = 3000;
5049
dispatched = new AtomicBoolean(false);
5150
completed = new AtomicBoolean(false);
51+
dispatchStarted = new AtomicBoolean(false);
5252
}
5353

5454
@Override
@@ -68,16 +68,14 @@ public boolean hasOriginalRequestAndResponse() {
6868

6969
@Override
7070
public void dispatch() {
71-
try {
72-
log.debug("Dispatching request");
73-
if (dispatched.get()) {
74-
throw new IllegalStateException("Dispatching already started");
75-
}
71+
log.debug("Dispatching request");
72+
73+
if (dispatched.get()) {
74+
throw new IllegalStateException("Dispatching already started");
75+
}
76+
if (dispatchStarted.getAndSet(true)) {
7677
dispatched.set(true);
77-
handler.doFilter(req, res, ((AwsServletContext)req.getServletContext()).getServletForPath(req.getRequestURI()));
7878
notifyListeners(NotificationType.START_ASYNC, null);
79-
} catch (ServletException | IOException e) {
80-
notifyListeners(NotificationType.ERROR, e);
8179
}
8280
}
8381

@@ -154,6 +152,10 @@ public boolean isCompleted() {
154152
return completed.get();
155153
}
156154

155+
public boolean isDispatchStarted() {
156+
return dispatchStarted.get();
157+
}
158+
157159
private void notifyListeners(NotificationType type, Throwable t) {
158160
listeners.forEach((h) -> {
159161
try {

aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpApiV2ProxyHttpServletRequest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,15 +442,15 @@ public boolean isAsyncStarted() {
442442

443443
@Override
444444
public AsyncContext startAsync() throws IllegalStateException {
445-
asyncContext = new AwsAsyncContext(this, response, containerHandler);
445+
asyncContext = new AwsAsyncContext(this, response);
446446
setAttribute(DISPATCHER_TYPE_ATTRIBUTE, DispatcherType.ASYNC);
447447
log.debug("Starting async context for request: " + SecurityUtils.crlf(request.getRequestContext().getRequestId()));
448448
return asyncContext;
449449
}
450450

451451
@Override
452452
public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException {
453-
asyncContext = new AwsAsyncContext((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse, containerHandler);
453+
asyncContext = new AwsAsyncContext((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse);
454454
setAttribute(DISPATCHER_TYPE_ATTRIBUTE, DispatcherType.ASYNC);
455455
log.debug("Starting async context for request: " + SecurityUtils.crlf(request.getRequestContext().getRequestId()));
456456
return asyncContext;

aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsLambdaServletContainerHandler.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,25 @@ protected void doFilter(HttpServletRequest request, HttpServletResponse response
152152

153153
FilterChain chain = getFilterChain(request, servlet);
154154
chain.doFilter(request, response);
155-
155+
if(requiresAsyncReDispatch(request)) {
156+
chain = getFilterChain(request, servlet);
157+
chain.doFilter(request, response);
158+
}
156159
// if for some reason the response wasn't flushed yet, we force it here unless it's being processed asynchronously (WebFlux)
157160
if (!response.isCommitted() && request.getDispatcherType() != DispatcherType.ASYNC) {
158161
response.flushBuffer();
159162
}
160163
}
161164

165+
private boolean requiresAsyncReDispatch(HttpServletRequest request) {
166+
if (request.isAsyncStarted()) {
167+
AsyncContext asyncContext = request.getAsyncContext();
168+
return asyncContext instanceof AwsAsyncContext
169+
&& ((AwsAsyncContext) asyncContext).isDispatchStarted();
170+
}
171+
return false;
172+
}
173+
162174
@Override
163175
public void initialize() throws ContainerInitializationException {
164176
// we expect all servlets to be wrapped in an AwsServletRegistration

aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ public boolean isAsyncStarted() {
495495
@Override
496496
public AsyncContext startAsync()
497497
throws IllegalStateException {
498-
asyncContext = new AwsAsyncContext(this, response, containerHandler);
498+
asyncContext = new AwsAsyncContext(this, response);
499499
setAttribute(DISPATCHER_TYPE_ATTRIBUTE, DispatcherType.ASYNC);
500500
log.debug("Starting async context for request: " + SecurityUtils.crlf(request.getRequestContext().getRequestId()));
501501
return asyncContext;
@@ -506,7 +506,7 @@ public AsyncContext startAsync()
506506
public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse)
507507
throws IllegalStateException {
508508
servletRequest.setAttribute(DISPATCHER_TYPE_ATTRIBUTE, DispatcherType.ASYNC);
509-
asyncContext = new AwsAsyncContext((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse, containerHandler);
509+
asyncContext = new AwsAsyncContext((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse);
510510
log.debug("Starting async context for request: " + SecurityUtils.crlf(request.getRequestContext().getRequestId()));
511511
return asyncContext;
512512
}

aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsAsyncContextTest.java

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import com.amazonaws.serverless.proxy.model.AwsProxyRequest;
1111
import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
1212
import com.amazonaws.services.lambda.runtime.Context;
13+
import org.junit.jupiter.api.Disabled;
1314
import org.junit.jupiter.api.Test;
1415

1516
import jakarta.servlet.AsyncContext;
@@ -32,48 +33,20 @@ public class AwsAsyncContextTest {
3233
private AwsServletContextTest.TestServlet srv2 = new AwsServletContextTest.TestServlet("srv2");
3334
private AwsServletContext ctx = getCtx();
3435

35-
@Test
36-
void dispatch_sendsToCorrectServlet() {
37-
AwsProxyHttpServletRequest req = new AwsProxyHttpServletRequest(new AwsProxyRequestBuilder("/srv1/hello", "GET").build(), lambdaCtx, null);
38-
req.setResponse(handler.getContainerResponse(req, new CountDownLatch(1)));
39-
req.setServletContext(ctx);
40-
req.setContainerHandler(handler);
41-
42-
AsyncContext asyncCtx = req.startAsync();
43-
handler.setDesiredStatus(201);
44-
asyncCtx.dispatch();
45-
assertNotNull(handler.getSelectedServlet());
46-
assertEquals(srv1, handler.getSelectedServlet());
47-
assertEquals(201, handler.getResponse().getStatus());
48-
49-
req = new AwsProxyHttpServletRequest(new AwsProxyRequestBuilder("/srv5/hello", "GET").build(), lambdaCtx, null);
50-
req.setResponse(handler.getContainerResponse(req, new CountDownLatch(1)));
51-
req.setServletContext(ctx);
52-
req.setContainerHandler(handler);
53-
asyncCtx = req.startAsync();
54-
handler.setDesiredStatus(202);
55-
asyncCtx.dispatch();
56-
assertNotNull(handler.getSelectedServlet());
57-
assertEquals(srv2, handler.getSelectedServlet());
58-
assertEquals(202, handler.getResponse().getStatus());
59-
}
6036

6137
@Test
62-
void dispatchNewPath_sendsToCorrectServlet() throws InvalidRequestEventException {
38+
void dispatch_amendsPath() throws InvalidRequestEventException {
6339
AwsProxyHttpServletRequest req = (AwsProxyHttpServletRequest)reader.readRequest(new AwsProxyRequestBuilder("/srv1/hello", "GET").build(), null, lambdaCtx, LambdaContainerHandler.getContainerConfig());
6440
req.setResponse(handler.getContainerResponse(req, new CountDownLatch(1)));
6541
req.setServletContext(ctx);
6642
req.setContainerHandler(handler);
6743

6844
AsyncContext asyncCtx = req.startAsync();
69-
handler.setDesiredStatus(301);
7045
asyncCtx.dispatch("/srv4/hello");
71-
assertNotNull(handler.getSelectedServlet());
72-
assertEquals(srv2, handler.getSelectedServlet());
73-
assertNotNull(handler.getResponse());
74-
assertEquals(301, handler.getResponse().getStatus());
46+
assertEquals("/srv1/hello", req.getRequestURI());
7547
}
7648

49+
7750
private AwsServletContext getCtx() {
7851
AwsServletContext ctx = new AwsServletContext(handler);
7952
handler.setServletContext(ctx);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package com.amazonaws.serverless.proxy.spring;
2+
3+
import com.amazonaws.serverless.exceptions.ContainerInitializationException;
4+
import com.amazonaws.serverless.proxy.internal.testutils.AwsProxyRequestBuilder;
5+
import com.amazonaws.serverless.proxy.internal.testutils.MockLambdaContext;
6+
import com.amazonaws.serverless.proxy.model.AwsProxyRequest;
7+
import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
8+
import com.amazonaws.serverless.proxy.spring.springapp.LambdaHandler;
9+
import com.amazonaws.serverless.proxy.spring.springapp.MessageController;
10+
import org.junit.jupiter.api.BeforeAll;
11+
import org.junit.jupiter.api.Test;
12+
13+
import static org.junit.jupiter.api.Assertions.assertEquals;
14+
import static org.junit.jupiter.api.Assertions.fail;
15+
16+
public class AsyncAppTest {
17+
18+
private static LambdaHandler handler;
19+
20+
@BeforeAll
21+
public static void setUp() {
22+
try {
23+
handler = new LambdaHandler();
24+
} catch (ContainerInitializationException e) {
25+
e.printStackTrace();
26+
fail();
27+
}
28+
}
29+
30+
@Test
31+
void springApp_helloRequest_returnsCorrect() {
32+
AwsProxyRequest req = new AwsProxyRequestBuilder("/hello", "GET").build();
33+
AwsProxyResponse resp = handler.handleRequest(req, new MockLambdaContext());
34+
assertEquals(200, resp.getStatusCode());
35+
assertEquals(MessageController.HELLO_MESSAGE, resp.getBody());
36+
}
37+
38+
@Test
39+
void springApp_asyncRequest_returnsCorrect() {
40+
AwsProxyRequest req = new AwsProxyRequestBuilder("/async", "GET").build();
41+
AwsProxyResponse resp = handler.handleRequest(req, new MockLambdaContext());
42+
assertEquals(200, resp.getStatusCode());
43+
assertEquals(MessageController.HELLO_MESSAGE, resp.getBody());
44+
}
45+
46+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package com.amazonaws.serverless.proxy.spring.springapp;
2+
3+
import org.springframework.context.annotation.Configuration;
4+
import org.springframework.context.annotation.Import;
5+
6+
@Configuration
7+
@Import({MessageController.class})
8+
public class AppConfig { }
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package com.amazonaws.serverless.proxy.spring.springapp;
2+
3+
import com.amazonaws.serverless.exceptions.ContainerInitializationException;
4+
import com.amazonaws.serverless.proxy.model.AwsProxyRequest;
5+
import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
6+
import com.amazonaws.serverless.proxy.spring.SpringLambdaContainerHandler;
7+
import com.amazonaws.serverless.proxy.spring.SpringProxyHandlerBuilder;
8+
import com.amazonaws.services.lambda.runtime.Context;
9+
import com.amazonaws.services.lambda.runtime.RequestHandler;
10+
11+
public class LambdaHandler implements RequestHandler<AwsProxyRequest, AwsProxyResponse> {
12+
private SpringLambdaContainerHandler<AwsProxyRequest, AwsProxyResponse> handler;
13+
14+
public LambdaHandler() throws ContainerInitializationException {
15+
handler = new SpringProxyHandlerBuilder<AwsProxyRequest>()
16+
.defaultProxy()
17+
.asyncInit()
18+
.configurationClasses(AppConfig.class)
19+
.buildAndInitialize();
20+
}
21+
22+
@Override
23+
public AwsProxyResponse handleRequest(AwsProxyRequest awsProxyRequest, Context context) {
24+
return handler.proxy(awsProxyRequest, context);
25+
}
26+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package com.amazonaws.serverless.proxy.spring.springapp;
2+
3+
import org.springframework.web.bind.annotation.RequestMapping;
4+
import org.springframework.web.bind.annotation.RequestMethod;
5+
import org.springframework.web.bind.annotation.RestController;
6+
import org.springframework.web.context.request.async.DeferredResult;
7+
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
8+
9+
@RestController
10+
@EnableWebMvc
11+
public class MessageController {
12+
public static final String HELLO_MESSAGE = "Hello";
13+
14+
@RequestMapping(path="/hello", method= RequestMethod.GET)
15+
public String hello() {
16+
return HELLO_MESSAGE;
17+
}
18+
19+
@RequestMapping(path="/async", method= RequestMethod.GET)
20+
public DeferredResult<String> asyncHello() {
21+
DeferredResult<String> result = new DeferredResult<>();
22+
result.setResult(HELLO_MESSAGE);
23+
return result;
24+
}
25+
}

aws-serverless-java-container-springboot3/pom.xml

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,46 @@
191191
<scope>test</scope>
192192
</dependency>
193193

194+
<dependency>
195+
<groupId>org.springframework.boot</groupId>
196+
<artifactId>spring-boot-starter-data-jpa</artifactId>
197+
<version>3.2.1</version>
198+
<scope>test</scope>
199+
<exclusions>
200+
<exclusion>
201+
<groupId>org.springframework.boot</groupId>
202+
<artifactId>spring-boot-starter-aop</artifactId>
203+
</exclusion>
204+
<exclusion>
205+
<groupId>org.springframework.boot</groupId>
206+
<artifactId>spring-boot-starter-web</artifactId>
207+
</exclusion>
208+
<exclusion>
209+
<groupId>org.springframework.boot</groupId>
210+
<artifactId>spring-boot-starter-logging</artifactId>
211+
</exclusion>
212+
<exclusion>
213+
<groupId>org.springframework.boot</groupId>
214+
<artifactId>spring-boot-starter-tomcat</artifactId>
215+
</exclusion>
216+
<exclusion>
217+
<groupId>org.apache.tomcat.embed</groupId>
218+
<artifactId>tomcat-embed-core</artifactId>
219+
</exclusion>
220+
<exclusion>
221+
<groupId>org.apache.tomcat.embed</groupId>
222+
<artifactId>tomcat-embed-websocket</artifactId>
223+
</exclusion>
224+
</exclusions>
225+
</dependency>
226+
<dependency>
227+
<groupId>com.h2database</groupId>
228+
<artifactId>h2</artifactId>
229+
<version>2.2.222</version>
230+
<scope>test</scope>
231+
</dependency>
232+
233+
194234
</dependencies>
195235

196236
<build>
@@ -282,6 +322,14 @@
282322
<failOnError>false</failOnError>
283323
</configuration>
284324
</plugin>
325+
<plugin>
326+
<groupId>org.apache.maven.plugins</groupId>
327+
<artifactId>maven-compiler-plugin</artifactId>
328+
<configuration>
329+
<source>10</source>
330+
<target>10</target>
331+
</configuration>
332+
</plugin>
285333
</plugins>
286334
</build>
287335
<repositories>

0 commit comments

Comments
 (0)