caldempsey commented on code in PR #152: URL: https://github.com/apache/spark-connect-go/pull/152#discussion_r2319805765
########## spark/client/client.go: ########## @@ -434,6 +443,151 @@ func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) { } } +func (c *ExecutePlanClient) ToRecordBatches(ctx context.Context) (<-chan arrow.Record, <-chan error, *types.StructType) { + recordChan := make(chan arrow.Record, 10) + errorChan := make(chan error, 1) + + go func() { + defer func() { + // Ensure channels are always closed to prevent goroutine leaks + close(recordChan) + close(errorChan) + }() + + // Explicitly needed when tracking re-attachable execution. + c.done = false + + for { + // Check for context cancellation before each iteration + select { + case <-ctx.Done(): + // Context cancelled - send the error and return immediately + select { + case errorChan <- ctx.Err(): + default: + // Channel might be full, but we're exiting anyway + } + return + default: + // Continue with normal processing + } + + resp, err := c.responseStream.Recv() + + // Check for context cancellation after potentially blocking operations + select { + case <-ctx.Done(): + select { + case errorChan <- ctx.Err(): + default: + } + return + default: + } + + // EOF is received when the last message has been processed and the stream + // finished normally. Handle this FIRST, before any other processing. + if errors.Is(err, io.EOF) { + return + } + + // If there's any other error, handle it + if err != nil { + if se := sparkerrors.FromRPCError(err); se != nil { + select { + case errorChan <- sparkerrors.WithType(se, sparkerrors.ExecutionError): + case <-ctx.Done(): + return + } + } else { + // Unknown error - still send it + select { + case errorChan <- err: + case <-ctx.Done(): + return + } + } + return + } + + // Only proceed if we have a valid response (no error) + if resp == nil { + continue + } + + // Check that the server returned the session ID that we were expecting + // and that it has not changed. + if resp.GetSessionId() != c.sessionId { + select { + case errorChan <- sparkerrors.WithType(&sparkerrors.InvalidServerSideSessionDetailsError{ + OwnSessionId: c.sessionId, + ReceivedSessionId: resp.GetSessionId(), + }, sparkerrors.InvalidServerSideSessionError): + case <-ctx.Done(): + return + } + return + } + + // Check if the response has already the schema set and if yes, convert + // the proto DataType to a StructType. + if resp.Schema != nil { + c.schema, err = types.ConvertProtoDataTypeToStructType(resp.Schema) + if err != nil { + select { + case errorChan <- sparkerrors.WithType(err, sparkerrors.ExecutionError): + case <-ctx.Done(): + return + } + return + } + } + + switch x := resp.ResponseType.(type) { + case *proto.ExecutePlanResponse_SqlCommandResult_: + if val := x.SqlCommandResult.GetRelation(); val != nil { + c.properties["sql_command_result"] = val + } + + case *proto.ExecutePlanResponse_ArrowBatch_: + // This is what we want - stream the record batch + record, err := types.ReadArrowBatchToRecord(x.ArrowBatch.Data, c.schema) + if err != nil { + select { + case errorChan <- err: + case <-ctx.Done(): + return + } + return + } + + // Try to send the record, but respect context cancellation + select { + case recordChan <- record: + // Successfully sent Review Comment: left this for now, I think we can punt this down the road -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org