kou commented on code in PR #36009:
URL: https://github.com/apache/arrow/pull/36009#discussion_r1235921540
##########
go/arrow/flight/client.go:
##########
@@ -348,10 +352,87 @@ func (c *client) Authenticate(ctx context.Context, opts
...grpc.CallOption) erro
return c.authHandler.Authenticate(ctx, &clientAuthConn{stream})
}
+// Ensure the result of a DoAction is fully consumed
+func ReadUntilEOF(stream FlightService_DoActionClient) error {
Review Comment:
Yes. I wanted to use this in here and flightsql.
> Godoc format says that the comment should start with the name of the
function.
Oh, sorry. I missed it. I use your suggestion as-is. Thanks!
##########
go/arrow/flight/flightsql/server.go:
##########
@@ -511,8 +511,36 @@ func (BaseServer) BeginSavepoint(context.Context,
ActionBeginSavepointRequest) (
return nil, status.Error(codes.Unimplemented, "BeginSavepoint not
implemented")
}
-func (BaseServer) CancelQuery(context.Context, ActionCancelQueryRequest)
(CancelResult, error) {
- return CancelResultUnspecified, status.Error(codes.Unimplemented,
"CancelQuery not implemented")
+func (b *BaseServer) CancelQuery(context context.Context, request
ActionCancelQueryRequest) (CancelResult, error) {
+ result, err := b.CancelFlightInfo(context, request.GetInfo())
+ if err != nil {
+ return CancelResultUnspecified, err
+ }
Review Comment:
Oh, thanks.
I removed `CancelQuery` because we can simplify our implementation. Is it
acceptable from the existing Go users because `CancelQuery` is an experimental
API?
##########
go/arrow/flight/flightsql/server.go:
##########
@@ -640,7 +668,17 @@ type Server interface {
// EndTransaction commits or rollsback a transaction
EndTransaction(context.Context, ActionEndTransactionRequest) error
// CancelQuery attempts to explicitly cancel a query
+ // Deprecated: Since 13.0.0. If you can require all clients
+ // use 13.0.0 or later, you can use only CancelFlightInfo and
+ // you don't need to use CancelQuery. Otherwise, you may need
+ // to use CancelQuery and/or CancelFlightInfo.
Review Comment:
You're right.
The document is wrong. Users need only implement `CancelFlightInfo` and it
works with `CancelQuery` and `CancelFlightInfo` from clients.
##########
cpp/src/arrow/flight/integration_tests/test_integration.cc:
##########
@@ -410,6 +413,470 @@ class OrderedScenario : public Scenario {
}
};
+/// \brief The server used for testing FlightEndpoint.expiration_time.
+///
+/// GetFlightInfo() returns a FlightInfo that has the following
+/// three FlightEndpoints:
+///
+/// 1. No expiration time
+/// 2. 2 seconds expiration time
+/// 3. 3 seconds expiration time
+///
+/// The client can't read data from the first endpoint multiple times
+/// but can read data from the second and third endpoints. The client
+/// can't re-read data from the second endpoint 2 seconds later. The
+/// client can't re-read data from the third endpoint 3 seconds
+/// later.
+///
+/// The client can cancel a returned FlightInfo by pre-defined
+/// CancelFlightInfo action. The client can't read data from endpoints
+/// even within 3 seconds after the action.
+///
+/// The client can extend the expiration time of a FlightEndpoint in
+/// a returned FlightInfo by pre-defined RefreshFlightEndpoint
+/// action. The client can read data from endpoints multiple times
+/// within more 10 seconds after the action.
+///
+/// The client can close a returned FlightInfo explicitly by
+/// pre-defined CloseFlightInfo action. The client can't read data
+/// from endpoints even within 3 seconds after the action.
+class ExpirationTimeServer : public FlightServerBase {
+ private:
+ struct EndpointStatus {
+ explicit EndpointStatus(std::optional<Timestamp> expiration_time)
+ : expiration_time(expiration_time) {}
+
+ std::optional<Timestamp> expiration_time;
+ uint32_t num_gets = 0;
+ bool cancelled = false;
+ bool closed = false;
+ };
+
+ public:
+ ExpirationTimeServer() : FlightServerBase(), statuses_() {}
+
+ Status GetFlightInfo(const ServerCallContext& context,
+ const FlightDescriptor& descriptor,
+ std::unique_ptr<FlightInfo>* result) override {
+ statuses_.clear();
+ auto schema = BuildSchema();
+ std::vector<FlightEndpoint> endpoints;
+ AddEndpoint(endpoints, "No expiration time", std::nullopt);
+ AddEndpoint(endpoints, "2 seconds",
+ Timestamp::clock::now() + std::chrono::seconds{2});
+ AddEndpoint(endpoints, "3 seconds",
+ Timestamp::clock::now() + std::chrono::seconds{3});
+ ARROW_ASSIGN_OR_RAISE(
+ auto info, FlightInfo::Make(*schema, descriptor, endpoints, -1, -1,
false));
+ *result = std::make_unique<FlightInfo>(info);
+ return Status::OK();
+ }
+
+ Status DoGet(const ServerCallContext& context, const Ticket& request,
+ std::unique_ptr<FlightDataStream>* stream) override {
+ ARROW_ASSIGN_OR_RAISE(auto index, ExtractIndexFromTicket(request.ticket));
+ auto& status = statuses_[index];
+ if (status.closed) {
+ return Status::KeyError("Invalid flight: closed: ", request.ticket);
+ }
+ if (status.cancelled) {
+ return Status::KeyError("Invalid flight: canceled: ", request.ticket);
+ }
+ if (status.expiration_time.has_value()) {
+ auto expiration_time = status.expiration_time.value();
+ if (expiration_time < Timestamp::clock::now()) {
+ return Status::KeyError("Invalid flight: expired: ", request.ticket);
+ }
+ } else {
+ if (status.num_gets > 0) {
+ return Status::KeyError("Invalid flight: can't read multiple times: ",
+ request.ticket);
+ }
+ }
+ status.num_gets++;
+ ARROW_ASSIGN_OR_RAISE(auto builder, RecordBatchBuilder::Make(
+ BuildSchema(),
arrow::default_memory_pool()));
+ auto number_builder = builder->GetFieldAs<UInt32Builder>(0);
+ ARROW_RETURN_NOT_OK(number_builder->Append(index));
+ ARROW_ASSIGN_OR_RAISE(auto record_batch, builder->Flush());
+ std::vector<std::shared_ptr<RecordBatch>> record_batches{record_batch};
+ ARROW_ASSIGN_OR_RAISE(auto record_batch_reader,
+ RecordBatchReader::Make(record_batches));
+ *stream = std::make_unique<RecordBatchStream>(record_batch_reader);
+ return Status::OK();
+ }
+
+ Status DoAction(const ServerCallContext& context, const Action& action,
+ std::unique_ptr<ResultStream>* result_stream) override {
+ std::vector<Result> results;
+ if (action.type == ActionType::kCancelFlightInfo.type) {
+ ARROW_ASSIGN_OR_RAISE(auto info,
+
FlightInfo::Deserialize(std::string_view(*action.body)));
+ for (const auto& endpoint : info->endpoints()) {
+ auto index_result = ExtractIndexFromTicket(endpoint.ticket.ticket);
+ auto cancel_status = CancelStatus::kUnspecified;
+ if (index_result.ok()) {
+ auto index = *index_result;
+ if (statuses_[index].cancelled) {
+ cancel_status = CancelStatus::kNotCancellable;
+ } else {
+ statuses_[index].cancelled = true;
+ cancel_status = CancelStatus::kCancelled;
+ }
+ } else {
+ cancel_status = CancelStatus::kNotCancellable;
+ }
+ auto cancel_result = CancelFlightInfoResult{cancel_status};
+ ARROW_ASSIGN_OR_RAISE(auto serialized,
cancel_result.SerializeToString());
+ results.push_back(Result{Buffer::FromString(std::move(serialized))});
+ }
Review Comment:
Ah, sorry. It should have returned one result. I'll fix it.
##########
go/arrow/flight/client.go:
##########
@@ -348,10 +352,87 @@ func (c *client) Authenticate(ctx context.Context, opts
...grpc.CallOption) erro
return c.authHandler.Authenticate(ctx, &clientAuthConn{stream})
}
+// Ensure the result of a DoAction is fully consumed
+func ReadUntilEOF(stream FlightService_DoActionClient) error {
+ for {
+ _, err := stream.Recv()
+ if err == io.EOF {
+ return nil
+ } else if err != nil {
+ return err
+ }
+ }
+}
+
+func (c *client) CancelFlightInfo(ctx context.Context, info *FlightInfo, opts
...grpc.CallOption) (result CancelFlightInfoResult, err error) {
+ var action flight.Action
+ action.Type = CancelFlightInfoActionType
+ action.Body, err = proto.Marshal(info)
+ if err != nil {
+ return
+ }
+ stream, err := c.DoAction(ctx, &action, opts...)
+ if err != nil {
+ return
+ }
+ res, err := stream.Recv()
+ if err != nil {
+ return
+ }
+ if err = proto.Unmarshal(res.Body, &result); err != nil {
+ return
+ }
+ err = ReadUntilEOF(stream)
+ return
+}
+
func (c *client) Close() error {
c.FlightServiceClient = nil
if cl, ok := c.conn.(io.Closer); ok {
return cl.Close()
}
return nil
}
+
+func (c *client) CloseFlightInfo(ctx context.Context, info *FlightInfo, opts
...grpc.CallOption) (err error) {
+ var action flight.Action
+ action.Type = CloseFlightInfoActionType
+ action.Body, err = proto.Marshal(info)
+ if err != nil {
+ return
+ }
+ stream, err := c.DoAction(ctx, &action, opts...)
+ if err != nil {
+ return
+ }
+ err = ReadUntilEOF(stream)
+ return
Review Comment:
Thanks!
I didn't know that we can mix named return value and `return XXX` style.
##########
cpp/src/arrow/flight/integration_tests/test_integration.cc:
##########
@@ -410,6 +413,470 @@ class OrderedScenario : public Scenario {
}
};
+/// \brief The server used for testing FlightEndpoint.expiration_time.
+///
+/// GetFlightInfo() returns a FlightInfo that has the following
+/// three FlightEndpoints:
+///
+/// 1. No expiration time
+/// 2. 2 seconds expiration time
+/// 3. 3 seconds expiration time
+///
+/// The client can't read data from the first endpoint multiple times
+/// but can read data from the second and third endpoints. The client
+/// can't re-read data from the second endpoint 2 seconds later. The
+/// client can't re-read data from the third endpoint 3 seconds
+/// later.
+///
+/// The client can cancel a returned FlightInfo by pre-defined
+/// CancelFlightInfo action. The client can't read data from endpoints
+/// even within 3 seconds after the action.
+///
+/// The client can extend the expiration time of a FlightEndpoint in
+/// a returned FlightInfo by pre-defined RefreshFlightEndpoint
+/// action. The client can read data from endpoints multiple times
+/// within more 10 seconds after the action.
+///
+/// The client can close a returned FlightInfo explicitly by
+/// pre-defined CloseFlightInfo action. The client can't read data
+/// from endpoints even within 3 seconds after the action.
+class ExpirationTimeServer : public FlightServerBase {
+ private:
+ struct EndpointStatus {
+ explicit EndpointStatus(std::optional<Timestamp> expiration_time)
+ : expiration_time(expiration_time) {}
+
+ std::optional<Timestamp> expiration_time;
+ uint32_t num_gets = 0;
+ bool cancelled = false;
+ bool closed = false;
+ };
+
+ public:
+ ExpirationTimeServer() : FlightServerBase(), statuses_() {}
+
+ Status GetFlightInfo(const ServerCallContext& context,
+ const FlightDescriptor& descriptor,
+ std::unique_ptr<FlightInfo>* result) override {
+ statuses_.clear();
+ auto schema = BuildSchema();
+ std::vector<FlightEndpoint> endpoints;
+ AddEndpoint(endpoints, "No expiration time", std::nullopt);
+ AddEndpoint(endpoints, "2 seconds",
+ Timestamp::clock::now() + std::chrono::seconds{2});
+ AddEndpoint(endpoints, "3 seconds",
+ Timestamp::clock::now() + std::chrono::seconds{3});
+ ARROW_ASSIGN_OR_RAISE(
+ auto info, FlightInfo::Make(*schema, descriptor, endpoints, -1, -1,
false));
+ *result = std::make_unique<FlightInfo>(info);
+ return Status::OK();
+ }
+
+ Status DoGet(const ServerCallContext& context, const Ticket& request,
+ std::unique_ptr<FlightDataStream>* stream) override {
+ ARROW_ASSIGN_OR_RAISE(auto index, ExtractIndexFromTicket(request.ticket));
+ auto& status = statuses_[index];
+ if (status.closed) {
+ return Status::KeyError("Invalid flight: closed: ", request.ticket);
+ }
+ if (status.cancelled) {
+ return Status::KeyError("Invalid flight: canceled: ", request.ticket);
+ }
+ if (status.expiration_time.has_value()) {
+ auto expiration_time = status.expiration_time.value();
+ if (expiration_time < Timestamp::clock::now()) {
+ return Status::KeyError("Invalid flight: expired: ", request.ticket);
+ }
+ } else {
+ if (status.num_gets > 0) {
+ return Status::KeyError("Invalid flight: can't read multiple times: ",
+ request.ticket);
+ }
+ }
+ status.num_gets++;
+ ARROW_ASSIGN_OR_RAISE(auto builder, RecordBatchBuilder::Make(
+ BuildSchema(),
arrow::default_memory_pool()));
+ auto number_builder = builder->GetFieldAs<UInt32Builder>(0);
+ ARROW_RETURN_NOT_OK(number_builder->Append(index));
+ ARROW_ASSIGN_OR_RAISE(auto record_batch, builder->Flush());
+ std::vector<std::shared_ptr<RecordBatch>> record_batches{record_batch};
+ ARROW_ASSIGN_OR_RAISE(auto record_batch_reader,
+ RecordBatchReader::Make(record_batches));
+ *stream = std::make_unique<RecordBatchStream>(record_batch_reader);
+ return Status::OK();
+ }
+
+ Status DoAction(const ServerCallContext& context, const Action& action,
+ std::unique_ptr<ResultStream>* result_stream) override {
+ std::vector<Result> results;
+ if (action.type == ActionType::kCancelFlightInfo.type) {
+ ARROW_ASSIGN_OR_RAISE(auto info,
+
FlightInfo::Deserialize(std::string_view(*action.body)));
+ for (const auto& endpoint : info->endpoints()) {
+ auto index_result = ExtractIndexFromTicket(endpoint.ticket.ticket);
+ auto cancel_status = CancelStatus::kUnspecified;
+ if (index_result.ok()) {
+ auto index = *index_result;
+ if (statuses_[index].cancelled) {
+ cancel_status = CancelStatus::kNotCancellable;
+ } else {
+ statuses_[index].cancelled = true;
+ cancel_status = CancelStatus::kCancelled;
+ }
+ } else {
+ cancel_status = CancelStatus::kNotCancellable;
+ }
+ auto cancel_result = CancelFlightInfoResult{cancel_status};
+ ARROW_ASSIGN_OR_RAISE(auto serialized,
cancel_result.SerializeToString());
+ results.push_back(Result{Buffer::FromString(std::move(serialized))});
+ }
+ } else if (action.type == ActionType::kCloseFlightInfo.type) {
+ ARROW_ASSIGN_OR_RAISE(auto info,
+
FlightInfo::Deserialize(std::string_view(*action.body)));
+ for (const auto& endpoint : info->endpoints()) {
+ auto index_result = ExtractIndexFromTicket(endpoint.ticket.ticket);
+ if (!index_result.ok()) {
+ continue;
+ }
+ auto index = *index_result;
+ statuses_[index].closed = true;
+ }
+ } else if (action.type == ActionType::kRefreshFlightEndpoint.type) {
+ ARROW_ASSIGN_OR_RAISE(auto endpoint,
+
FlightEndpoint::Deserialize(std::string_view(*action.body)));
+ ARROW_ASSIGN_OR_RAISE(auto index,
ExtractIndexFromTicket(endpoint.ticket.ticket));
+ if (statuses_[index].cancelled) {
+ return Status::Invalid("Invalid flight: canceled: ",
endpoint.ticket.ticket);
+ }
+ endpoint.ticket.ticket += ": refreshed (+ 10 seconds)";
+ endpoint.expiration_time = Timestamp::clock::now() +
std::chrono::seconds{10};
+ statuses_[index].expiration_time = endpoint.expiration_time.value();
+ ARROW_ASSIGN_OR_RAISE(auto serialized, endpoint.SerializeToString());
+ results.push_back(Result{Buffer::FromString(std::move(serialized))});
+ } else {
+ return Status::Invalid("Unknown action: ", action.type);
+ }
+ *result_stream = std::make_unique<SimpleResultStream>(std::move(results));
+ return Status::OK();
+ }
+
+ Status ListActions(const ServerCallContext& context,
+ std::vector<ActionType>* actions) override {
+ *actions = {
+ ActionType::kCancelFlightInfo,
+ ActionType::kCloseFlightInfo,
+ ActionType::kRefreshFlightEndpoint,
+ };
+ return Status::OK();
+ }
+
+ private:
+ void AddEndpoint(std::vector<FlightEndpoint>& endpoints, std::string ticket,
+ std::optional<Timestamp> expiration_time) {
+ endpoints.push_back(FlightEndpoint{
+ {std::to_string(statuses_.size()) + ": " + ticket}, {},
expiration_time});
+ statuses_.emplace_back(expiration_time);
+ }
+
+ arrow::Result<uint32_t> ExtractIndexFromTicket(const std::string& ticket) {
+ auto index_string = arrow::internal::SplitString(ticket, ':', 2)[0];
+ uint32_t index;
+ if (!arrow::internal::ParseUnsigned(index_string.data(),
index_string.length(),
+ &index)) {
+ return Status::KeyError("Invalid flight: no index: ", ticket);
+ }
+ if (index >= statuses_.size()) {
+ return Status::KeyError("Invalid flight: out of index: ", ticket);
+ }
+ return index;
+ }
+
+ std::shared_ptr<Schema> BuildSchema() {
+ return arrow::schema({arrow::field("number", arrow::uint32(), false)});
+ }
+
+ std::vector<EndpointStatus> statuses_;
+};
+
+/// \brief The expiration time scenario - DoGet.
+///
+/// This tests that the client can read data that isn't expired yet
+/// multiple times and can't read data after it's expired.
+class ExpirationTimeDoGetScenario : public Scenario {
+ Status MakeServer(std::unique_ptr<FlightServerBase>* server,
+ FlightServerOptions* options) override {
+ *server = std::make_unique<ExpirationTimeServer>();
+ return Status::OK();
+ }
+
+ Status MakeClient(FlightClientOptions* options) override { return
Status::OK(); }
+
+ Status RunClient(std::unique_ptr<FlightClient> client) override {
+ ARROW_ASSIGN_OR_RAISE(
+ auto info,
client->GetFlightInfo(FlightDescriptor::Command("expiration_time")));
+ std::vector<std::shared_ptr<arrow::Table>> tables;
+ // First read from all endpoints
+ for (const auto& endpoint : info->endpoints()) {
+ ARROW_ASSIGN_OR_RAISE(auto reader, client->DoGet(endpoint.ticket));
+ ARROW_ASSIGN_OR_RAISE(auto table, reader->ToTable());
+ tables.push_back(table);
+ }
+ // Re-reads only from endpoints that have expiration time
+ for (const auto& endpoint : info->endpoints()) {
+ if (endpoint.expiration_time.has_value()) {
+ ARROW_ASSIGN_OR_RAISE(auto reader, client->DoGet(endpoint.ticket));
+ ARROW_ASSIGN_OR_RAISE(auto table, reader->ToTable());
+ tables.push_back(table);
+ } else {
+ auto reader = client->DoGet(endpoint.ticket);
+ if (reader.ok()) {
+ return Status::Invalid(
+ "Data that doesn't have expiration time "
+ "shouldn't be readable multiple times");
+ }
+ }
+ }
+ // Re-reads after expired
+ for (const auto& endpoint : info->endpoints()) {
+ if (!endpoint.expiration_time.has_value()) {
+ continue;
+ }
+ const auto& expiration_time = endpoint.expiration_time.value();
+ if (expiration_time > Timestamp::clock::now()) {
+ std::this_thread::sleep_for(expiration_time - Timestamp::clock::now());
+ }
Review Comment:
Yes, this might be flaky on very slow one CPU core CI machine.
But I hope that 2 seconds is enough time to process simple 1 `GetFlightInfo`
and (3 (first reads) + 2 (re-reads)) `DoGet()` requests on most machines.
> Maybe there's no need to test the actual expiration of the ticket?
It's an option. How about adding a "We may remove this check if this is
flaky in CI" comment to here? And we'll remove this check if this check is
failed in our CI.
> Or, we can treat the expiration as a counter instead of a real expiration
time to make things not dependent on the clock.
Hmm. It'll work but it may confuse us when we maintain our integration test.
Because expiration time is defined as a timestamp not a counter.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]