Skip to content

Commit 4ba5ea3

Browse files
committed
GH-1018 Ensure AWS adapter can pass raw InputStream
Resolves #1018
1 parent 7365deb commit 4ba5ea3

File tree

3 files changed

+86
-1
lines changed

3 files changed

+86
-1
lines changed

spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSLambdaUtils.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package org.springframework.cloud.function.adapter.aws;
1818

19+
import java.io.IOException;
20+
import java.io.InputStream;
1921
import java.lang.reflect.Type;
2022
import java.nio.charset.StandardCharsets;
2123
import java.util.HashMap;
@@ -31,7 +33,9 @@
3133
import org.springframework.http.HttpStatus;
3234
import org.springframework.messaging.Message;
3335
import org.springframework.messaging.MessageHeaders;
36+
import org.springframework.messaging.support.GenericMessage;
3437
import org.springframework.messaging.support.MessageBuilder;
38+
import org.springframework.util.StreamUtils;
3539

3640
/**
3741
*
@@ -77,6 +81,23 @@ static boolean isSupportedAWSType(Type inputType) {
7781
|| typeName.equals("com.amazonaws.services.lambda.runtime.events.KinesisEvent");
7882
}
7983

84+
@SuppressWarnings("rawtypes")
85+
public static Message generateMessage(InputStream payload, Type inputType, boolean isSupplier, JsonMapper jsonMapper, Context context) throws IOException {
86+
if (inputType != null && FunctionTypeUtils.isMessage(inputType)) {
87+
inputType = FunctionTypeUtils.getImmediateGenericType(inputType, 0);
88+
}
89+
if (inputType != null && InputStream.class.isAssignableFrom(FunctionTypeUtils.getRawType(inputType))) {
90+
MessageBuilder msgBuilder = MessageBuilder.withPayload(payload);
91+
if (context != null) {
92+
msgBuilder.setHeader(AWSLambdaUtils.AWS_CONTEXT, context);
93+
}
94+
return msgBuilder.build();
95+
}
96+
else {
97+
return generateMessage(StreamUtils.copyToByteArray(payload), inputType, isSupplier, jsonMapper, context);
98+
}
99+
}
100+
80101
public static Message<byte[]> generateMessage(byte[] payload, Type inputType, boolean isSupplier, JsonMapper jsonMapper) {
81102
return generateMessage(payload, inputType, isSupplier, jsonMapper, null);
82103
}
@@ -87,6 +108,7 @@ public static Message<byte[]> generateMessage(byte[] payload, Type inputType, bo
87108
logger.info("Received: " + new String(payload, StandardCharsets.UTF_8));
88109
}
89110

111+
90112
Object structMessage = jsonMapper.fromJson(payload, Object.class);
91113
boolean isApiGateway = structMessage instanceof Map
92114
&& (((Map<String, Object>) structMessage).containsKey("httpMethod") ||

spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/FunctionInvoker.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ public FunctionInvoker() {
8080
@Override
8181
public void handleRequest(InputStream input, OutputStream output, Context context) throws IOException {
8282
Message requestMessage = AWSLambdaUtils
83-
.generateMessage(StreamUtils.copyToByteArray(input), this.function.getInputType(), this.function.isSupplier(), jsonMapper, context);
83+
.generateMessage(input, this.function.getInputType(), this.function.isSupplier(), jsonMapper, context);
8484

8585
Object response = this.function.apply(requestMessage);
8686
byte[] responseBytes = this.buildResult(requestMessage, response);

spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/test/java/org/springframework/cloud/function/adapter/aws/FunctionInvokerTests.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
import org.springframework.messaging.converter.AbstractMessageConverter;
5858
import org.springframework.messaging.support.MessageBuilder;
5959
import org.springframework.util.MimeType;
60+
import org.springframework.util.StreamUtils;
6061

6162
import static org.assertj.core.api.Assertions.assertThat;
6263
import static org.junit.jupiter.api.Assertions.fail;
@@ -971,6 +972,40 @@ public void testApiGatewayAsSupplier() throws Exception {
971972
assertThat(result.get("body")).isEqualTo("\"boom\"");
972973
}
973974

975+
@SuppressWarnings({ "rawtypes", "unchecked" })
976+
@Test
977+
public void testApiGatewayInAndOutInputStream() throws Exception {
978+
System.setProperty("MAIN_CLASS", ApiGatewayConfiguration.class.getName());
979+
System.setProperty("spring.cloud.function.definition", "echoInputStreamToString");
980+
FunctionInvoker invoker = new FunctionInvoker();
981+
982+
InputStream targetStream = new ByteArrayInputStream(this.apiGatewayEvent.getBytes());
983+
ByteArrayOutputStream output = new ByteArrayOutputStream();
984+
invoker.handleRequest(targetStream, output, null);
985+
986+
Map result = mapper.readValue(output.toByteArray(), Map.class);
987+
assertThat(result.get("body")).isEqualTo("hello");
988+
Map headers = (Map) result.get("headers");
989+
assertThat(headers).isNotEmpty();
990+
}
991+
992+
@SuppressWarnings({ "rawtypes", "unchecked" })
993+
@Test
994+
public void testApiGatewayInAndOutInputStreamMsg() throws Exception {
995+
System.setProperty("MAIN_CLASS", ApiGatewayConfiguration.class.getName());
996+
System.setProperty("spring.cloud.function.definition", "echoInputStreamMsgToString");
997+
FunctionInvoker invoker = new FunctionInvoker();
998+
999+
InputStream targetStream = new ByteArrayInputStream(this.apiGatewayEvent.getBytes());
1000+
ByteArrayOutputStream output = new ByteArrayOutputStream();
1001+
invoker.handleRequest(targetStream, output, null);
1002+
1003+
Map result = mapper.readValue(output.toByteArray(), Map.class);
1004+
assertThat(result.get("body")).isEqualTo("hello");
1005+
Map headers = (Map) result.get("headers");
1006+
assertThat(headers).isNotEmpty();
1007+
}
1008+
9741009
@SuppressWarnings("rawtypes")
9751010
@Test
9761011
public void testApiGatewayInAndOut() throws Exception {
@@ -1400,6 +1435,34 @@ public Function<APIGatewayProxyRequestEvent, String> inputApiEvent() {
14001435
};
14011436
}
14021437

1438+
@Bean
1439+
1440+
public Function<InputStream, String> echoInputStreamToString() {
1441+
return is -> {
1442+
try {
1443+
String result = StreamUtils.copyToString(is, StandardCharsets.UTF_8);
1444+
return result;
1445+
}
1446+
catch (Exception e) {
1447+
throw new RuntimeException(e);
1448+
}
1449+
};
1450+
}
1451+
1452+
@Bean
1453+
1454+
public Function<Message<InputStream>, String> echoInputStreamMsgToString() {
1455+
return msg -> {
1456+
try {
1457+
String result = StreamUtils.copyToString(msg.getPayload(), StandardCharsets.UTF_8);
1458+
return result;
1459+
}
1460+
catch (Exception e) {
1461+
throw new RuntimeException(e);
1462+
}
1463+
};
1464+
}
1465+
14031466
@Bean
14041467
public Function<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> inputOutputApiEvent() {
14051468
return v -> {

0 commit comments

Comments
 (0)