diff --git a/generated/src/aws-cpp-sdk-transcribestreaming/include/aws/transcribestreaming/TranscribeStreamingServiceClient.h b/generated/src/aws-cpp-sdk-transcribestreaming/include/aws/transcribestreaming/TranscribeStreamingServiceClient.h index ed31da78e82..a40f00da82d 100644 --- a/generated/src/aws-cpp-sdk-transcribestreaming/include/aws/transcribestreaming/TranscribeStreamingServiceClient.h +++ b/generated/src/aws-cpp-sdk-transcribestreaming/include/aws/transcribestreaming/TranscribeStreamingServiceClient.h @@ -6,15 +6,19 @@ #pragma once #include #include -#include #include -#include #include +#include +#include +#include +#include +#include namespace Aws { namespace TranscribeStreamingService { + AWS_TRANSCRIBESTREAMINGSERVICE_API extern const char SERVICE_NAME[]; /** *

Amazon Transcribe streaming offers four main types of real-time * transcription: Standard, Medical, Call Analytics, and @@ -31,12 +35,20 @@ namespace TranscribeStreamingService * patient-clinician conversations using generative AI. Refer to [here] for * details.

*/ - class AWS_TRANSCRIBESTREAMINGSERVICE_API TranscribeStreamingServiceClient : public Aws::Client::AWSJsonClient, public Aws::Client::ClientWithAsyncTemplateMethods + class AWS_TRANSCRIBESTREAMINGSERVICE_API TranscribeStreamingServiceClient : Aws::Client::ClientWithAsyncTemplateMethods, + smithy::client::AwsSmithyClientT, + Aws::Crt::Variant, + TranscribeStreamingServiceEndpointProviderBase, + smithy::client::JsonOutcomeSerializer, + smithy::client::JsonOutcome, + Aws::Client::TranscribeStreamingServiceErrorMarshaller> { public: - typedef Aws::Client::AWSJsonClient BASECLASS; static const char* GetServiceName(); static const char* GetAllocationTag(); + inline const char* GetServiceClientName() const override { return "Transcribe Streaming"; } typedef TranscribeStreamingServiceClientConfiguration ClientConfigurationType; typedef TranscribeStreamingServiceEndpointProvider EndpointProviderType; @@ -229,10 +241,7 @@ namespace TranscribeStreamingService std::shared_ptr& accessEndpointProvider(); private: friend class Aws::Client::ClientWithAsyncTemplateMethods; - void init(const TranscribeStreamingServiceClientConfiguration& clientConfiguration); - TranscribeStreamingServiceClientConfiguration m_clientConfiguration; - std::shared_ptr m_endpointProvider; }; } // namespace TranscribeStreamingService diff --git a/generated/src/aws-cpp-sdk-transcribestreaming/include/aws/transcribestreaming/model/AudioStream.h b/generated/src/aws-cpp-sdk-transcribestreaming/include/aws/transcribestreaming/model/AudioStream.h index e2c5782695b..e542ef97216 100644 --- a/generated/src/aws-cpp-sdk-transcribestreaming/include/aws/transcribestreaming/model/AudioStream.h +++ b/generated/src/aws-cpp-sdk-transcribestreaming/include/aws/transcribestreaming/model/AudioStream.h @@ -9,6 +9,7 @@ #include #include #include +#include namespace Aws { @@ -25,7 +26,7 @@ namespace Model * href="http://docs.aws.amazon.com/goto/WebAPI/transcribe-streaming-2017-10-26/AudioStream">AWS * API Reference

*/ - class AWS_TRANSCRIBESTREAMINGSERVICE_API AudioStream : public Aws::Utils::Event::EventEncoderStream + class AWS_TRANSCRIBESTREAMINGSERVICE_API AudioStream : public Aws::Utils::Event::SmithyEventEncoderStream { public: AudioStream& WriteAudioEvent(const AudioEvent& value) diff --git a/generated/src/aws-cpp-sdk-transcribestreaming/include/aws/transcribestreaming/model/MedicalScribeInputStream.h b/generated/src/aws-cpp-sdk-transcribestreaming/include/aws/transcribestreaming/model/MedicalScribeInputStream.h index 399df4cd9fc..1b965b39059 100644 --- a/generated/src/aws-cpp-sdk-transcribestreaming/include/aws/transcribestreaming/model/MedicalScribeInputStream.h +++ b/generated/src/aws-cpp-sdk-transcribestreaming/include/aws/transcribestreaming/model/MedicalScribeInputStream.h @@ -10,6 +10,7 @@ #include #include #include +#include namespace Aws { @@ -29,7 +30,7 @@ namespace Model * href="http://docs.aws.amazon.com/goto/WebAPI/transcribe-streaming-2017-10-26/MedicalScribeInputStream">AWS * API Reference

*/ - class AWS_TRANSCRIBESTREAMINGSERVICE_API MedicalScribeInputStream : public Aws::Utils::Event::EventEncoderStream + class AWS_TRANSCRIBESTREAMINGSERVICE_API MedicalScribeInputStream : public Aws::Utils::Event::SmithyEventEncoderStream { public: MedicalScribeInputStream& WriteMedicalScribeAudioEvent(const MedicalScribeAudioEvent& value) diff --git a/generated/src/aws-cpp-sdk-transcribestreaming/source/TranscribeStreamingServiceClient.cpp b/generated/src/aws-cpp-sdk-transcribestreaming/source/TranscribeStreamingServiceClient.cpp index d50e38dad0b..9709579b7c9 100644 --- a/generated/src/aws-cpp-sdk-transcribestreaming/source/TranscribeStreamingServiceClient.cpp +++ b/generated/src/aws-cpp-sdk-transcribestreaming/source/TranscribeStreamingServiceClient.cpp @@ -4,19 +4,17 @@ */ #include -#include #include #include #include -#include #include #include -#include #include #include #include #include #include + #include #include @@ -31,6 +29,9 @@ #include +#include +#include +#include using namespace Aws; using namespace Aws::Auth; @@ -46,100 +47,100 @@ namespace Aws { namespace TranscribeStreamingService { - const char SERVICE_NAME[] = "transcribe"; const char ALLOCATION_TAG[] = "TranscribeStreamingServiceClient"; + const char SERVICE_NAME[] = "transcribe"; } } const char* TranscribeStreamingServiceClient::GetServiceName() {return SERVICE_NAME;} const char* TranscribeStreamingServiceClient::GetAllocationTag() {return ALLOCATION_TAG;} TranscribeStreamingServiceClient::TranscribeStreamingServiceClient(const TranscribeStreamingService::TranscribeStreamingServiceClientConfiguration& clientConfiguration, - std::shared_ptr endpointProvider) : - BASECLASS(clientConfiguration, - Aws::MakeShared(ALLOCATION_TAG, - Aws::MakeShared(ALLOCATION_TAG), - SERVICE_NAME, - Aws::Region::ComputeSignerRegion(clientConfiguration.region)), - Aws::MakeShared(ALLOCATION_TAG)), - m_clientConfiguration(clientConfiguration), - m_endpointProvider(endpointProvider ? std::move(endpointProvider) : Aws::MakeShared(ALLOCATION_TAG)) -{ - init(m_clientConfiguration); -} + std::shared_ptr endpointProvider) : + AwsSmithyClientT(clientConfiguration, + GetServiceName(), + "Transcribe Streaming", + Aws::Http::CreateHttpClient(clientConfiguration), + Aws::MakeShared(ALLOCATION_TAG), + endpointProvider ? endpointProvider : Aws::MakeShared(ALLOCATION_TAG), + Aws::MakeShared>(ALLOCATION_TAG), + { + {smithy::SigV4AuthSchemeOption::sigV4AuthSchemeOption.schemeId, smithy::SigV4AuthScheme{GetServiceName(), clientConfiguration.region}}, + }) +{} TranscribeStreamingServiceClient::TranscribeStreamingServiceClient(const AWSCredentials& credentials, - std::shared_ptr endpointProvider, - const TranscribeStreamingService::TranscribeStreamingServiceClientConfiguration& clientConfiguration) : - BASECLASS(clientConfiguration, - Aws::MakeShared(ALLOCATION_TAG, - Aws::MakeShared(ALLOCATION_TAG, credentials), - SERVICE_NAME, - Aws::Region::ComputeSignerRegion(clientConfiguration.region)), - Aws::MakeShared(ALLOCATION_TAG)), - m_clientConfiguration(clientConfiguration), - m_endpointProvider(endpointProvider ? std::move(endpointProvider) : Aws::MakeShared(ALLOCATION_TAG)) -{ - init(m_clientConfiguration); -} + std::shared_ptr endpointProvider, + const TranscribeStreamingService::TranscribeStreamingServiceClientConfiguration& clientConfiguration) : + AwsSmithyClientT(clientConfiguration, + GetServiceName(), + "Transcribe Streaming", + Aws::Http::CreateHttpClient(clientConfiguration), + Aws::MakeShared(ALLOCATION_TAG), + endpointProvider ? endpointProvider : Aws::MakeShared(ALLOCATION_TAG), + Aws::MakeShared>(ALLOCATION_TAG), + { + {smithy::SigV4AuthSchemeOption::sigV4AuthSchemeOption.schemeId, smithy::SigV4AuthScheme{Aws::MakeShared(ALLOCATION_TAG, credentials), GetServiceName(), clientConfiguration.region}}, + }) +{} TranscribeStreamingServiceClient::TranscribeStreamingServiceClient(const std::shared_ptr& credentialsProvider, - std::shared_ptr endpointProvider, - const TranscribeStreamingService::TranscribeStreamingServiceClientConfiguration& clientConfiguration) : - BASECLASS(clientConfiguration, - Aws::MakeShared(ALLOCATION_TAG, - credentialsProvider, - SERVICE_NAME, - Aws::Region::ComputeSignerRegion(clientConfiguration.region)), - Aws::MakeShared(ALLOCATION_TAG)), - m_clientConfiguration(clientConfiguration), - m_endpointProvider(endpointProvider ? std::move(endpointProvider) : Aws::MakeShared(ALLOCATION_TAG)) -{ - init(m_clientConfiguration); -} + std::shared_ptr endpointProvider, + const TranscribeStreamingService::TranscribeStreamingServiceClientConfiguration& clientConfiguration) : + AwsSmithyClientT(clientConfiguration, + GetServiceName(), + "Transcribe Streaming", + Aws::Http::CreateHttpClient(clientConfiguration), + Aws::MakeShared(ALLOCATION_TAG), + endpointProvider ? endpointProvider : Aws::MakeShared(ALLOCATION_TAG), + Aws::MakeShared>(ALLOCATION_TAG), + { + {smithy::SigV4AuthSchemeOption::sigV4AuthSchemeOption.schemeId, smithy::SigV4AuthScheme{ Aws::MakeShared(ALLOCATION_TAG, credentialsProvider), GetServiceName(), clientConfiguration.region}} + }) +{} - /* Legacy constructors due deprecation */ - TranscribeStreamingServiceClient::TranscribeStreamingServiceClient(const Client::ClientConfiguration& clientConfiguration) : - BASECLASS(clientConfiguration, - Aws::MakeShared(ALLOCATION_TAG, - Aws::MakeShared(ALLOCATION_TAG), - SERVICE_NAME, - Aws::Region::ComputeSignerRegion(clientConfiguration.region)), - Aws::MakeShared(ALLOCATION_TAG)), - m_clientConfiguration(clientConfiguration), - m_endpointProvider(Aws::MakeShared(ALLOCATION_TAG)) -{ - init(m_clientConfiguration); -} +/* Legacy constructors due deprecation */ +TranscribeStreamingServiceClient::TranscribeStreamingServiceClient(const Client::ClientConfiguration& clientConfiguration) : + AwsSmithyClientT(clientConfiguration, + GetServiceName(), + "Transcribe Streaming", + Aws::Http::CreateHttpClient(clientConfiguration), + Aws::MakeShared(ALLOCATION_TAG), + Aws::MakeShared(ALLOCATION_TAG), + Aws::MakeShared>(ALLOCATION_TAG), + { + {smithy::SigV4AuthSchemeOption::sigV4AuthSchemeOption.schemeId, smithy::SigV4AuthScheme{Aws::MakeShared(ALLOCATION_TAG), GetServiceName(), clientConfiguration.region}} + }) +{} TranscribeStreamingServiceClient::TranscribeStreamingServiceClient(const AWSCredentials& credentials, - const Client::ClientConfiguration& clientConfiguration) : - BASECLASS(clientConfiguration, - Aws::MakeShared(ALLOCATION_TAG, - Aws::MakeShared(ALLOCATION_TAG, credentials), - SERVICE_NAME, - Aws::Region::ComputeSignerRegion(clientConfiguration.region)), - Aws::MakeShared(ALLOCATION_TAG)), - m_clientConfiguration(clientConfiguration), - m_endpointProvider(Aws::MakeShared(ALLOCATION_TAG)) -{ - init(m_clientConfiguration); -} + const Client::ClientConfiguration& clientConfiguration) : + AwsSmithyClientT(clientConfiguration, + GetServiceName(), + "Transcribe Streaming", + Aws::Http::CreateHttpClient(clientConfiguration), + Aws::MakeShared(ALLOCATION_TAG), + Aws::MakeShared(ALLOCATION_TAG), + Aws::MakeShared>(ALLOCATION_TAG), + { + {smithy::SigV4AuthSchemeOption::sigV4AuthSchemeOption.schemeId, smithy::SigV4AuthScheme{Aws::MakeShared(ALLOCATION_TAG, credentials), GetServiceName(), clientConfiguration.region}} + }) +{} TranscribeStreamingServiceClient::TranscribeStreamingServiceClient(const std::shared_ptr& credentialsProvider, - const Client::ClientConfiguration& clientConfiguration) : - BASECLASS(clientConfiguration, - Aws::MakeShared(ALLOCATION_TAG, - credentialsProvider, - SERVICE_NAME, - Aws::Region::ComputeSignerRegion(clientConfiguration.region)), - Aws::MakeShared(ALLOCATION_TAG)), - m_clientConfiguration(clientConfiguration), - m_endpointProvider(Aws::MakeShared(ALLOCATION_TAG)) -{ - init(m_clientConfiguration); -} + const Client::ClientConfiguration& clientConfiguration) : + AwsSmithyClientT(clientConfiguration, + GetServiceName(), + "Transcribe Streaming", + Aws::Http::CreateHttpClient(clientConfiguration), + Aws::MakeShared(ALLOCATION_TAG), + Aws::MakeShared(ALLOCATION_TAG), + Aws::MakeShared>(ALLOCATION_TAG), + { + {smithy::SigV4AuthSchemeOption::sigV4AuthSchemeOption.schemeId, smithy::SigV4AuthScheme{Aws::MakeShared(ALLOCATION_TAG, credentialsProvider), GetServiceName(), clientConfiguration.region}} + }) +{} +/* End of legacy constructors due deprecation */ - /* End of legacy constructors due deprecation */ TranscribeStreamingServiceClient::~TranscribeStreamingServiceClient() { ShutdownSdkClient(this, -1); @@ -150,27 +151,11 @@ std::shared_ptr& TranscribeStrea return m_endpointProvider; } -void TranscribeStreamingServiceClient::init(const TranscribeStreamingService::TranscribeStreamingServiceClientConfiguration& config) -{ - AWSClient::SetServiceClientName("Transcribe Streaming"); - if (!m_clientConfiguration.executor) { - if (!m_clientConfiguration.configFactories.executorCreateFn()) { - AWS_LOGSTREAM_FATAL(ALLOCATION_TAG, "Failed to initialize client: config is missing Executor or executorCreateFn"); - m_isInitialized = false; - return; - } - m_clientConfiguration.executor = m_clientConfiguration.configFactories.executorCreateFn(); - } - AWS_CHECK_PTR(SERVICE_NAME, m_endpointProvider); - m_endpointProvider->InitBuiltInParameters(config); -} - void TranscribeStreamingServiceClient::OverrideEndpoint(const Aws::String& endpoint) { - AWS_CHECK_PTR(SERVICE_NAME, m_endpointProvider); - m_endpointProvider->OverrideEndpoint(endpoint); + AWS_CHECK_PTR(SERVICE_NAME, m_endpointProvider); + m_endpointProvider->OverrideEndpoint(endpoint); } - GetMedicalScribeStreamOutcome TranscribeStreamingServiceClient::GetMedicalScribeStream(const GetMedicalScribeStreamRequest& request) const { AWS_OPERATION_GUARD(GetMedicalScribeStream); @@ -180,24 +165,19 @@ GetMedicalScribeStreamOutcome TranscribeStreamingServiceClient::GetMedicalScribe AWS_LOGSTREAM_ERROR("GetMedicalScribeStream", "Required field: SessionId, is not set"); return GetMedicalScribeStreamOutcome(Aws::Client::AWSError(TranscribeStreamingServiceErrors::MISSING_PARAMETER, "MISSING_PARAMETER", "Missing required field [SessionId]", false)); } - AWS_OPERATION_CHECK_PTR(m_telemetryProvider, GetMedicalScribeStream, CoreErrors, CoreErrors::NOT_INITIALIZED); - auto tracer = m_telemetryProvider->getTracer(this->GetServiceClientName(), {}); - auto meter = m_telemetryProvider->getMeter(this->GetServiceClientName(), {}); + AWS_OPERATION_CHECK_PTR(m_clientConfiguration.telemetryProvider, GetMedicalScribeStream, CoreErrors, CoreErrors::NOT_INITIALIZED); + auto tracer = m_clientConfiguration.telemetryProvider->getTracer(this->GetServiceClientName(), {}); + auto meter = m_clientConfiguration.telemetryProvider->getMeter(this->GetServiceClientName(), {}); AWS_OPERATION_CHECK_PTR(meter, GetMedicalScribeStream, CoreErrors, CoreErrors::NOT_INITIALIZED); auto span = tracer->CreateSpan(Aws::String(this->GetServiceClientName()) + ".GetMedicalScribeStream", {{ TracingUtils::SMITHY_METHOD_DIMENSION, request.GetServiceRequestName() }, { TracingUtils::SMITHY_SERVICE_DIMENSION, this->GetServiceClientName() }, { TracingUtils::SMITHY_SYSTEM_DIMENSION, TracingUtils::SMITHY_METHOD_AWS_VALUE }}, smithy::components::tracing::SpanKind::CLIENT); return TracingUtils::MakeCallWithTiming( [&]()-> GetMedicalScribeStreamOutcome { - auto endpointResolutionOutcome = TracingUtils::MakeCallWithTiming( - [&]() -> ResolveEndpointOutcome { return m_endpointProvider->ResolveEndpoint(request.GetEndpointContextParams()); }, - TracingUtils::SMITHY_CLIENT_ENDPOINT_RESOLUTION_METRIC, - *meter, - {{TracingUtils::SMITHY_METHOD_DIMENSION, request.GetServiceRequestName()}, {TracingUtils::SMITHY_SERVICE_DIMENSION, this->GetServiceClientName()}}); - AWS_OPERATION_CHECK_SUCCESS(endpointResolutionOutcome, GetMedicalScribeStream, CoreErrors, CoreErrors::ENDPOINT_RESOLUTION_FAILURE, endpointResolutionOutcome.GetError().GetMessage()); - endpointResolutionOutcome.GetResult().AddPathSegments("/medical-scribe-stream/"); - endpointResolutionOutcome.GetResult().AddPathSegment(request.GetSessionId()); - return GetMedicalScribeStreamOutcome(MakeRequest(request, endpointResolutionOutcome.GetResult(), Aws::Http::HttpMethod::HTTP_GET, Aws::Auth::SIGV4_SIGNER)); + return GetMedicalScribeStreamOutcome(MakeRequestDeserialize(&request, request.GetServiceRequestName(), Aws::Http::HttpMethod::HTTP_GET, [&](Aws::Endpoint::AWSEndpoint& resolvedEndpoint) -> void { + resolvedEndpoint.AddPathSegments("/medical-scribe-stream/"); + resolvedEndpoint.AddPathSegment(request.GetSessionId()); + })); }, TracingUtils::SMITHY_CLIENT_DURATION_METRIC, *meter, @@ -232,35 +212,43 @@ void TranscribeStreamingServiceClient::StartCallAnalyticsStreamTranscriptionAsyn handler(this, request, StartCallAnalyticsStreamTranscriptionOutcome(Aws::Client::AWSError(TranscribeStreamingServiceErrors::MISSING_PARAMETER, "MISSING_PARAMETER", "Missing required field [MediaEncoding]", false)), handlerContext); return; } - auto meter = m_telemetryProvider->getMeter(this->GetServiceClientName(), {}); - auto endpointResolutionOutcome = TracingUtils::MakeCallWithTiming( - [&]() -> ResolveEndpointOutcome { return m_endpointProvider->ResolveEndpoint(request.GetEndpointContextParams()); }, - TracingUtils::SMITHY_CLIENT_ENDPOINT_RESOLUTION_METRIC, - *meter, - {{TracingUtils::SMITHY_METHOD_DIMENSION, request.GetServiceRequestName()}, {TracingUtils::SMITHY_SERVICE_DIMENSION, this->GetServiceClientName()}}); - if (!endpointResolutionOutcome.IsSuccess()) { - handler(this, request, StartCallAnalyticsStreamTranscriptionOutcome(Aws::Client::AWSError( - CoreErrors::ENDPOINT_RESOLUTION_FAILURE, "ENDPOINT_RESOLUTION_FAILURE", endpointResolutionOutcome.GetError().GetMessage(), false)), handlerContext); - return; + request.SetResponseStreamFactory( + [&] { request.GetEventStreamDecoder().Reset(); return Aws::New(ALLOCATION_TAG, request.GetEventStreamDecoder()); } + ); + if (!request.GetHeadersReceivedEventHandler()) { + request.SetHeadersReceivedEventHandler([&request](const Http::HttpRequest*, Http::HttpResponse* response) { + AWS_CHECK_PTR("StartCallAnalyticsStreamTranscription", response); + if (const auto initialResponseHandler = request.GetEventStreamHandler().GetInitialResponseCallbackEx()) { + initialResponseHandler({response->GetHeaders()}, Utils::Event::InitialResponseType::ON_RESPONSE); + } + }); } - endpointResolutionOutcome.GetResult().AddPathSegments("/call-analytics-stream-transcription"); auto eventEncoderStream = Aws::MakeShared(ALLOCATION_TAG); - eventEncoderStream->SetSigner(GetSignerByName(Aws::Auth::EVENTSTREAM_SIGV4_SIGNER)); - auto requestCopy = Aws::MakeShared("StartCallAnalyticsStreamTranscription", request); - requestCopy->SetAudioStream(eventEncoderStream); // this becomes the body of the request - request.SetAudioStream(eventEncoderStream); - - auto asyncTask = CreateBidirectionalEventStreamTask(this, - endpointResolutionOutcome.GetResultWithOwnership(), - requestCopy, - handler, - handlerContext, - eventEncoderStream); - auto sem = asyncTask.GetSemaphore(); - m_clientConfiguration.executor->Submit(std::move(asyncTask)); - sem->WaitOne(); - streamReadyHandler(*eventEncoderStream); + request.SetAudioStream(eventEncoderStream); // this becomes the body of the request + auto streamReadySemaphore = Aws::MakeShared(ALLOCATION_TAG, 0, 1); + request.SetRequestSignedHandler([eventEncoderStream, streamReadySemaphore](const Aws::Http::HttpRequest& httpRequest) { eventEncoderStream->SetSignatureSeed(Aws::Client::GetAuthorizationHeader(httpRequest)); streamReadySemaphore->ReleaseAll(); }); + m_clientConfiguration.executor->Submit([this, &request, handler, handlerContext, eventEncoderStream, streamReadySemaphore] () mutable { + auto eventStreamHandler = [&](std::shared_ptr encoder) -> void{ + eventEncoderStream->SetEncoder(encoder); + }; + JsonOutcome outcome = MakeEventStreamRequestDeserialize(&request, request.GetServiceRequestName(), Aws::Http::HttpMethod::HTTP_POST, [&](Aws::Endpoint::AWSEndpoint& resolvedEndpoint) -> void { + resolvedEndpoint.AddPathSegments("/call-analytics-stream-transcription"); + }, + std::move(eventStreamHandler)); + if(outcome.IsSuccess()) + { + handler(this, request, StartCallAnalyticsStreamTranscriptionOutcome(NoResult()), handlerContext); + } + else + { + request.GetAudioStream()->Close(); + handler(this, request, StartCallAnalyticsStreamTranscriptionOutcome(outcome.GetError()), handlerContext); + } + return StartCallAnalyticsStreamTranscriptionOutcome(NoResult()); + }); + streamReadySemaphore->WaitOne(); + streamReadyHandler(*request.GetAudioStream()); } void TranscribeStreamingServiceClient::StartMedicalScribeStreamAsync(Model::StartMedicalScribeStreamRequest& request, const StartMedicalScribeStreamStreamReadyHandler& streamReadyHandler, @@ -290,35 +278,43 @@ void TranscribeStreamingServiceClient::StartMedicalScribeStreamAsync(Model::Star handler(this, request, StartMedicalScribeStreamOutcome(Aws::Client::AWSError(TranscribeStreamingServiceErrors::MISSING_PARAMETER, "MISSING_PARAMETER", "Missing required field [MediaEncoding]", false)), handlerContext); return; } - auto meter = m_telemetryProvider->getMeter(this->GetServiceClientName(), {}); - auto endpointResolutionOutcome = TracingUtils::MakeCallWithTiming( - [&]() -> ResolveEndpointOutcome { return m_endpointProvider->ResolveEndpoint(request.GetEndpointContextParams()); }, - TracingUtils::SMITHY_CLIENT_ENDPOINT_RESOLUTION_METRIC, - *meter, - {{TracingUtils::SMITHY_METHOD_DIMENSION, request.GetServiceRequestName()}, {TracingUtils::SMITHY_SERVICE_DIMENSION, this->GetServiceClientName()}}); - if (!endpointResolutionOutcome.IsSuccess()) { - handler(this, request, StartMedicalScribeStreamOutcome(Aws::Client::AWSError( - CoreErrors::ENDPOINT_RESOLUTION_FAILURE, "ENDPOINT_RESOLUTION_FAILURE", endpointResolutionOutcome.GetError().GetMessage(), false)), handlerContext); - return; + request.SetResponseStreamFactory( + [&] { request.GetEventStreamDecoder().Reset(); return Aws::New(ALLOCATION_TAG, request.GetEventStreamDecoder()); } + ); + if (!request.GetHeadersReceivedEventHandler()) { + request.SetHeadersReceivedEventHandler([&request](const Http::HttpRequest*, Http::HttpResponse* response) { + AWS_CHECK_PTR("StartMedicalScribeStream", response); + if (const auto initialResponseHandler = request.GetEventStreamHandler().GetInitialResponseCallbackEx()) { + initialResponseHandler({response->GetHeaders()}, Utils::Event::InitialResponseType::ON_RESPONSE); + } + }); } - endpointResolutionOutcome.GetResult().AddPathSegments("/medical-scribe-stream"); auto eventEncoderStream = Aws::MakeShared(ALLOCATION_TAG); - eventEncoderStream->SetSigner(GetSignerByName(Aws::Auth::EVENTSTREAM_SIGV4_SIGNER)); - auto requestCopy = Aws::MakeShared("StartMedicalScribeStream", request); - requestCopy->SetInputStream(eventEncoderStream); // this becomes the body of the request - request.SetInputStream(eventEncoderStream); - - auto asyncTask = CreateBidirectionalEventStreamTask(this, - endpointResolutionOutcome.GetResultWithOwnership(), - requestCopy, - handler, - handlerContext, - eventEncoderStream); - auto sem = asyncTask.GetSemaphore(); - m_clientConfiguration.executor->Submit(std::move(asyncTask)); - sem->WaitOne(); - streamReadyHandler(*eventEncoderStream); + request.SetInputStream(eventEncoderStream); // this becomes the body of the request + auto streamReadySemaphore = Aws::MakeShared(ALLOCATION_TAG, 0, 1); + request.SetRequestSignedHandler([eventEncoderStream, streamReadySemaphore](const Aws::Http::HttpRequest& httpRequest) { eventEncoderStream->SetSignatureSeed(Aws::Client::GetAuthorizationHeader(httpRequest)); streamReadySemaphore->ReleaseAll(); }); + m_clientConfiguration.executor->Submit([this, &request, handler, handlerContext, eventEncoderStream, streamReadySemaphore] () mutable { + auto eventStreamHandler = [&](std::shared_ptr encoder) -> void{ + eventEncoderStream->SetEncoder(encoder); + }; + JsonOutcome outcome = MakeEventStreamRequestDeserialize(&request, request.GetServiceRequestName(), Aws::Http::HttpMethod::HTTP_POST, [&](Aws::Endpoint::AWSEndpoint& resolvedEndpoint) -> void { + resolvedEndpoint.AddPathSegments("/medical-scribe-stream"); + }, + std::move(eventStreamHandler)); + if(outcome.IsSuccess()) + { + handler(this, request, StartMedicalScribeStreamOutcome(NoResult()), handlerContext); + } + else + { + request.GetInputStream()->Close(); + handler(this, request, StartMedicalScribeStreamOutcome(outcome.GetError()), handlerContext); + } + return StartMedicalScribeStreamOutcome(NoResult()); + }); + streamReadySemaphore->WaitOne(); + streamReadyHandler(*request.GetInputStream()); } void TranscribeStreamingServiceClient::StartMedicalStreamTranscriptionAsync(Model::StartMedicalStreamTranscriptionRequest& request, const StartMedicalStreamTranscriptionStreamReadyHandler& streamReadyHandler, @@ -360,35 +356,43 @@ void TranscribeStreamingServiceClient::StartMedicalStreamTranscriptionAsync(Mode handler(this, request, StartMedicalStreamTranscriptionOutcome(Aws::Client::AWSError(TranscribeStreamingServiceErrors::MISSING_PARAMETER, "MISSING_PARAMETER", "Missing required field [Type]", false)), handlerContext); return; } - auto meter = m_telemetryProvider->getMeter(this->GetServiceClientName(), {}); - auto endpointResolutionOutcome = TracingUtils::MakeCallWithTiming( - [&]() -> ResolveEndpointOutcome { return m_endpointProvider->ResolveEndpoint(request.GetEndpointContextParams()); }, - TracingUtils::SMITHY_CLIENT_ENDPOINT_RESOLUTION_METRIC, - *meter, - {{TracingUtils::SMITHY_METHOD_DIMENSION, request.GetServiceRequestName()}, {TracingUtils::SMITHY_SERVICE_DIMENSION, this->GetServiceClientName()}}); - if (!endpointResolutionOutcome.IsSuccess()) { - handler(this, request, StartMedicalStreamTranscriptionOutcome(Aws::Client::AWSError( - CoreErrors::ENDPOINT_RESOLUTION_FAILURE, "ENDPOINT_RESOLUTION_FAILURE", endpointResolutionOutcome.GetError().GetMessage(), false)), handlerContext); - return; + request.SetResponseStreamFactory( + [&] { request.GetEventStreamDecoder().Reset(); return Aws::New(ALLOCATION_TAG, request.GetEventStreamDecoder()); } + ); + if (!request.GetHeadersReceivedEventHandler()) { + request.SetHeadersReceivedEventHandler([&request](const Http::HttpRequest*, Http::HttpResponse* response) { + AWS_CHECK_PTR("StartMedicalStreamTranscription", response); + if (const auto initialResponseHandler = request.GetEventStreamHandler().GetInitialResponseCallbackEx()) { + initialResponseHandler({response->GetHeaders()}, Utils::Event::InitialResponseType::ON_RESPONSE); + } + }); } - endpointResolutionOutcome.GetResult().AddPathSegments("/medical-stream-transcription"); auto eventEncoderStream = Aws::MakeShared(ALLOCATION_TAG); - eventEncoderStream->SetSigner(GetSignerByName(Aws::Auth::EVENTSTREAM_SIGV4_SIGNER)); - auto requestCopy = Aws::MakeShared("StartMedicalStreamTranscription", request); - requestCopy->SetAudioStream(eventEncoderStream); // this becomes the body of the request - request.SetAudioStream(eventEncoderStream); - - auto asyncTask = CreateBidirectionalEventStreamTask(this, - endpointResolutionOutcome.GetResultWithOwnership(), - requestCopy, - handler, - handlerContext, - eventEncoderStream); - auto sem = asyncTask.GetSemaphore(); - m_clientConfiguration.executor->Submit(std::move(asyncTask)); - sem->WaitOne(); - streamReadyHandler(*eventEncoderStream); + request.SetAudioStream(eventEncoderStream); // this becomes the body of the request + auto streamReadySemaphore = Aws::MakeShared(ALLOCATION_TAG, 0, 1); + request.SetRequestSignedHandler([eventEncoderStream, streamReadySemaphore](const Aws::Http::HttpRequest& httpRequest) { eventEncoderStream->SetSignatureSeed(Aws::Client::GetAuthorizationHeader(httpRequest)); streamReadySemaphore->ReleaseAll(); }); + m_clientConfiguration.executor->Submit([this, &request, handler, handlerContext, eventEncoderStream, streamReadySemaphore] () mutable { + auto eventStreamHandler = [&](std::shared_ptr encoder) -> void{ + eventEncoderStream->SetEncoder(encoder); + }; + JsonOutcome outcome = MakeEventStreamRequestDeserialize(&request, request.GetServiceRequestName(), Aws::Http::HttpMethod::HTTP_POST, [&](Aws::Endpoint::AWSEndpoint& resolvedEndpoint) -> void { + resolvedEndpoint.AddPathSegments("/medical-stream-transcription"); + }, + std::move(eventStreamHandler)); + if(outcome.IsSuccess()) + { + handler(this, request, StartMedicalStreamTranscriptionOutcome(NoResult()), handlerContext); + } + else + { + request.GetAudioStream()->Close(); + handler(this, request, StartMedicalStreamTranscriptionOutcome(outcome.GetError()), handlerContext); + } + return StartMedicalStreamTranscriptionOutcome(NoResult()); + }); + streamReadySemaphore->WaitOne(); + streamReadyHandler(*request.GetAudioStream()); } void TranscribeStreamingServiceClient::StartStreamTranscriptionAsync(Model::StartStreamTranscriptionRequest& request, const StartStreamTranscriptionStreamReadyHandler& streamReadyHandler, @@ -412,33 +416,42 @@ void TranscribeStreamingServiceClient::StartStreamTranscriptionAsync(Model::Star handler(this, request, StartStreamTranscriptionOutcome(Aws::Client::AWSError(TranscribeStreamingServiceErrors::MISSING_PARAMETER, "MISSING_PARAMETER", "Missing required field [MediaEncoding]", false)), handlerContext); return; } - auto meter = m_telemetryProvider->getMeter(this->GetServiceClientName(), {}); - auto endpointResolutionOutcome = TracingUtils::MakeCallWithTiming( - [&]() -> ResolveEndpointOutcome { return m_endpointProvider->ResolveEndpoint(request.GetEndpointContextParams()); }, - TracingUtils::SMITHY_CLIENT_ENDPOINT_RESOLUTION_METRIC, - *meter, - {{TracingUtils::SMITHY_METHOD_DIMENSION, request.GetServiceRequestName()}, {TracingUtils::SMITHY_SERVICE_DIMENSION, this->GetServiceClientName()}}); - if (!endpointResolutionOutcome.IsSuccess()) { - handler(this, request, StartStreamTranscriptionOutcome(Aws::Client::AWSError( - CoreErrors::ENDPOINT_RESOLUTION_FAILURE, "ENDPOINT_RESOLUTION_FAILURE", endpointResolutionOutcome.GetError().GetMessage(), false)), handlerContext); - return; + request.SetResponseStreamFactory( + [&] { request.GetEventStreamDecoder().Reset(); return Aws::New(ALLOCATION_TAG, request.GetEventStreamDecoder()); } + ); + if (!request.GetHeadersReceivedEventHandler()) { + request.SetHeadersReceivedEventHandler([&request](const Http::HttpRequest*, Http::HttpResponse* response) { + AWS_CHECK_PTR("StartStreamTranscription", response); + if (const auto initialResponseHandler = request.GetEventStreamHandler().GetInitialResponseCallbackEx()) { + initialResponseHandler({response->GetHeaders()}, Utils::Event::InitialResponseType::ON_RESPONSE); + } + }); } - endpointResolutionOutcome.GetResult().AddPathSegments("/stream-transcription"); auto eventEncoderStream = Aws::MakeShared(ALLOCATION_TAG); - eventEncoderStream->SetSigner(GetSignerByName(Aws::Auth::EVENTSTREAM_SIGV4_SIGNER)); - auto requestCopy = Aws::MakeShared("StartStreamTranscription", request); - requestCopy->SetAudioStream(eventEncoderStream); // this becomes the body of the request - request.SetAudioStream(eventEncoderStream); - - auto asyncTask = CreateBidirectionalEventStreamTask(this, - endpointResolutionOutcome.GetResultWithOwnership(), - requestCopy, - handler, - handlerContext, - eventEncoderStream); - auto sem = asyncTask.GetSemaphore(); - m_clientConfiguration.executor->Submit(std::move(asyncTask)); - sem->WaitOne(); - streamReadyHandler(*eventEncoderStream); + request.SetAudioStream(eventEncoderStream); // this becomes the body of the request + auto streamReadySemaphore = Aws::MakeShared(ALLOCATION_TAG, 0, 1); + request.SetRequestSignedHandler([eventEncoderStream, streamReadySemaphore](const Aws::Http::HttpRequest& httpRequest) { eventEncoderStream->SetSignatureSeed(Aws::Client::GetAuthorizationHeader(httpRequest)); streamReadySemaphore->ReleaseAll(); }); + m_clientConfiguration.executor->Submit([this, &request, handler, handlerContext, eventEncoderStream, streamReadySemaphore] () mutable { + auto eventStreamHandler = [&](std::shared_ptr encoder) -> void{ + eventEncoderStream->SetEncoder(encoder); + }; + JsonOutcome outcome = MakeEventStreamRequestDeserialize(&request, request.GetServiceRequestName(), Aws::Http::HttpMethod::HTTP_POST, [&](Aws::Endpoint::AWSEndpoint& resolvedEndpoint) -> void { + resolvedEndpoint.AddPathSegments("/stream-transcription"); + }, + std::move(eventStreamHandler)); + if(outcome.IsSuccess()) + { + handler(this, request, StartStreamTranscriptionOutcome(NoResult()), handlerContext); + } + else + { + request.GetAudioStream()->Close(); + handler(this, request, StartStreamTranscriptionOutcome(outcome.GetError()), handlerContext); + } + return StartStreamTranscriptionOutcome(NoResult()); + }); + streamReadySemaphore->WaitOne(); + streamReadyHandler(*request.GetAudioStream()); } + diff --git a/src/aws-cpp-sdk-core/include/aws/core/auth/signer/AWSAuthEventStreamV4Signer.h b/src/aws-cpp-sdk-core/include/aws/core/auth/signer/AWSAuthEventStreamV4Signer.h index 27083efb506..40f2e70d5f8 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/auth/signer/AWSAuthEventStreamV4Signer.h +++ b/src/aws-cpp-sdk-core/include/aws/core/auth/signer/AWSAuthEventStreamV4Signer.h @@ -55,20 +55,11 @@ namespace Aws bool SignEventMessage(Aws::Utils::Event::Message&, Aws::String& priorSignature) const override; - bool SignRequest(Aws::Http::HttpRequest& request) const override - { - return SignRequest(request, m_region.c_str(), m_serviceName.c_str(), true); - } + bool SignRequest(Aws::Http::HttpRequest& request) const override; - bool SignRequest(Aws::Http::HttpRequest& request, bool signBody) const override - { - return SignRequest(request, m_region.c_str(), m_serviceName.c_str(), signBody); - } + bool SignRequest(Aws::Http::HttpRequest& request, bool signBody) const override; - bool SignRequest(Aws::Http::HttpRequest& request, const char* region, bool signBody) const override - { - return SignRequest(request, region, m_serviceName.c_str(), signBody); - } + bool SignRequest(Aws::Http::HttpRequest& request, const char* region, bool signBody) const override; bool SignRequest(Aws::Http::HttpRequest& request, const char* region, const char* serviceName, bool signBody) const override; @@ -88,6 +79,12 @@ namespace Aws bool PresignRequest(Aws::Http::HttpRequest&, const char*, const char*, long long) const override { return false; } bool ShouldSignHeader(const Aws::String& header) const; + + bool SignRequest(Aws::Http::HttpRequest& request, const char* region, const char* serviceName, bool /* signBody */, + const Aws::Auth::AWSCredentials& credentials) const; + + bool SignEventMessage(Aws::Utils::Event::Message&, Aws::String& priorSignature, const Aws::Auth::AWSCredentials& creds) const; + private: Utils::ByteBuffer GenerateSignature(const Aws::Auth::AWSCredentials& credentials, const Aws::String& stringToSign, const Aws::String& simpleDate, const Aws::String& region, const Aws::String& serviceName) const; diff --git a/src/aws-cpp-sdk-core/include/aws/core/utils/event/EventEncoderStream.h b/src/aws-cpp-sdk-core/include/aws/core/utils/event/EventEncoderStream.h index 84503da6e4d..b6b3d878157 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/utils/event/EventEncoderStream.h +++ b/src/aws-cpp-sdk-core/include/aws/core/utils/event/EventEncoderStream.h @@ -41,7 +41,7 @@ namespace Aws * Every event uses its previous event's signature to calculate its own signature. * Setting this value affects the signature calculation of the first event. */ - void SetSignatureSeed(const Aws::String& seed) { m_encoder.SetSignatureSeed(seed); } + virtual void SetSignatureSeed(const Aws::String& seed) { m_encoder.SetSignatureSeed(seed); } /** * Writes an event-stream message to the underlying buffer. @@ -67,10 +67,41 @@ namespace Aws */ bool WaitForDrain(int64_t timeoutMs = 1000); + virtual ~EventEncoderStream() {} + protected: + virtual Aws::Vector EncodeAndSign(const Aws::Utils::Event::Message& msg); private: Stream::ConcurrentStreamBuf m_streambuf; EventStreamEncoder m_encoder; }; + + template + class AWS_CORE_API SmithyEventEncoderStream : public EventEncoderStream { + public: + explicit SmithyEventEncoderStream(size_t bufferSize = DEFAULT_BUF_SIZE) : EventEncoderStream(bufferSize) {} + virtual ~SmithyEventEncoderStream() {} + + /*void SetSigner(std::shared_ptr > signer) { + m_evtEncoder.SetSigner(signer); + } + + void SetRequestContext(std::shared_ptr pRequestCtx) { + m_evtEncoder.SetRequestContext(std::move(pRequestCtx)); + }*/ + + void SetEncoder(std::shared_ptr encoder) + { + m_evtEncoder = encoder; + } + + void SetSignatureSeed(const Aws::String& seed) override { m_evtEncoder->SetSignatureSeed(seed); } + + protected: + Aws::Vector EncodeAndSign(const Aws::Utils::Event::Message& msg) override { + return m_evtEncoder->EncodeAndSign(msg); + } + std::shared_ptr m_evtEncoder; + }; } } } diff --git a/src/aws-cpp-sdk-core/include/aws/core/utils/event/EventStreamEncoder.h b/src/aws-cpp-sdk-core/include/aws/core/utils/event/EventStreamEncoder.h index 656388b8e87..3e3da0d64f1 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/utils/event/EventStreamEncoder.h +++ b/src/aws-cpp-sdk-core/include/aws/core/utils/event/EventStreamEncoder.h @@ -8,6 +8,7 @@ #include #include #include +#include namespace Aws { @@ -40,6 +41,11 @@ namespace Aws * The signing is done via the signer member. */ Aws::Vector EncodeAndSign(const Aws::Utils::Event::Message& msg); + + virtual ~EventStreamEncoder(){}; + protected: + virtual bool SignEventMessage(Event::Message& msg); + Aws::String m_signatureSeed; private: /** * Initialize C struct based on C++ object. @@ -57,7 +63,28 @@ namespace Aws bool InitSignedStruct(const aws_event_stream_message* payload, aws_event_stream_message* signedmsg); Aws::Client::AWSAuthSigner* m_signer; - Aws::String m_signatureSeed; + }; + + template + class AWS_CORE_API SmithyEventStreamEncoder : public EventStreamEncoder { + public: + using SIGNER_TYPE = smithy::AwsSignerBase; + SmithyEventStreamEncoder(std::shared_ptr signer, std::shared_ptr awsIdentity) : EventStreamEncoder(), m_smithySigner(signer), m_awsIdentity{awsIdentity}{}; + //SmithyEventStreamEncoder() : EventStreamEncoder(){}; + + protected: + bool SignEventMessage(Event::Message& signedMessage) override { + + //resolved identity + const auto& identity = *static_cast(m_awsIdentity.get()); + + //@to do: if identity expired, resolve it again + return (m_smithySigner->SignEventMessage(signedMessage, m_signatureSeed, identity)); + } + + private: + std::shared_ptr m_smithySigner; + std::shared_ptr m_awsIdentity; }; } } diff --git a/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClient.h b/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClient.h index 5be45b494fc..5794861835e 100644 --- a/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClient.h +++ b/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClient.h @@ -24,6 +24,7 @@ #include #include #include +#include namespace smithy { namespace client @@ -228,6 +229,27 @@ namespace client auto httpResponseOutcome = MakeRequestSync(request, requestName, method, std::move(endpointCallback)); return m_serializer->Deserialize(std::move(httpResponseOutcome), GetServiceClientName(), requestName); } + + ResponseT MakeEventStreamRequestDeserialize( + Aws::AmazonWebServiceRequest const* const request, + const char* requestName, + Aws::Http::HttpMethod method, + EndpointUpdateCallback&& endpointCallback, + std::function) >&& eventEncoderStreamHandler + ) const { + std::shared_ptr pExecutor = + Aws::MakeShared("AwsSmithyClient"); + assert(pExecutor); + + HttpResponseOutcome outcome = ClientError(CoreErrors::INTERNAL_FAILURE, "", "Response handler was not called", false); + ResponseHandlerFunc responseHandler = [&outcome](HttpResponseOutcome&& asyncOutcome) { outcome = std::move(asyncOutcome); }; + pExecutor->Submit([&]() { + this->MakeRequestAsync(request, requestName, method, std::move(endpointCallback), std::move(responseHandler), pExecutor, + std::move(eventEncoderStreamHandler)); + }); + pExecutor->WaitUntilStopped(); + return m_serializer->Deserialize(std::move(outcome), GetServiceClientName(), requestName); + } Aws::String GeneratePresignedUrl( EndpointUpdateCallback&& endpointCallback, diff --git a/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientAsyncRequestContext.h b/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientAsyncRequestContext.h index 6cf9ecd559d..ed81afe0f70 100644 --- a/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientAsyncRequestContext.h +++ b/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientAsyncRequestContext.h @@ -12,6 +12,7 @@ #include #include #include +#include namespace smithy { @@ -71,6 +72,8 @@ namespace smithy std::shared_ptr m_pExecutor; std::shared_ptr m_interceptorContext; std::shared_ptr m_awsIdentity; + //std::shared_ptr m_eventEncoderStream; + std::function )> m_eventEncoderStreamHandler; }; } // namespace client } // namespace smithy diff --git a/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientBase.h b/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientBase.h index 0d499f4b2bf..a9007ba84f4 100644 --- a/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientBase.h +++ b/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientBase.h @@ -23,6 +23,7 @@ #include #include #include +#include namespace Aws { @@ -137,7 +138,9 @@ namespace client Aws::Http::HttpMethod method, EndpointUpdateCallback&& endpointCallback, ResponseHandlerFunc&& responseHandler, - std::shared_ptr pExecutor) const; + std::shared_ptr pExecutor, + std::function) >&& eventEncoderStreamHandler + ) const; HttpResponseOutcome MakeRequestSync(Aws::AmazonWebServiceRequest const * const request, const char* requestName, diff --git a/src/aws-cpp-sdk-core/include/smithy/client/common/AwsSmithyRequestSigning.h b/src/aws-cpp-sdk-core/include/smithy/client/common/AwsSmithyRequestSigning.h index c517ae90f40..33d1c838f97 100644 --- a/src/aws-cpp-sdk-core/include/smithy/client/common/AwsSmithyRequestSigning.h +++ b/src/aws-cpp-sdk-core/include/smithy/client/common/AwsSmithyRequestSigning.h @@ -222,6 +222,11 @@ namespace smithy return; } + if(m_httpRequest->IsEventStreamRequest() && m_requestContext.m_eventEncoderStreamHandler) + { + m_requestContext.m_eventEncoderStreamHandler(Aws::MakeShared>("", signer, m_requestContext.m_awsIdentity)); + } + result.emplace(signer->sign(m_httpRequest, *static_cast(m_requestContext.m_awsIdentity.get()), m_requestContext.m_authSchemeOption.signerProperties())); diff --git a/src/aws-cpp-sdk-core/include/smithy/identity/auth/built-in/SigV4MultiAuthResolver.h b/src/aws-cpp-sdk-core/include/smithy/identity/auth/built-in/SigV4MultiAuthResolver.h index 5cd1b372a81..0aac0c6be67 100644 --- a/src/aws-cpp-sdk-core/include/smithy/identity/auth/built-in/SigV4MultiAuthResolver.h +++ b/src/aws-cpp-sdk-core/include/smithy/identity/auth/built-in/SigV4MultiAuthResolver.h @@ -43,6 +43,8 @@ namespace smithy { } } + bool isRequestEventStream = identityProperties.additionalProperties.find(SignerProperties::EVENT_STREAM_REQUEST) != identityProperties.additionalProperties.end(); + //resolve endpoint first time to fetch auth schemes if (m_endpointProviderForAuth) { auto epResolutionOutcome = m_endpointProviderForAuth->ResolveEndpoint(epParams); diff --git a/src/aws-cpp-sdk-core/include/smithy/identity/signer/AwsSignerBase.h b/src/aws-cpp-sdk-core/include/smithy/identity/signer/AwsSignerBase.h index b5509539ad2..81eeaeeecc9 100644 --- a/src/aws-cpp-sdk-core/include/smithy/identity/signer/AwsSignerBase.h +++ b/src/aws-cpp-sdk-core/include/smithy/identity/signer/AwsSignerBase.h @@ -47,6 +47,7 @@ namespace smithy { // signer may copy the original httpRequest or create a new one virtual SigningFutureOutcome sign(std::shared_ptr httpRequest, const IdentityT& identity, SigningProperties properties) = 0; virtual SigningFutureOutcome presign(std::shared_ptr httpRequest, const IdentityT& identity, SigningProperties properties, const Aws::String& region, const Aws::String& serviceName, long long expirationTimeInSeconds) = 0; + virtual bool SignEventMessage(Aws::Utils::Event::Message& , Aws::String& , const IDENTITY_T& ) const {return false;} virtual ~AwsSignerBase() {}; }; diff --git a/src/aws-cpp-sdk-core/include/smithy/identity/signer/built-in/SigV4Signer.h b/src/aws-cpp-sdk-core/include/smithy/identity/signer/built-in/SigV4Signer.h index df405a63987..2289a59df31 100644 --- a/src/aws-cpp-sdk-core/include/smithy/identity/signer/built-in/SigV4Signer.h +++ b/src/aws-cpp-sdk-core/include/smithy/identity/signer/built-in/SigV4Signer.h @@ -25,7 +25,8 @@ namespace smithy { explicit AwsSigV4Signer(const Aws::String& serviceName, const Aws::String& region) : m_serviceName(serviceName), m_region(region), - legacySigner(nullptr, serviceName.c_str(), region, Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Always) + legacySigner(nullptr, serviceName.c_str(), region, Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Always), + legacyEventStreamSigner{Aws::MakeShared("SigV4AuthScheme", nullptr, serviceName.c_str(), region)} { } /* @@ -34,7 +35,8 @@ namespace smithy { explicit AwsSigV4Signer(const Aws::String& serviceName, const Aws::String& region, Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy policy, bool urlEscapePath) : m_serviceName(serviceName), m_region(region), - legacySigner(nullptr, serviceName.c_str(), region, policy, urlEscapePath) + legacySigner(nullptr, serviceName.c_str(), region, policy, urlEscapePath), + legacyEventStreamSigner{Aws::MakeShared("SigV4AuthScheme", nullptr, serviceName.c_str(), region)} { } @@ -63,7 +65,9 @@ namespace smithy { assert(httpRequest); - bool success = legacySigner.SignRequestWithCreds(*httpRequest, legacyCreds, region, svcName, signPayload); + + bool success = httpRequest->IsEventStreamRequest()? legacyEventStreamSigner->SignRequest(*httpRequest, region, svcName, signPayload, legacyCreds) : + legacySigner.SignRequestWithCreds(*httpRequest, legacyCreds, region, svcName, signPayload); if (success) { return SigningFutureOutcome(std::move(httpRequest)); @@ -108,8 +112,26 @@ namespace smithy { return legacySigner; } + bool SignEventMessage(Aws::Utils::Event::Message& em, Aws::String& sig, const AwsCredentialIdentityBase& identity) const override { + + //get legacy credentials + const auto legacyCreds = [&identity]() -> Aws::Auth::AWSCredentials { + if(identity.sessionToken().has_value() && identity.expiration().has_value()) + { + return {identity.accessKeyId(), identity.secretAccessKey(), *identity.sessionToken(), *identity.expiration()}; + } + if(identity.sessionToken().has_value()) + { + return {identity.accessKeyId(), identity.secretAccessKey(), *identity.sessionToken()}; + } + return {identity.accessKeyId(), identity.secretAccessKey()}; + }(); + return legacyEventStreamSigner->SignEventMessage(em, sig, legacyCreds); + } + Aws::String m_serviceName; Aws::String m_region; Aws::Client::AWSAuthV4Signer legacySigner; + std::shared_ptr legacyEventStreamSigner; }; } diff --git a/src/aws-cpp-sdk-core/source/auth/signer/AWSAuthEventStreamV4Signer.cpp b/src/aws-cpp-sdk-core/source/auth/signer/AWSAuthEventStreamV4Signer.cpp index 619f122938b..e1c84f175fd 100644 --- a/src/aws-cpp-sdk-core/source/auth/signer/AWSAuthEventStreamV4Signer.cpp +++ b/src/aws-cpp-sdk-core/source/auth/signer/AWSAuthEventStreamV4Signer.cpp @@ -58,93 +58,10 @@ AWSAuthEventStreamV4Signer::AWSAuthEventStreamV4Signer(const std::shared_ptrGetAWSCredentials(); - - //don't sign anonymous requests - if (credentials.GetAWSAccessKeyId().empty() || credentials.GetAWSSecretKey().empty()) - { - return true; - } - - if (!credentials.GetSessionToken().empty()) - { - request.SetAwsSessionToken(credentials.GetSessionToken()); - } - - request.SetHeaderValue(Aws::Auth::AWSAuthHelper::X_AMZ_CONTENT_SHA256, EVENT_STREAM_CONTENT_SHA256); - - //calculate date header to use in internal signature (this also goes into date header). - DateTime now = GetSigningTimestamp(); - Aws::String dateHeaderValue = now.ToGmtString(DateFormat::ISO_8601_BASIC); - request.SetHeaderValue(AWS_DATE_HEADER, dateHeaderValue); - - Aws::StringStream headersStream; - Aws::StringStream signedHeadersStream; - - for (const auto& header : Aws::Auth::AWSAuthHelper::CanonicalizeHeaders(request.GetHeaders())) - { - if(ShouldSignHeader(header.first)) - { - headersStream << header.first.c_str() << ":" << header.second.c_str() << Aws::Auth::AWSAuthHelper::NEWLINE; - signedHeadersStream << header.first.c_str() << ";"; - } - } - - Aws::String canonicalHeadersString = headersStream.str(); - AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "Canonical Header String: " << canonicalHeadersString); - - //calculate signed headers parameter - Aws::String signedHeadersValue = signedHeadersStream.str(); - //remove that last semi-colon - if (!signedHeadersValue.empty()) - { - signedHeadersValue.pop_back(); - } - - AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "Signed Headers value:" << signedHeadersValue); - - //generate generalized canonicalized request string. - Aws::String canonicalRequestString = Aws::Auth::AWSAuthHelper::CanonicalizeRequestSigningString(request, true/* m_urlEscapePath */); - - //append v4 stuff to the canonical request string. - canonicalRequestString.append(canonicalHeadersString); - canonicalRequestString.append(Aws::Auth::AWSAuthHelper::NEWLINE); - canonicalRequestString.append(signedHeadersValue); - canonicalRequestString.append(Aws::Auth::AWSAuthHelper::NEWLINE); - canonicalRequestString.append(EVENT_STREAM_CONTENT_SHA256); - - AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "Canonical Request String: " << canonicalRequestString); - - //now compute sha256 on that request string - auto sha256Digest = HashingUtils::CalculateSHA256(canonicalRequestString); - if (sha256Digest.GetLength() == 0) - { - AWS_LOGSTREAM_ERROR(v4StreamingLogTag, "Failed to hash (sha256) request string"); - AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "The request string is: \"" << canonicalRequestString << "\""); - return false; - } - - Aws::String canonicalRequestHash = HashingUtils::HexEncode(sha256Digest); - Aws::String simpleDate = now.ToGmtString(Aws::Auth::AWSAuthHelper::SIMPLE_DATE_FORMAT_STR); - - Aws::String signingRegion = region ? region : m_region; - Aws::String signingServiceName = serviceName ? serviceName : m_serviceName; - Aws::String stringToSign = GenerateStringToSign(dateHeaderValue, simpleDate, canonicalRequestHash, signingRegion, signingServiceName); - auto finalSignature = GenerateSignature(credentials, stringToSign, simpleDate, signingRegion, signingServiceName); - - Aws::StringStream ss; - ss << Aws::Auth::AWSAuthHelper::AWS_HMAC_SHA256 << " " << Aws::Auth::AWSAuthHelper::CREDENTIAL << Aws::Auth::AWSAuthHelper::EQ << credentials.GetAWSAccessKeyId() << "/" << simpleDate - << "/" << signingRegion << "/" << signingServiceName << "/" << Aws::Auth::AWSAuthHelper::AWS4_REQUEST << ", " << Aws::Auth::AWSAuthHelper::SIGNED_HEADERS << Aws::Auth::AWSAuthHelper::EQ - << signedHeadersValue << ", " << SIGNATURE << Aws::Auth::AWSAuthHelper::EQ << HashingUtils::HexEncode(finalSignature); - - auto awsAuthString = ss.str(); - AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "Signing request with: " << awsAuthString); - request.SetAwsAuthorization(awsAuthString); - request.SetSigningAccessKey(credentials.GetAWSAccessKeyId()); - request.SetSigningRegion(signingRegion); - return true; + return SignRequest(request, region, serviceName, signBody, credentials); } // this works regardless if the current machine is Big/Little Endian @@ -160,70 +77,8 @@ static void WriteBigEndian(Aws::String& str, uint64_t n) bool AWSAuthEventStreamV4Signer::SignEventMessage(Event::Message& message, Aws::String& priorSignature) const { - using Event::EventHeaderValue; - - Aws::StringStream stringToSign; - stringToSign << EVENT_STREAM_PAYLOAD << Aws::Auth::AWSAuthHelper::NEWLINE; - const DateTime now = GetSigningTimestamp(); - const auto simpleDate = now.ToGmtString(Aws::Auth::AWSAuthHelper::SIMPLE_DATE_FORMAT_STR); - stringToSign << now.ToGmtString(DateFormat::ISO_8601_BASIC) << Aws::Auth::AWSAuthHelper::NEWLINE - << simpleDate << "/" << m_region << "/" - << m_serviceName << "/aws4_request" << Aws::Auth::AWSAuthHelper::NEWLINE << priorSignature << Aws::Auth::AWSAuthHelper::NEWLINE; - - - Aws::String nonSignatureHeaders; - nonSignatureHeaders.push_back(char(sizeof(EVENTSTREAM_DATE_HEADER) - 1)); // length of the string - nonSignatureHeaders += EVENTSTREAM_DATE_HEADER; - nonSignatureHeaders.push_back(static_cast(EventHeaderValue::EventHeaderType::TIMESTAMP)); // type of the value - WriteBigEndian(nonSignatureHeaders, static_cast(now.Millis())); // the value of the timestamp in big-endian - - auto nonSignatureHeadersHash = HashingUtils::CalculateSHA256(nonSignatureHeaders); - if (nonSignatureHeadersHash.GetLength() == 0) - { - AWS_LOGSTREAM_ERROR(v4StreamingLogTag, "Failed to hash (sha256) non-signature headers."); - return false; - } - - stringToSign << HashingUtils::HexEncode(nonSignatureHeadersHash) << Aws::Auth::AWSAuthHelper::NEWLINE; - - ByteBuffer payloadHash; - if (!message.GetEventPayload().empty()) - { - // use a preallocatedStreamBuf to avoid making a copy. - // The Hashing API requires either Aws::String or IStream as input. - // TODO: the hashing API should be accept 'unsigned char*' as input. - Utils::Stream::PreallocatedStreamBuf streamBuf(message.GetEventPayload().data(), message.GetEventPayload().size()); - Aws::IOStream payload(&streamBuf); - payloadHash = HashingUtils::CalculateSHA256(payload); - } - else - { - // only a signature and a date will be in a frame - AWS_LOGSTREAM_INFO(v4StreamingLogTag, "Signing an event with an empty payload"); - - payloadHash = HashingUtils::CalculateSHA256(""); // SHA256 of an empty buffer - } - - if (payloadHash.GetLength() == 0) - { - AWS_LOGSTREAM_ERROR(v4StreamingLogTag, "Failed to hash (sha256) non-signature headers."); - return false; - } - stringToSign << HashingUtils::HexEncode(payloadHash); - AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "Payload hash - " << HashingUtils::HexEncode(payloadHash)); - - Aws::String canonicalRequestString = stringToSign.str(); - AWS_LOGSTREAM_TRACE(v4StreamingLogTag, "EventStream Event Canonical Request String: " << canonicalRequestString); - Aws::Utils::ByteBuffer finalSignatureDigest = GenerateSignature(m_credentialsProvider->GetAWSCredentials(), canonicalRequestString, simpleDate, m_region, m_serviceName); - const auto finalSignature = HashingUtils::HexEncode(finalSignatureDigest); - AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "Final computed signing hash: " << finalSignature); - priorSignature = finalSignature; - - message.InsertEventHeader(EVENTSTREAM_DATE_HEADER, EventHeaderValue(now.Millis(), EventHeaderValue::EventHeaderType::TIMESTAMP)); - message.InsertEventHeader(EVENTSTREAM_SIGNATURE_HEADER, std::move(finalSignatureDigest)); - - AWS_LOGSTREAM_INFO(v4StreamingLogTag, "Event chunk final signature - " << finalSignature); - return true; + const auto& creds = m_credentialsProvider->GetAWSCredentials(); + return SignEventMessage(message, priorSignature, creds); } bool AWSAuthEventStreamV4Signer::ShouldSignHeader(const Aws::String& header) const @@ -317,3 +172,162 @@ Aws::Utils::ByteBuffer AWSAuthEventStreamV4Signer::ComputeHash(const Aws::String } return hashResult; } + +bool AWSAuthEventStreamV4Signer::SignRequest(Aws::Http::HttpRequest& request, const char* region, const char* serviceName, + bool /* signBody */, const AWSCredentials& credentials) const { + // don't sign anonymous requests + if (credentials.GetAWSAccessKeyId().empty() || credentials.GetAWSSecretKey().empty()) { + return true; + } + + if (!credentials.GetSessionToken().empty()) { + request.SetAwsSessionToken(credentials.GetSessionToken()); + } + + request.SetHeaderValue(Aws::Auth::AWSAuthHelper::X_AMZ_CONTENT_SHA256, EVENT_STREAM_CONTENT_SHA256); + + // calculate date header to use in internal signature (this also goes into date header). + DateTime now = GetSigningTimestamp(); + Aws::String dateHeaderValue = now.ToGmtString(DateFormat::ISO_8601_BASIC); + request.SetHeaderValue(AWS_DATE_HEADER, dateHeaderValue); + + Aws::StringStream headersStream; + Aws::StringStream signedHeadersStream; + + for (const auto& header : Aws::Auth::AWSAuthHelper::CanonicalizeHeaders(request.GetHeaders())) { + if (ShouldSignHeader(header.first)) { + headersStream << header.first.c_str() << ":" << header.second.c_str() << Aws::Auth::AWSAuthHelper::NEWLINE; + signedHeadersStream << header.first.c_str() << ";"; + } + } + + Aws::String canonicalHeadersString = headersStream.str(); + AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "Canonical Header String: " << canonicalHeadersString); + + // calculate signed headers parameter + Aws::String signedHeadersValue = signedHeadersStream.str(); + // remove that last semi-colon + if (!signedHeadersValue.empty()) { + signedHeadersValue.pop_back(); + } + + AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "Signed Headers value:" << signedHeadersValue); + + // generate generalized canonicalized request string. + Aws::String canonicalRequestString = Aws::Auth::AWSAuthHelper::CanonicalizeRequestSigningString(request, true /* m_urlEscapePath */); + + // append v4 stuff to the canonical request string. + canonicalRequestString.append(canonicalHeadersString); + canonicalRequestString.append(Aws::Auth::AWSAuthHelper::NEWLINE); + canonicalRequestString.append(signedHeadersValue); + canonicalRequestString.append(Aws::Auth::AWSAuthHelper::NEWLINE); + canonicalRequestString.append(EVENT_STREAM_CONTENT_SHA256); + + AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "Canonical Request String: " << canonicalRequestString); + + // now compute sha256 on that request string + auto sha256Digest = HashingUtils::CalculateSHA256(canonicalRequestString); + if (sha256Digest.GetLength() == 0) { + AWS_LOGSTREAM_ERROR(v4StreamingLogTag, "Failed to hash (sha256) request string"); + AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "The request string is: \"" << canonicalRequestString << "\""); + return false; + } + + Aws::String canonicalRequestHash = HashingUtils::HexEncode(sha256Digest); + Aws::String simpleDate = now.ToGmtString(Aws::Auth::AWSAuthHelper::SIMPLE_DATE_FORMAT_STR); + + Aws::String signingRegion = region ? region : m_region; + Aws::String signingServiceName = serviceName ? serviceName : m_serviceName; + Aws::String stringToSign = GenerateStringToSign(dateHeaderValue, simpleDate, canonicalRequestHash, signingRegion, signingServiceName); + auto finalSignature = GenerateSignature(credentials, stringToSign, simpleDate, signingRegion, signingServiceName); + + Aws::StringStream ss; + ss << Aws::Auth::AWSAuthHelper::AWS_HMAC_SHA256 << " " << Aws::Auth::AWSAuthHelper::CREDENTIAL << Aws::Auth::AWSAuthHelper::EQ + << credentials.GetAWSAccessKeyId() << "/" << simpleDate << "/" << signingRegion << "/" << signingServiceName << "/" + << Aws::Auth::AWSAuthHelper::AWS4_REQUEST << ", " << Aws::Auth::AWSAuthHelper::SIGNED_HEADERS << Aws::Auth::AWSAuthHelper::EQ + << signedHeadersValue << ", " << SIGNATURE << Aws::Auth::AWSAuthHelper::EQ << HashingUtils::HexEncode(finalSignature); + + auto awsAuthString = ss.str(); + AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "Signing request with: " << awsAuthString); + request.SetAwsAuthorization(awsAuthString); + request.SetSigningAccessKey(credentials.GetAWSAccessKeyId()); + request.SetSigningRegion(signingRegion); + return true; +} + +bool AWSAuthEventStreamV4Signer::SignEventMessage(Event::Message& message, Aws::String& priorSignature, const AWSCredentials& creds) const { + using Event::EventHeaderValue; + + Aws::StringStream stringToSign; + stringToSign << EVENT_STREAM_PAYLOAD << Aws::Auth::AWSAuthHelper::NEWLINE; + const DateTime now = GetSigningTimestamp(); + const auto simpleDate = now.ToGmtString(Aws::Auth::AWSAuthHelper::SIMPLE_DATE_FORMAT_STR); + stringToSign << now.ToGmtString(DateFormat::ISO_8601_BASIC) << Aws::Auth::AWSAuthHelper::NEWLINE << simpleDate << "/" << m_region << "/" + << m_serviceName << "/aws4_request" << Aws::Auth::AWSAuthHelper::NEWLINE << priorSignature + << Aws::Auth::AWSAuthHelper::NEWLINE; + + Aws::String nonSignatureHeaders; + nonSignatureHeaders.push_back(char(sizeof(EVENTSTREAM_DATE_HEADER) - 1)); // length of the string + nonSignatureHeaders += EVENTSTREAM_DATE_HEADER; + nonSignatureHeaders.push_back(static_cast(EventHeaderValue::EventHeaderType::TIMESTAMP)); // type of the value + WriteBigEndian(nonSignatureHeaders, static_cast(now.Millis())); // the value of the timestamp in big-endian + + auto nonSignatureHeadersHash = HashingUtils::CalculateSHA256(nonSignatureHeaders); + if (nonSignatureHeadersHash.GetLength() == 0) { + AWS_LOGSTREAM_ERROR(v4StreamingLogTag, "Failed to hash (sha256) non-signature headers."); + return false; + } + + stringToSign << HashingUtils::HexEncode(nonSignatureHeadersHash) << Aws::Auth::AWSAuthHelper::NEWLINE; + + ByteBuffer payloadHash; + if (!message.GetEventPayload().empty()) { + // use a preallocatedStreamBuf to avoid making a copy. + // The Hashing API requires either Aws::String or IStream as input. + // TODO: the hashing API should be accept 'unsigned char*' as input. + Utils::Stream::PreallocatedStreamBuf streamBuf(message.GetEventPayload().data(), message.GetEventPayload().size()); + Aws::IOStream payload(&streamBuf); + payloadHash = HashingUtils::CalculateSHA256(payload); + } else { + // only a signature and a date will be in a frame + AWS_LOGSTREAM_INFO(v4StreamingLogTag, "Signing an event with an empty payload"); + + payloadHash = HashingUtils::CalculateSHA256(""); // SHA256 of an empty buffer + } + + if (payloadHash.GetLength() == 0) { + AWS_LOGSTREAM_ERROR(v4StreamingLogTag, "Failed to hash (sha256) non-signature headers."); + return false; + } + stringToSign << HashingUtils::HexEncode(payloadHash); + AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "Payload hash - " << HashingUtils::HexEncode(payloadHash)); + + Aws::String canonicalRequestString = stringToSign.str(); + AWS_LOGSTREAM_TRACE(v4StreamingLogTag, "EventStream Event Canonical Request String: " << canonicalRequestString); + Aws::Utils::ByteBuffer finalSignatureDigest = GenerateSignature(creds, canonicalRequestString, simpleDate, m_region, m_serviceName); + const auto finalSignature = HashingUtils::HexEncode(finalSignatureDigest); + AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "Final computed signing hash: " << finalSignature); + priorSignature = finalSignature; + + message.InsertEventHeader(EVENTSTREAM_DATE_HEADER, EventHeaderValue(now.Millis(), EventHeaderValue::EventHeaderType::TIMESTAMP)); + message.InsertEventHeader(EVENTSTREAM_SIGNATURE_HEADER, std::move(finalSignatureDigest)); + + AWS_LOGSTREAM_INFO(v4StreamingLogTag, "Event chunk final signature - " << finalSignature); + return true; + } + + +bool AWSAuthEventStreamV4Signer::SignRequest(Aws::Http::HttpRequest& request) const +{ + return SignRequest(request, m_region.c_str(), m_serviceName.c_str(), true, m_credentialsProvider->GetAWSCredentials()); +} + +bool AWSAuthEventStreamV4Signer::SignRequest(Aws::Http::HttpRequest& request, bool signBody) const +{ + return SignRequest(request, m_region.c_str(), m_serviceName.c_str(), signBody, m_credentialsProvider->GetAWSCredentials()); +} + +bool AWSAuthEventStreamV4Signer::SignRequest(Aws::Http::HttpRequest& request, const char* region, bool signBody) const +{ + return SignRequest(request, region, m_serviceName.c_str(), signBody, m_credentialsProvider->GetAWSCredentials()); +} \ No newline at end of file diff --git a/src/aws-cpp-sdk-core/source/smithy/client/AwsSmithyClientBase.cpp b/src/aws-cpp-sdk-core/source/smithy/client/AwsSmithyClientBase.cpp index 83022a6b011..7debda1153f 100644 --- a/src/aws-cpp-sdk-core/source/smithy/client/AwsSmithyClientBase.cpp +++ b/src/aws-cpp-sdk-core/source/smithy/client/AwsSmithyClientBase.cpp @@ -182,7 +182,9 @@ void AwsSmithyClientBase::MakeRequestAsync(Aws::AmazonWebServiceRequest const* c Aws::Http::HttpMethod method, EndpointUpdateCallback&& endpointCallback, ResponseHandlerFunc&& responseHandler, - std::shared_ptr pExecutor) const + std::shared_ptr pExecutor, + std::function) >&& eventEncoderStreamHandler + ) const { if(!responseHandler) { @@ -212,6 +214,8 @@ void AwsSmithyClientBase::MakeRequestAsync(Aws::AmazonWebServiceRequest const* c pRequestCtx->m_method = method; pRequestCtx->m_retryCount = 0; pRequestCtx->m_invocationId = Aws::Utils::UUID::PseudoRandomUUID(); + pRequestCtx->m_eventEncoderStreamHandler = std::move(eventEncoderStreamHandler); + auto authSchemeOptionOutcome = this->SelectAuthSchemeOption(*pRequestCtx); if (!authSchemeOptionOutcome.IsSuccess()) { @@ -643,7 +647,7 @@ AwsSmithyClientBase::MakeRequestSync(Aws::AmazonWebServiceRequest const * const pExecutor->Submit([&]() { - this->MakeRequestAsync(request, requestName, method, std::move(endpointCallback), std::move(responseHandler), pExecutor); + this->MakeRequestAsync(request, requestName, method, std::move(endpointCallback), std::move(responseHandler), pExecutor, nullptr); }); pExecutor->WaitUntilStopped(); diff --git a/src/aws-cpp-sdk-core/source/utils/event/EventEncoderStream.cpp b/src/aws-cpp-sdk-core/source/utils/event/EventEncoderStream.cpp index c413a09ede7..2f3cc4c8c0c 100644 --- a/src/aws-cpp-sdk-core/source/utils/event/EventEncoderStream.cpp +++ b/src/aws-cpp-sdk-core/source/utils/event/EventEncoderStream.cpp @@ -26,7 +26,7 @@ namespace Aws EventEncoderStream& EventEncoderStream::WriteEvent(const Aws::Utils::Event::Message& msg) { - auto bits = m_encoder.EncodeAndSign(msg); + auto bits = EncodeAndSign(msg); AWS_LOGSTREAM_TRACE("EventEncoderStream::WriteEvent", "Encoded event (base64 encoded): " << Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::ByteBuffer(bits.data(), bits.size()))); @@ -39,6 +39,10 @@ namespace Aws flush(); return *this; } + + Aws::Vector EventEncoderStream::EncodeAndSign(const Aws::Utils::Event::Message& msg) { + return m_encoder.EncodeAndSign(msg); + } } } } diff --git a/src/aws-cpp-sdk-core/source/utils/event/EventStreamEncoder.cpp b/src/aws-cpp-sdk-core/source/utils/event/EventStreamEncoder.cpp index 97a48ac40b9..b3a5c93a264 100644 --- a/src/aws-cpp-sdk-core/source/utils/event/EventStreamEncoder.cpp +++ b/src/aws-cpp-sdk-core/source/utils/event/EventStreamEncoder.cpp @@ -147,8 +147,7 @@ namespace Aws signedMessage.WriteEventPayload(msgbuf, msglen); } - assert(m_signer); - if (m_signer->SignEventMessage(signedMessage, m_signatureSeed)) + if (SignEventMessage(signedMessage)) { aws_array_list headers; EncodeHeaders(signedMessage, &headers); @@ -173,6 +172,12 @@ namespace Aws return success; } + bool EventStreamEncoder::SignEventMessage(Event::Message& signedMessage) { + assert(m_signer); + assert(!m_signatureSeed.empty()); + return (m_signer->SignEventMessage(signedMessage, m_signatureSeed)); + } + } // namespace Event } // namespace Utils } // namespace Aws diff --git a/tools/code-generation/generator/src/main/java/com/amazonaws/util/awsclientgenerator/generators/cpp/JsonCppClientGenerator.java b/tools/code-generation/generator/src/main/java/com/amazonaws/util/awsclientgenerator/generators/cpp/JsonCppClientGenerator.java index 0a11969471f..ac4a90ddd58 100644 --- a/tools/code-generation/generator/src/main/java/com/amazonaws/util/awsclientgenerator/generators/cpp/JsonCppClientGenerator.java +++ b/tools/code-generation/generator/src/main/java/com/amazonaws/util/awsclientgenerator/generators/cpp/JsonCppClientGenerator.java @@ -72,7 +72,14 @@ protected SdkFileEntry generateModelHeaderFile(ServiceModel serviceModel, Map.En template = velocityEngine.getTemplate("/com/amazonaws/util/awsclientgenerator/velocity/cpp/json/JsonResultHeader.vm", StandardCharsets.UTF_8.name()); } } else if (shape.isEventStream() && shape.isOutgoingEventStream()) { - template = velocityEngine.getTemplate("/com/amazonaws/util/awsclientgenerator/velocity/cpp/json/EventStreamHeader.vm", StandardCharsets.UTF_8.name()); + if(serviceModel.isUseSmithyClient()) + { + template = velocityEngine.getTemplate("/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyEventStreamHeader.vm", StandardCharsets.UTF_8.name()); + } + else + { + template = velocityEngine.getTemplate("/com/amazonaws/util/awsclientgenerator/velocity/cpp/json/EventStreamHeader.vm", StandardCharsets.UTF_8.name()); + } } else if (shape.isStructure()) { template = velocityEngine.getTemplate("/com/amazonaws/util/awsclientgenerator/velocity/cpp/json/JsonSubObjectHeader.vm", StandardCharsets.UTF_8.name()); } @@ -153,7 +160,7 @@ else if (shape.isResult()) { @Override protected SdkFileEntry generateClientHeaderFile(final ServiceModel serviceModel) throws Exception { - if (serviceModel.isUseSmithyClient() && !serviceModel.hasEventStreamingRequestShapes()) { + if (serviceModel.isUseSmithyClient() && serviceModel.hasEventStreamingRequestShapes()) { return generateClientSmithyHeaderFile(serviceModel); } @@ -175,7 +182,7 @@ protected List generateClientSourceFile( List servic return serviceModelsIndices.stream().map(index -> { - if(serviceModels.get(index).isUseSmithyClient() && !serviceModels.get(index).hasEventStreamingRequestShapes()) + if(serviceModels.get(index).isUseSmithyClient() && serviceModels.get(index).hasEventStreamingRequestShapes()) { return GenerateSmithyClientSourceFile(serviceModels.get(index), index, Optional.empty()); } diff --git a/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyEndpointClosure.vm b/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyEndpointClosure.vm index a85295c21d9..0e0db24fa85 100644 --- a/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyEndpointClosure.vm +++ b/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyEndpointClosure.vm @@ -1,7 +1,4 @@ #set($indent = " ") -#if($operation.request.shape.hasEventStreamMembers()) -${indent}streamReadySemaphore->ReleaseAll(); -#end #parse("/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyUriRequestQueryParams.vm") #if($metadata.hasEndpointDiscoveryTrait) #if($operation.hasEndpointDiscoveryTrait) diff --git a/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyEventStreamHeader.vm b/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyEventStreamHeader.vm index aac66739df0..101ae98c401 100644 --- a/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyEventStreamHeader.vm +++ b/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyEventStreamHeader.vm @@ -9,6 +9,11 @@ \#include $header #end \#include +#if($serviceModel.hasOnlyBearerAuth()) +\#include +#else +\#include +#end #foreach($entry in $shape.members.entrySet()) #if($entry.value.shape.isEvent()) #foreach($eventMemberEntry in $entry.value.shape.members.entrySet()) @@ -40,7 +45,11 @@ namespace Model * $shape.documentation */ #end - class $typeInfo.exportValue $typeInfo.className : public Aws::Utils::Event::EventEncoderStream +#if($serviceModel.hasOnlyBearerAuth()) + class $typeInfo.exportValue $typeInfo.className : public Aws::Utils::Event::SmithyEventEncoderStream +#else + class $typeInfo.exportValue $typeInfo.className : public Aws::Utils::Event::SmithyEventEncoderStream +#end { public: #foreach($entry in $shape.members.entrySet()) diff --git a/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyJsonServiceEventStreamOperationsSource.vm b/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyJsonServiceEventStreamOperationsSource.vm index 572f6812d92..bb54238b850 100644 --- a/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyJsonServiceEventStreamOperationsSource.vm +++ b/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyJsonServiceEventStreamOperationsSource.vm @@ -6,7 +6,6 @@ void ${className}::${operation.name}Async(Model::${operation.request.shape.name} AWS_ASYNC_OPERATION_GUARD(${operation.name}); #parse("com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyServiceClientOperationsEndpointPrepareCommonBody.vm") #parse("com/amazonaws/util/awsclientgenerator/velocity/cpp/common/ServiceClientOperationRequestRequiredMemberValidate.vm") -#parse("/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyUriRequestQueryParams.vm") #if($operation.result && $operation.result.shape.hasEventStreamMembers()) request.SetResponseStreamFactory( [&] { request.GetEventStreamDecoder().Reset(); return Aws::New(ALLOCATION_TAG, request.GetEventStreamDecoder()); } @@ -38,12 +37,15 @@ void ${className}::${operation.name}Async(Model::${operation.request.shape.name} auto eventEncoderStream = Aws::MakeShared(ALLOCATION_TAG); request.Set${streamModelNameWithFirstLetterCapitalized}(eventEncoderStream); // this becomes the body of the request auto streamReadySemaphore = Aws::MakeShared(ALLOCATION_TAG, 0, 1); - m_clientConfiguration.executor->Submit([this, &request, handler, handlerContext, #if($hasEndPointOverrides) endpointOverrides #end, eventEncoderStream, streamReadySemaphore] () mutable { - JsonOutcome outcome = MakeEventStreamRequestDeserialize(&request, request.GetServiceRequestName(), Aws::Http::HttpMethod::HTTP_${operation.http.method}, [&](Aws::Endpoint::AWSEndpoint& resolvedEndpoint) -> void { + request.SetRequestSignedHandler([eventEncoderStream, streamReadySemaphore](const Aws::Http::HttpRequest& httpRequest) { eventEncoderStream->SetSignatureSeed(Aws::Client::GetAuthorizationHeader(httpRequest)); streamReadySemaphore->ReleaseAll(); }); + m_clientConfiguration.executor->Submit([this, &request, handler, handlerContext, #if($hasEndPointOverrides) endpointOverrides, #end eventEncoderStream, streamReadySemaphore] () mutable { + auto eventStreamHandler = [&](std::shared_ptr encoder) -> void{ + eventEncoderStream->SetEncoder(encoder); + }; + JsonOutcome outcome = MakeEventStreamRequestDeserialize(&request, request.GetServiceRequestName(), Aws::Http::HttpMethod::HTTP_${operation.http.method}, [&](Aws::Endpoint::AWSEndpoint& resolvedEndpoint) -> void { #parse("/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyEndpointClosure.vm") }, - eventEncoderStream - ); + std::move(eventStreamHandler)); if(outcome.IsSuccess()) { handler(this, request, ${operation.name}Outcome(NoResult()), handlerContext); diff --git a/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyJsonServiceOperationsSource.vm b/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyJsonServiceOperationsSource.vm index a4afd49fb06..cc9d699d159 100644 --- a/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyJsonServiceOperationsSource.vm +++ b/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyJsonServiceOperationsSource.vm @@ -1,6 +1,9 @@ #foreach($operation in $serviceModel.operations) #set($hasEndPointOverrides = false) -## todo: add support for request stream +#if($operation.request.shape.hasEventStreamMembers()) +#set($constText = "") +#parse("com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyJsonServiceEventStreamOperationsSource.vm") +#else #if($operation.result.shape.hasEventStreamMembers()) #set($constText = "") #else @@ -8,6 +11,7 @@ #end #parse("com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyServiceOperationsSource.vm") #end +#end #foreach($presigner in $serviceModel.presigners) Aws::Utils::Outcome ${className}::${presigner.functionName}(const Aws::String& ${presigner.hostNameVarName}, const Aws::String& ${presigner.regionVarName},#foreach($arg in $presigner.queryParams)#if(${arg.variableName}) ${arg.type} ${arg.variableName},#end#end long long expiresIn) { diff --git a/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyUriRequestQueryParams.vm b/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyUriRequestQueryParams.vm index 13885b3eba4..f1587a7e713 100644 --- a/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyUriRequestQueryParams.vm +++ b/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyUriRequestQueryParams.vm @@ -1,4 +1,4 @@ -#if(($serviceNamespace == "S3Crt" && $operation.s3CrtEnabled) || $operation.getRequest().getShape().hasEventStreamMembers()) +#if(($serviceNamespace == "S3Crt" && $operation.s3CrtEnabled)) #set($meterNeeded = true) #set($indent = "") #else diff --git a/tools/scripts/codegen/model_utils.py b/tools/scripts/codegen/model_utils.py index 9e3ce817af4..80a48ade68a 100644 --- a/tools/scripts/codegen/model_utils.py +++ b/tools/scripts/codegen/model_utils.py @@ -39,7 +39,7 @@ # bidirectional streaming , "lexv2-runtime" , "qbusiness" - , "transcribestreaming" + #, "transcribestreaming" , "s3-crt" , "s3" , "s3control" @@ -180,8 +180,8 @@ def is_smithy_enabled(service_id, models_dir, c2j_model_filename): with open(models_dir + "/" + c2j_model_filename, 'r') as json_file: model = json.load(json_file) model_protocol = model.get("metadata", dict()).get("protocol", "UNKNOWN_PROTOCOL") - #if model_protocol in {"json", "rest-json", "rest-xml", "query"}: - # use_smithy = True + if model_protocol in {"json", "rest-json", "rest-xml", "query"}: + use_smithy = True return use_smithy @staticmethod