Skip to content

Commit aabe4d0

Browse files
committed
Pickup CoroutineContext saved by CoWebFilter in coRouter
Closes gh-31793
1 parent 5700742 commit aabe4d0

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.springframework.http.HttpMethod
2727
import org.springframework.http.HttpStatusCode
2828
import org.springframework.http.MediaType
2929
import org.springframework.web.reactive.function.server.RouterFunctions.nest
30+
import org.springframework.web.server.CoWebFilter
3031
import reactor.core.publisher.Mono
3132
import java.net.URI
3233
import kotlin.coroutines.CoroutineContext
@@ -731,7 +732,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
731732
) : HandlerFunction<T> {
732733

733734
override fun handle(request: ServerRequest): Mono<T> {
734-
return handle(Dispatchers.Unconfined, request)
735+
val context = request.attributes()[CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext?
736+
return handle(context ?: Dispatchers.Unconfined, request)
735737
}
736738

737739
fun handle(context: CoroutineContext, request: ServerRequest) = asMono(request, context) {

spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ import org.springframework.http.HttpHeaders.CONTENT_TYPE
2525
import org.springframework.http.HttpMethod.PATCH
2626
import org.springframework.http.HttpStatus
2727
import org.springframework.http.MediaType.*
28+
import org.springframework.web.server.CoWebFilter
29+
import org.springframework.web.server.CoWebFilterChain
30+
import org.springframework.web.server.ServerWebExchange
2831
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.*
32+
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpResponse
2933
import org.springframework.web.testfixture.server.MockServerWebExchange
3034
import reactor.test.StepVerifier
3135

@@ -204,6 +208,16 @@ class CoRouterFunctionDslTests {
204208
.verifyComplete()
205209
}
206210

211+
@Test
212+
fun webFilterAndContext() {
213+
val strategies = HandlerStrategies.builder().webFilter(MyCoWebFilterWithContext()).build()
214+
val httpHandler = RouterFunctions.toHttpHandler(routerWithoutContext, strategies)
215+
val mockRequest = get("https://example.com/").build()
216+
val mockResponse = MockServerHttpResponse()
217+
StepVerifier.create(httpHandler.handle(mockRequest, mockResponse)).verifyComplete()
218+
assertThat(mockResponse.headers.getFirst("context")).contains("Filter context")
219+
}
220+
207221
@Test
208222
fun multipleContextProviders() {
209223
assertThatIllegalStateException().isThrownBy {
@@ -309,6 +323,12 @@ class CoRouterFunctionDslTests {
309323
}
310324
}
311325

326+
private val routerWithoutContext = coRouter {
327+
GET("/") {
328+
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
329+
}
330+
}
331+
312332
private val otherRouter = router {
313333
"/other" {
314334
ok().build()
@@ -369,3 +389,12 @@ class CoRouterFunctionDslTests {
369389

370390
@Suppress("UNUSED_PARAMETER")
371391
private suspend fun handle(req: ServerRequest) = ServerResponse.ok().buildAndAwait()
392+
393+
394+
private class MyCoWebFilterWithContext : CoWebFilter() {
395+
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) {
396+
withContext(CoroutineName("Filter context")) {
397+
chain.filter(exchange)
398+
}
399+
}
400+
}

0 commit comments

Comments
 (0)