niebayes commented on issue #7248:
URL: https://github.com/apache/arrow-rs/issues/7248#issuecomment-2756544811
@tustvold
Sorry for late response. I have written a unit test to reproduce the issue.
``` rust
#[cfg(test)]
mod tests {
use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::Arc;
use arrow::util::data_gen::create_random_batch;
use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::flight_service_server::FlightService;
use arrow_flight::sql::TicketStatementQuery;
use arrow_flight::sql::client::FlightSqlServiceClient;
use arrow_flight::sql::server::FlightSqlService;
use arrow_flight::{
Ticket,
flight_service_server::FlightServiceServer,
sql::{Any, SqlInfo},
};
use arrow_schema::{DataType, Field, Schema};
use futures::StreamExt;
use futures::TryStreamExt;
use prost::Message;
use tokio::sync::oneshot;
use tonic::Response;
use tonic::transport::Endpoint;
use tonic::{Request, Status, transport::Server};
#[derive(Clone)]
struct DummyFlightSqlServer;
#[tonic::async_trait]
impl FlightSqlService for DummyFlightSqlServer {
type FlightService = DummyFlightSqlServer;
/// Get a FlightDataStream containing the query results.
async fn do_get_statement(
&self,
_ticket: TicketStatementQuery,
_request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Float64, false),
]));
let batches = (0..100)
.map(|_| create_random_batch(schema.clone(), 128, 0.0,
0.0).unwrap())
.collect::<Vec<_>>();
let stream = futures::stream::iter(batches).map(|x|
Ok(x)).inspect(|x| {
// This log should never print when the stream is not
consumed.
println!("consume batch of {} rows",
x.as_ref().unwrap().num_rows())
});
let output = FlightDataEncoderBuilder::new()
.with_schema(schema)
.build(stream);
Ok(Response::new(output.map_err(Status::from).boxed()))
}
async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}
}
#[tokio::test]
async fn test_flight_sql_lazy_stream() {
let addr: SocketAddr = "127.0.0.1:4000".parse().unwrap();
let (tx, rx) = oneshot::channel::<()>();
// Starts a flight sql server.
let server_handle = tokio::spawn(async move {
Server::builder()
.add_service(FlightServiceServer::new(DummyFlightSqlServer
{}))
.serve_with_shutdown(addr, async {
rx.await.ok();
})
.await
.unwrap();
});
// Wait for the server to start
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
// Create a Flight SQL client and connect to the server
let channel = Endpoint::from_str("http://127.0.0.1:4000")
.unwrap()
.connect()
.await
.unwrap();
let mut client = FlightSqlServiceClient::new(channel);
// Calls do_get to get a stream of record batches.
let message = TicketStatementQuery {
statement_handle: "SELECT * FROM t".into(),
};
let ticket = Ticket {
ticket: Any::pack(&message).unwrap().encode_to_vec().into(),
};
let _output = client.do_get(ticket).await.unwrap();
// Wait for a while. The server should never consume the stream
during this period of time.
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
// Shutdown the server without consuming the stream.
tx.send(()).unwrap();
server_handle.await.unwrap();
}
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]