From c47ffcf1936e3507c022df92249f91437c3c46a4 Mon Sep 17 00:00:00 2001 From: wu-hui Date: Mon, 24 Nov 2025 15:50:22 -0500 Subject: [PATCH] Use grpc streaming reader for pipeline execution --- Firestore/core/src/remote/datastore.cc | 49 ++++++---- .../core/src/remote/grpc_streaming_reader.cc | 7 +- .../core/src/remote/grpc_streaming_reader.h | 5 +- .../core/src/remote/remote_objc_bridge.cc | 44 +++++++++ .../core/src/remote/remote_objc_bridge.h | 4 + .../unit/remote/grpc_streaming_reader_test.cc | 95 ++++++++++++++++--- 6 files changed, 168 insertions(+), 36 deletions(-) diff --git a/Firestore/core/src/remote/datastore.cc b/Firestore/core/src/remote/datastore.cc index 60d8d6e0764..c8b58e09325 100644 --- a/Firestore/core/src/remote/datastore.cc +++ b/Firestore/core/src/remote/datastore.cc @@ -321,7 +321,7 @@ void Datastore::RunPipeline( const StatusOr& auth_token, const std::string& app_check_token) mutable { if (!auth_token.ok()) { - // result_callback(auth_token.status()); + result_callback(auth_token.status()); return; } RunPipelineWithCredentials(auth_token.ValueOrDie(), app_check_token, @@ -338,27 +338,40 @@ void Datastore::RunPipelineWithCredentials( LOG_DEBUG("Run Pipeline: %s", request.ToString()); grpc::ByteBuffer message = MakeByteBuffer(request); - std::unique_ptr call_owning = grpc_connection_.CreateUnaryCall( - kRpcNameExecutePipeline, auth_token, app_check_token, std::move(message)); - GrpcUnaryCall* call = call_owning.get(); + std::unique_ptr call_owning = + grpc_connection_.CreateStreamingReader(kRpcNameExecutePipeline, + auth_token, app_check_token, + std::move(message)); + GrpcStreamingReader* call = call_owning.get(); active_calls_.push_back(std::move(call_owning)); - call->Start( - [this, db = pipeline.firestore(), call, callback = std::move(callback)]( - const StatusOr& result) { - LogGrpcCallFinished("ExecutePipeline", call, result.status()); - HandleCallStatus(result.status()); + auto responses_callback = [this, db = pipeline.firestore(), callback]( + const std::vector& result) { + if (result.empty()) { + callback(util::Status(Error::kErrorInternal, + "Received empty response for RunPipeline")); + return; + } - if (result.ok()) { - auto response = datastore_serializer_.DecodeExecutePipelineResponse( - result.ValueOrDie(), std::move(db)); - callback(response); - } else { - callback(result.status()); - } + auto response = datastore_serializer_.MergeExecutePipelineResponses( + result, std::move(db)); + callback(response); + }; - RemoveGrpcCall(call); - }); + auto close_callback = [this, call, callback](const util::Status& status, + bool callback_fired) { + if (!callback_fired) { + callback(status); + } + if (!status.ok()) { + LogGrpcCallFinished("ExecutePipeline", call, status); + HandleCallStatus(status); + } + RemoveGrpcCall(call); + }; + + call->Start(util::Status(Error::kErrorUnknown, "Unknown response count"), + responses_callback, close_callback); } void Datastore::ResumeRpcWithCredentials(const OnCredentials& on_credentials) { diff --git a/Firestore/core/src/remote/grpc_streaming_reader.cc b/Firestore/core/src/remote/grpc_streaming_reader.cc index 7f10bc2be4c..ee581666213 100644 --- a/Firestore/core/src/remote/grpc_streaming_reader.cc +++ b/Firestore/core/src/remote/grpc_streaming_reader.cc @@ -45,10 +45,10 @@ GrpcStreamingReader::GrpcStreamingReader( request_{request} { } -void GrpcStreamingReader::Start(size_t expected_response_count, +void GrpcStreamingReader::Start(util::StatusOr expected_response_count, ResponsesCallback&& responses_callback, CloseCallback&& close_callback) { - expected_response_count_ = expected_response_count; + expected_response_count_ = std::move(expected_response_count); responses_callback_ = std::move(responses_callback); close_callback_ = std::move(close_callback); stream_->Start(); @@ -72,7 +72,8 @@ void GrpcStreamingReader::OnStreamRead(const grpc::ByteBuffer& message) { // Accumulate responses, responses_callback_ will be fired if // GrpcStreamingReader has received all the responses. responses_.push_back(message); - if (responses_.size() == expected_response_count_) { + if (expected_response_count_.ok() && + responses_.size() == expected_response_count_.ValueOrDie()) { callback_fired_ = true; responses_callback_(responses_); } diff --git a/Firestore/core/src/remote/grpc_streaming_reader.h b/Firestore/core/src/remote/grpc_streaming_reader.h index 6fbe4837e0f..658faf3f7dc 100644 --- a/Firestore/core/src/remote/grpc_streaming_reader.h +++ b/Firestore/core/src/remote/grpc_streaming_reader.h @@ -26,6 +26,7 @@ #include "Firestore/core/src/remote/grpc_stream_observer.h" #include "Firestore/core/src/util/status.h" #include "Firestore/core/src/util/status_fwd.h" +#include "Firestore/core/src/util/statusor.h" #include "Firestore/core/src/util/warnings.h" #include "grpcpp/client_context.h" #include "grpcpp/support/byte_buffer.h" @@ -62,7 +63,7 @@ class GrpcStreamingReader : public GrpcCall, public GrpcStreamObserver { * results of the call. If the call fails, the `callback` will be invoked with * a non-ok status. */ - void Start(size_t expected_response_count, + void Start(util::StatusOr expected_response_count, ResponsesCallback&& responses_callback, CloseCallback&& close_callback); @@ -103,7 +104,7 @@ class GrpcStreamingReader : public GrpcCall, public GrpcStreamObserver { std::unique_ptr stream_; grpc::ByteBuffer request_; - size_t expected_response_count_; + util::StatusOr expected_response_count_; bool callback_fired_ = false; ResponsesCallback responses_callback_; CloseCallback close_callback_; diff --git a/Firestore/core/src/remote/remote_objc_bridge.cc b/Firestore/core/src/remote/remote_objc_bridge.cc index b0ab0b5aab1..27faaa171d4 100644 --- a/Firestore/core/src/remote/remote_objc_bridge.cc +++ b/Firestore/core/src/remote/remote_objc_bridge.cc @@ -426,6 +426,50 @@ DatastoreSerializer::DecodeExecutePipelineResponse( return snapshot; } +util::StatusOr +DatastoreSerializer::MergeExecutePipelineResponses( + const std::vector& responses, + std::shared_ptr db) const { + std::vector all_results; + model::SnapshotVersion execution_time = model::SnapshotVersion::None(); + + for (const auto& response : responses) { + ByteBufferReader reader{response}; + auto message = + Message::TryParse(&reader); + if (!reader.ok()) { + return reader.status(); + } + + // DecodePipelineResponse decodes the whole message into a Snapshot. + // We can reuse it to get the partial results and execution time. + auto partial_snapshot = + serializer_.DecodePipelineResponse(reader.context(), message); + if (!reader.ok()) { + return reader.status(); + } + + // Accumulate results + // PipelineSnapshot::results() returns a const ref. We need to copy. + // But PipelineResult should be copyable/movable. + for (const auto& result : partial_snapshot.results()) { + all_results.push_back(result); + } + + // Update execution time if present. + // DecodePipelineResponse returns SnapshotVersion::None() if not present? + // Let's assume the last non-None execution time is the correct one, or just + // update it. + if (partial_snapshot.execution_time() != model::SnapshotVersion::None()) { + execution_time = partial_snapshot.execution_time(); + } + } + + api::PipelineSnapshot merged_snapshot{std::move(all_results), execution_time}; + merged_snapshot.SetFirestore(std::move(db)); + return merged_snapshot; +} + } // namespace remote } // namespace firestore } // namespace firebase diff --git a/Firestore/core/src/remote/remote_objc_bridge.h b/Firestore/core/src/remote/remote_objc_bridge.h index 2d25487e9ec..962ea7e3644 100644 --- a/Firestore/core/src/remote/remote_objc_bridge.h +++ b/Firestore/core/src/remote/remote_objc_bridge.h @@ -164,6 +164,10 @@ class DatastoreSerializer { const grpc::ByteBuffer& response, std::shared_ptr db) const; + util::StatusOr MergeExecutePipelineResponses( + const std::vector& responses, + std::shared_ptr db) const; + private: Serializer serializer_; }; diff --git a/Firestore/core/test/unit/remote/grpc_streaming_reader_test.cc b/Firestore/core/test/unit/remote/grpc_streaming_reader_test.cc index 461bbed5d14..45171b398d1 100644 --- a/Firestore/core/test/unit/remote/grpc_streaming_reader_test.cc +++ b/Firestore/core/test/unit/remote/grpc_streaming_reader_test.cc @@ -74,10 +74,10 @@ class GrpcStreamingReaderTest : public testing::Test { tester.KeepPollingGrpcQueue(); } - void StartReader(size_t expected_response_count) { + void StartReader(util::StatusOr expected_response_count) { worker_queue->EnqueueBlocking([&] { reader->Start( - expected_response_count, + std::move(expected_response_count), [&](std::vector result) { responses = std::move(result); }, @@ -101,7 +101,7 @@ TEST_F(GrpcStreamingReaderTest, FinishImmediatelyIsIdempotent) { worker_queue->EnqueueBlocking( [&] { EXPECT_NO_THROW(reader->FinishImmediately()); }); - StartReader(0); + StartReader(util::StatusOr(0)); KeepPollingGrpcQueue(); worker_queue->EnqueueBlocking([&] { @@ -114,12 +114,12 @@ TEST_F(GrpcStreamingReaderTest, FinishImmediatelyIsIdempotent) { // Method prerequisites -- correct usage of `GetResponseHeaders` TEST_F(GrpcStreamingReaderTest, CanGetResponseHeadersAfterStarting) { - StartReader(0); + StartReader(util::StatusOr(0)); EXPECT_NO_THROW(reader->GetResponseHeaders()); } TEST_F(GrpcStreamingReaderTest, CanGetResponseHeadersAfterFinishing) { - StartReader(0); + StartReader(util::StatusOr(0)); KeepPollingGrpcQueue(); worker_queue->EnqueueBlocking([&] { @@ -139,7 +139,7 @@ TEST_F(GrpcStreamingReaderTest, CannotFinishAndNotifyBeforeStarting) { // Normal operation TEST_F(GrpcStreamingReaderTest, OneSuccessfulRead) { - StartReader(1); + StartReader(util::StatusOr(1)); ForceFinishAnyTypeOrder({ {Type::Write, CompletionResult::Ok}, @@ -158,7 +158,7 @@ TEST_F(GrpcStreamingReaderTest, OneSuccessfulRead) { } TEST_F(GrpcStreamingReaderTest, TwoSuccessfulReads) { - StartReader(2); + StartReader(util::StatusOr(2)); ForceFinishAnyTypeOrder({ {Type::Write, CompletionResult::Ok}, @@ -178,7 +178,7 @@ TEST_F(GrpcStreamingReaderTest, TwoSuccessfulReads) { } TEST_F(GrpcStreamingReaderTest, FinishWhileReading) { - StartReader(1); + StartReader(util::StatusOr(1)); ForceFinishAnyTypeOrder({{Type::Write, CompletionResult::Ok}, {Type::Read, CompletionResult::Ok}}); @@ -194,7 +194,7 @@ TEST_F(GrpcStreamingReaderTest, FinishWhileReading) { // Errors TEST_F(GrpcStreamingReaderTest, ErrorOnWrite) { - StartReader(1); + StartReader(util::StatusOr(1)); bool failed_write = false; auto future = tester.ForceFinishAsync([&](GrpcCompletion* completion) { @@ -230,7 +230,7 @@ TEST_F(GrpcStreamingReaderTest, ErrorOnWrite) { } TEST_F(GrpcStreamingReaderTest, ErrorOnFirstRead) { - StartReader(1); + StartReader(util::StatusOr(1)); ForceFinishAnyTypeOrder({ {Type::Write, CompletionResult::Ok}, @@ -245,7 +245,7 @@ TEST_F(GrpcStreamingReaderTest, ErrorOnFirstRead) { } TEST_F(GrpcStreamingReaderTest, ErrorOnSecondRead) { - StartReader(2); + StartReader(util::StatusOr(2)); ForceFinishAnyTypeOrder({ {Type::Write, CompletionResult::Ok}, @@ -259,12 +259,81 @@ TEST_F(GrpcStreamingReaderTest, ErrorOnSecondRead) { EXPECT_TRUE(responses.empty()); } +TEST_F(GrpcStreamingReaderTest, + UnknownResponseCountReceivesAllMessagesOnFinish) { + // Use Status(Error::kErrorUnknown) to signify unknown response count + StartReader(util::Status(Error::kErrorUnknown, "Unknown response count")); + + // Send some messages + ForceFinishAnyTypeOrder({ + {Type::Write, CompletionResult::Ok}, + {Type::Read, MakeByteBuffer("msg1")}, + {Type::Read, MakeByteBuffer("msg2")}, + /*Read after last*/ {Type::Read, CompletionResult::Error}, + }); + + // At this point, responses_callback_ should NOT have been fired because + // expected_response_count_ is not 'ok'. + EXPECT_TRUE(responses.empty()); + EXPECT_FALSE(status.has_value()); + + // Now, finish the stream successfully. This should trigger the + // responses_callback_ with all accumulated messages. + ForceFinish({{Type::Finish, grpc::Status::OK}}); + + ASSERT_TRUE(status.has_value()); + EXPECT_EQ(status.value(), Status::OK()); + ASSERT_EQ(responses.size(), 2); + EXPECT_EQ(ByteBufferToString(responses[0]), std::string{"msg1"}); + EXPECT_EQ(ByteBufferToString(responses[1]), std::string{"msg2"}); +} + +TEST_F(GrpcStreamingReaderTest, + UnknownResponseCountReceivesEmptyOnFinishWithNoReads) { + StartReader(util::Status(Error::kErrorUnknown, "Unknown response count")); + + ForceFinishAnyTypeOrder({ + {Type::Write, CompletionResult::Ok}, + /*Read after last*/ {Type::Read, CompletionResult::Error}, + }); + + EXPECT_TRUE(responses.empty()); + EXPECT_FALSE(status.has_value()); + + ForceFinish({{Type::Finish, grpc::Status::OK}}); + + ASSERT_TRUE(status.has_value()); + EXPECT_EQ(status.value(), Status::OK()); + ASSERT_TRUE(responses.empty()); // Should still be empty, but callback fired +} + +TEST_F(GrpcStreamingReaderTest, UnknownResponseCountErrorOnFinish) { + StartReader(util::Status(Error::kErrorUnknown, "Unknown response count")); + + ForceFinishAnyTypeOrder({ + {Type::Write, CompletionResult::Ok}, + {Type::Read, MakeByteBuffer("msg1")}, + /*Read after last*/ {Type::Read, CompletionResult::Error}, + }); + + EXPECT_TRUE(responses.empty()); + EXPECT_FALSE(status.has_value()); + + grpc::Status error_status{grpc::StatusCode::DATA_LOSS, "Bad stream"}; + ForceFinish({{Type::Finish, error_status}}); + + ASSERT_TRUE(status.has_value()); + EXPECT_EQ(status.value().code(), Error::kErrorDataLoss); + EXPECT_TRUE( + responses.empty()); // responses_callback_ should not be fired on error +} + // Callback destroys reader TEST_F(GrpcStreamingReaderTest, CallbackCanDestroyReaderOnSuccess) { worker_queue->EnqueueBlocking([&] { reader->Start( - 1, [&](std::vector) {}, + util::StatusOr(1), [&](std::vector) {}, [&](const util::Status&, bool) { reader.reset(); }); }); @@ -282,7 +351,7 @@ TEST_F(GrpcStreamingReaderTest, CallbackCanDestroyReaderOnSuccess) { TEST_F(GrpcStreamingReaderTest, CallbackCanDestroyReaderOnError) { worker_queue->EnqueueBlocking([&] { reader->Start( - 1, [&](std::vector) {}, + util::StatusOr(1), [&](std::vector) {}, [&](const util::Status&, bool) { reader.reset(); }); });