This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 3c013da56f GH-37795: [Java][FlightSQL] Add mock FlightSqlProducer and
tests (#37837)
3c013da56f is described below
commit 3c013da56fd55122072ffaca3e23afb12c290075
Author: James Duong <[email protected]>
AuthorDate: Mon Sep 25 06:17:26 2023 -0700
GH-37795: [Java][FlightSQL] Add mock FlightSqlProducer and tests (#37837)
### Rationale for this change
Clarify how to write a FlightSqlProducer with examples and helper classes.
This is more inline with what's available
to help developers write a FlightProducer.
### What changes are included in this PR?
Add helper classes for creating a No-op Flight SQL producer and a partially
implemented FlightSqlProducer that can process metadata requests.
Add a mock flight producer and tests for it based on the new
FlightSqlProducer partial implementations.
Clean-up missed closes of FlightStreams in TestFlightSql.
### Are these changes tested?
Yes.
### Are there any user-facing changes?
No.
* Closes: #37795
Authored-by: James Duong <[email protected]>
Signed-off-by: David Li <[email protected]>
---
.../arrow/flight/sql/BasicFlightSqlProducer.java | 109 ++++++
.../arrow/flight/sql/NoOpFlightSqlProducer.java | 221 +++++++++++
.../org/apache/arrow/flight/TestFlightSql.java | 432 ++++++++-------------
.../apache/arrow/flight/TestFlightSqlStreams.java | 288 ++++++++++++++
.../arrow/flight/sql/util/FlightStreamUtils.java | 129 ++++++
5 files changed, 917 insertions(+), 262 deletions(-)
diff --git
a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/BasicFlightSqlProducer.java
b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/BasicFlightSqlProducer.java
new file mode 100644
index 0000000000..ea99191f28
--- /dev/null
+++
b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/BasicFlightSqlProducer.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.sql;
+
+import java.util.List;
+
+import org.apache.arrow.flight.FlightDescriptor;
+import org.apache.arrow.flight.FlightEndpoint;
+import org.apache.arrow.flight.FlightInfo;
+import org.apache.arrow.flight.sql.impl.FlightSql;
+import org.apache.arrow.vector.types.pojo.Schema;
+
+import com.google.protobuf.Message;
+
+/**
+ * A {@link FlightSqlProducer} that implements getting FlightInfo for each
metadata request.
+ */
+public abstract class BasicFlightSqlProducer extends NoOpFlightSqlProducer {
+
+ @Override
+ public FlightInfo getFlightInfoSqlInfo(FlightSql.CommandGetSqlInfo request,
CallContext context,
+ FlightDescriptor descriptor) {
+ return generateFlightInfo(request, descriptor,
Schemas.GET_SQL_INFO_SCHEMA);
+ }
+
+ @Override
+ public FlightInfo getFlightInfoTypeInfo(FlightSql.CommandGetXdbcTypeInfo
request, CallContext context,
+ FlightDescriptor descriptor) {
+ return generateFlightInfo(request, descriptor,
Schemas.GET_TYPE_INFO_SCHEMA);
+ }
+
+ @Override
+ public FlightInfo getFlightInfoCatalogs(FlightSql.CommandGetCatalogs
request, CallContext context,
+ FlightDescriptor descriptor) {
+ return generateFlightInfo(request, descriptor,
Schemas.GET_CATALOGS_SCHEMA);
+ }
+
+ @Override
+ public FlightInfo getFlightInfoSchemas(FlightSql.CommandGetDbSchemas
request, CallContext context,
+ FlightDescriptor descriptor) {
+ return generateFlightInfo(request, descriptor, Schemas.GET_SCHEMAS_SCHEMA);
+ }
+
+ @Override
+ public FlightInfo getFlightInfoTables(FlightSql.CommandGetTables request,
CallContext context,
+ FlightDescriptor descriptor) {
+ if (request.getIncludeSchema()) {
+ return generateFlightInfo(request, descriptor,
Schemas.GET_TABLES_SCHEMA);
+ }
+ return generateFlightInfo(request, descriptor,
Schemas.GET_TABLES_SCHEMA_NO_SCHEMA);
+ }
+
+ @Override
+ public FlightInfo getFlightInfoTableTypes(FlightSql.CommandGetTableTypes
request, CallContext context,
+ FlightDescriptor descriptor) {
+ return generateFlightInfo(request, descriptor,
Schemas.GET_TABLE_TYPES_SCHEMA);
+ }
+
+ @Override
+ public FlightInfo getFlightInfoPrimaryKeys(FlightSql.CommandGetPrimaryKeys
request, CallContext context,
+ FlightDescriptor descriptor) {
+ return generateFlightInfo(request, descriptor,
Schemas.GET_PRIMARY_KEYS_SCHEMA);
+ }
+
+ @Override
+ public FlightInfo getFlightInfoExportedKeys(FlightSql.CommandGetExportedKeys
request, CallContext context,
+ FlightDescriptor descriptor) {
+ return generateFlightInfo(request, descriptor,
Schemas.GET_EXPORTED_KEYS_SCHEMA);
+ }
+
+ @Override
+ public FlightInfo getFlightInfoImportedKeys(FlightSql.CommandGetImportedKeys
request, CallContext context,
+ FlightDescriptor descriptor) {
+ return generateFlightInfo(request, descriptor,
Schemas.GET_IMPORTED_KEYS_SCHEMA);
+ }
+
+ @Override
+ public FlightInfo
getFlightInfoCrossReference(FlightSql.CommandGetCrossReference request,
CallContext context,
+ FlightDescriptor descriptor) {
+ return generateFlightInfo(request, descriptor,
Schemas.GET_CROSS_REFERENCE_SCHEMA);
+ }
+
+ /**
+ * Return a list of FlightEndpoints for the given request and
FlightDescriptor. This method should validate that
+ * the request is supported by this FlightSqlProducer.
+ */
+ protected abstract <T extends Message>
+ List<FlightEndpoint> determineEndpoints(T request, FlightDescriptor
flightDescriptor, Schema schema);
+
+ protected <T extends Message> FlightInfo generateFlightInfo(T request,
FlightDescriptor descriptor, Schema schema) {
+ final List<FlightEndpoint> endpoints = determineEndpoints(request,
descriptor, schema);
+ return new FlightInfo(schema, descriptor, endpoints, -1, -1);
+ }
+}
diff --git
a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/NoOpFlightSqlProducer.java
b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/NoOpFlightSqlProducer.java
new file mode 100644
index 0000000000..a02cee64bd
--- /dev/null
+++
b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/NoOpFlightSqlProducer.java
@@ -0,0 +1,221 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.sql;
+
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.Criteria;
+import org.apache.arrow.flight.FlightDescriptor;
+import org.apache.arrow.flight.FlightInfo;
+import org.apache.arrow.flight.FlightStream;
+import org.apache.arrow.flight.PutResult;
+import org.apache.arrow.flight.Result;
+import org.apache.arrow.flight.SchemaResult;
+import org.apache.arrow.flight.sql.impl.FlightSql;
+
+/**
+ * A {@link FlightSqlProducer} that throws on all FlightSql-specific
operations.
+ */
+public class NoOpFlightSqlProducer implements FlightSqlProducer {
+ @Override
+ public void
createPreparedStatement(FlightSql.ActionCreatePreparedStatementRequest request,
+ CallContext context,
StreamListener<Result> listener) {
+ listener.onError(CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException());
+ }
+
+ @Override
+ public void
closePreparedStatement(FlightSql.ActionClosePreparedStatementRequest request,
+ CallContext context,
StreamListener<Result> listener) {
+ listener.onError(CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException());
+ }
+
+ @Override
+ public FlightInfo getFlightInfoStatement(FlightSql.CommandStatementQuery
command,
+ CallContext context,
FlightDescriptor descriptor) {
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException();
+ }
+
+ @Override
+ public FlightInfo
getFlightInfoPreparedStatement(FlightSql.CommandPreparedStatementQuery command,
+ CallContext context,
FlightDescriptor descriptor) {
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException();
+ }
+
+ @Override
+ public SchemaResult getSchemaStatement(FlightSql.CommandStatementQuery
command,
+ CallContext context, FlightDescriptor
descriptor) {
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException();
+ }
+
+ @Override
+ public void getStreamStatement(FlightSql.TicketStatementQuery ticket,
+ CallContext context, ServerStreamListener
listener) {
+ listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException());
+ }
+
+ @Override
+ public void
getStreamPreparedStatement(FlightSql.CommandPreparedStatementQuery command,
+ CallContext context,
ServerStreamListener listener) {
+ listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException());
+ }
+
+ @Override
+ public Runnable acceptPutStatement(FlightSql.CommandStatementUpdate command,
CallContext context,
+ FlightStream flightStream,
StreamListener<PutResult> ackStream) {
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException();
+ }
+
+ @Override
+ public Runnable
acceptPutPreparedStatementUpdate(FlightSql.CommandPreparedStatementUpdate
command,
+ CallContext context,
FlightStream flightStream,
+ StreamListener<PutResult>
ackStream) {
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException();
+ }
+
+ @Override
+ public Runnable
acceptPutPreparedStatementQuery(FlightSql.CommandPreparedStatementQuery
command, CallContext context,
+ FlightStream flightStream,
StreamListener<PutResult> ackStream) {
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException();
+ }
+
+ @Override
+ public FlightInfo getFlightInfoSqlInfo(FlightSql.CommandGetSqlInfo request,
CallContext context,
+ FlightDescriptor descriptor) {
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException();
+ }
+
+ @Override
+ public void getStreamSqlInfo(FlightSql.CommandGetSqlInfo command,
CallContext context,
+ ServerStreamListener listener) {
+ listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException());
+ }
+
+ @Override
+ public FlightInfo getFlightInfoTypeInfo(FlightSql.CommandGetXdbcTypeInfo
request,
+ CallContext context,
FlightDescriptor descriptor) {
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException();
+ }
+
+ @Override
+ public void getStreamTypeInfo(FlightSql.CommandGetXdbcTypeInfo request,
+ CallContext context, ServerStreamListener
listener) {
+ listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException());
+ }
+
+ @Override
+ public FlightInfo getFlightInfoCatalogs(FlightSql.CommandGetCatalogs request,
+ CallContext context,
FlightDescriptor descriptor) {
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException();
+ }
+
+ @Override
+ public void getStreamCatalogs(CallContext context, ServerStreamListener
listener) {
+ listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException());
+ }
+
+ @Override
+ public FlightInfo getFlightInfoSchemas(FlightSql.CommandGetDbSchemas request,
+ CallContext context, FlightDescriptor
descriptor) {
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException();
+ }
+
+ @Override
+ public void getStreamSchemas(FlightSql.CommandGetDbSchemas command,
+ CallContext context, ServerStreamListener
listener) {
+ listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException());
+ }
+
+ @Override
+ public FlightInfo getFlightInfoTables(FlightSql.CommandGetTables request,
+ CallContext context, FlightDescriptor
descriptor) {
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException();
+ }
+
+ @Override
+ public void getStreamTables(FlightSql.CommandGetTables command, CallContext
context, ServerStreamListener listener) {
+ listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException());
+ }
+
+ @Override
+ public FlightInfo getFlightInfoTableTypes(FlightSql.CommandGetTableTypes
request, CallContext context,
+ FlightDescriptor descriptor) {
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException();
+ }
+
+ @Override
+ public void getStreamTableTypes(CallContext context, ServerStreamListener
listener) {
+ listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException());
+ }
+
+ @Override
+ public FlightInfo getFlightInfoPrimaryKeys(FlightSql.CommandGetPrimaryKeys
request,
+ CallContext context,
FlightDescriptor descriptor) {
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException();
+ }
+
+ @Override
+ public void getStreamPrimaryKeys(FlightSql.CommandGetPrimaryKeys command,
+ CallContext context, ServerStreamListener
listener) {
+ listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException());
+ }
+
+ @Override
+ public FlightInfo getFlightInfoExportedKeys(FlightSql.CommandGetExportedKeys
request,
+ CallContext context,
FlightDescriptor descriptor) {
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException();
+ }
+
+ @Override
+ public FlightInfo getFlightInfoImportedKeys(FlightSql.CommandGetImportedKeys
request,
+ CallContext context,
FlightDescriptor descriptor) {
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException();
+ }
+
+ @Override
+ public FlightInfo
getFlightInfoCrossReference(FlightSql.CommandGetCrossReference request,
+ CallContext context,
FlightDescriptor descriptor) {
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException();
+ }
+
+ @Override
+ public void getStreamExportedKeys(FlightSql.CommandGetExportedKeys command,
+ CallContext context, ServerStreamListener
listener) {
+ listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException());
+ }
+
+ @Override
+ public void getStreamImportedKeys(FlightSql.CommandGetImportedKeys command,
CallContext context,
+ ServerStreamListener listener) {
+ listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException());
+ }
+
+ @Override
+ public void getStreamCrossReference(FlightSql.CommandGetCrossReference
command, CallContext context,
+ ServerStreamListener listener) {
+ listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException());
+ }
+
+ @Override
+ public void close() throws Exception {
+
+ }
+
+ @Override
+ public void listFlights(CallContext context, Criteria criteria,
StreamListener<FlightInfo> listener) {
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException();
+ }
+}
diff --git
a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java
b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java
index 6da915a8ff..7635b80ecd 100644
---
a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java
+++
b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java
@@ -20,7 +20,7 @@ package org.apache.arrow.flight;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
-import static java.util.Objects.isNull;
+import static org.apache.arrow.flight.sql.util.FlightStreamUtils.getResults;
import static org.apache.arrow.util.AutoCloseables.close;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.is;
@@ -29,16 +29,12 @@ import static org.hamcrest.CoreMatchers.nullValue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
-import java.io.ByteArrayInputStream;
-import java.io.IOException;
-import java.nio.channels.Channels;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
-import java.util.Objects;
import java.util.Optional;
import java.util.stream.IntStream;
@@ -52,18 +48,9 @@ import
org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedCaseSensitivity;
import org.apache.arrow.flight.sql.util.TableRef;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
-import org.apache.arrow.vector.BitVector;
-import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.IntVector;
-import org.apache.arrow.vector.UInt1Vector;
-import org.apache.arrow.vector.UInt4Vector;
-import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
-import org.apache.arrow.vector.complex.DenseUnionVector;
-import org.apache.arrow.vector.complex.ListVector;
-import org.apache.arrow.vector.ipc.ReadChannel;
-import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.Types.MinorType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
@@ -657,197 +644,202 @@ public class TestFlightSql {
}
@Test
- public void testGetCommandExportedKeys() {
- final FlightStream stream =
+ public void testGetCommandExportedKeys() throws Exception {
+ try (final FlightStream stream =
sqlClient.getStream(
sqlClient.getExportedKeys(TableRef.of(null, null, "FOREIGNTABLE"))
- .getEndpoints().get(0).getTicket());
-
- final List<List<String>> results = getResults(stream);
-
- final List<Matcher<String>> matchers = asList(
- nullValue(String.class), // pk_catalog_name
- is("APP"), // pk_schema_name
- is("FOREIGNTABLE"), // pk_table_name
- is("ID"), // pk_column_name
- nullValue(String.class), // fk_catalog_name
- is("APP"), // fk_schema_name
- is("INTTABLE"), // fk_table_name
- is("FOREIGNID"), // fk_column_name
- is("1"), // key_sequence
- containsString("SQL"), // fk_key_name
- containsString("SQL"), // pk_key_name
- is("3"), // update_rule
- is("3")); // delete_rule
-
- final List<Executable> assertions = new ArrayList<>();
- Assertions.assertEquals(1, results.size());
- for (int i = 0; i < matchers.size(); i++) {
- final String actual = results.get(0).get(i);
- final Matcher<String> expected = matchers.get(i);
- assertions.add(() -> MatcherAssert.assertThat(actual, expected));
+ .getEndpoints().get(0).getTicket())) {
+
+ final List<List<String>> results = getResults(stream);
+
+ final List<Matcher<String>> matchers = asList(
+ nullValue(String.class), // pk_catalog_name
+ is("APP"), // pk_schema_name
+ is("FOREIGNTABLE"), // pk_table_name
+ is("ID"), // pk_column_name
+ nullValue(String.class), // fk_catalog_name
+ is("APP"), // fk_schema_name
+ is("INTTABLE"), // fk_table_name
+ is("FOREIGNID"), // fk_column_name
+ is("1"), // key_sequence
+ containsString("SQL"), // fk_key_name
+ containsString("SQL"), // pk_key_name
+ is("3"), // update_rule
+ is("3")); // delete_rule
+
+ final List<Executable> assertions = new ArrayList<>();
+ Assertions.assertEquals(1, results.size());
+ for (int i = 0; i < matchers.size(); i++) {
+ final String actual = results.get(0).get(i);
+ final Matcher<String> expected = matchers.get(i);
+ assertions.add(() -> MatcherAssert.assertThat(actual, expected));
+ }
+ Assertions.assertAll(assertions);
}
- Assertions.assertAll(assertions);
}
@Test
- public void testGetCommandImportedKeys() {
- final FlightStream stream =
+ public void testGetCommandImportedKeys() throws Exception {
+ try (final FlightStream stream =
sqlClient.getStream(
sqlClient.getImportedKeys(TableRef.of(null, null, "INTTABLE"))
- .getEndpoints().get(0).getTicket());
-
- final List<List<String>> results = getResults(stream);
-
- final List<Matcher<String>> matchers = asList(
- nullValue(String.class), // pk_catalog_name
- is("APP"), // pk_schema_name
- is("FOREIGNTABLE"), // pk_table_name
- is("ID"), // pk_column_name
- nullValue(String.class), // fk_catalog_name
- is("APP"), // fk_schema_name
- is("INTTABLE"), // fk_table_name
- is("FOREIGNID"), // fk_column_name
- is("1"), // key_sequence
- containsString("SQL"), // fk_key_name
- containsString("SQL"), // pk_key_name
- is("3"), // update_rule
- is("3")); // delete_rule
-
- Assertions.assertEquals(1, results.size());
- final List<Executable> assertions = new ArrayList<>();
- for (int i = 0; i < matchers.size(); i++) {
- final String actual = results.get(0).get(i);
- final Matcher<String> expected = matchers.get(i);
- assertions.add(() -> MatcherAssert.assertThat(actual, expected));
+ .getEndpoints().get(0).getTicket())) {
+
+ final List<List<String>> results = getResults(stream);
+
+ final List<Matcher<String>> matchers = asList(
+ nullValue(String.class), // pk_catalog_name
+ is("APP"), // pk_schema_name
+ is("FOREIGNTABLE"), // pk_table_name
+ is("ID"), // pk_column_name
+ nullValue(String.class), // fk_catalog_name
+ is("APP"), // fk_schema_name
+ is("INTTABLE"), // fk_table_name
+ is("FOREIGNID"), // fk_column_name
+ is("1"), // key_sequence
+ containsString("SQL"), // fk_key_name
+ containsString("SQL"), // pk_key_name
+ is("3"), // update_rule
+ is("3")); // delete_rule
+
+ Assertions.assertEquals(1, results.size());
+ final List<Executable> assertions = new ArrayList<>();
+ for (int i = 0; i < matchers.size(); i++) {
+ final String actual = results.get(0).get(i);
+ final Matcher<String> expected = matchers.get(i);
+ assertions.add(() -> MatcherAssert.assertThat(actual, expected));
+ }
+ Assertions.assertAll(assertions);
}
- Assertions.assertAll(assertions);
}
@Test
- public void testGetTypeInfo() {
+ public void testGetTypeInfo() throws Exception {
FlightInfo flightInfo = sqlClient.getXdbcTypeInfo();
- FlightStream stream =
sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket());
-
- final List<List<String>> results = getResults(stream);
-
- final List<List<String>> matchers = ImmutableList.of(
- asList("BIGINT", "-5", "19", null, null, emptyList().toString(), "1",
"false", "2", "false", "false", "true",
- "BIGINT", "0", "0",
- null, null, "10", null),
- asList("LONG VARCHAR FOR BIT DATA", "-4", "32700", "X'", "'",
emptyList().toString(), "1", "false", "0", "true",
- "false", "false",
- "LONG VARCHAR FOR BIT DATA", null, null, null, null, null, null),
- asList("VARCHAR () FOR BIT DATA", "-3", "32672", "X'", "'",
singletonList("length").toString(), "1", "false",
- "2", "true", "false",
- "false", "VARCHAR () FOR BIT DATA", null, null, null, null, null,
null),
- asList("CHAR () FOR BIT DATA", "-2", "254", "X'", "'",
singletonList("length").toString(), "1", "false", "2",
- "true", "false", "false",
- "CHAR () FOR BIT DATA", null, null, null, null, null, null),
- asList("LONG VARCHAR", "-1", "32700", "'", "'",
emptyList().toString(), "1", "true", "1", "true", "false",
- "false",
- "LONG VARCHAR", null, null, null, null, null, null),
- asList("CHAR", "1", "254", "'", "'",
singletonList("length").toString(), "1", "true", "3", "true", "false",
- "false", "CHAR", null, null,
- null, null, null, null),
- asList("NUMERIC", "2", "31", null, null, Arrays.asList("precision",
"scale").toString(), "1", "false", "2",
- "false", "true", "false",
- "NUMERIC", "0", "31", null, null, "10", null),
- asList("DECIMAL", "3", "31", null, null, Arrays.asList("precision",
"scale").toString(), "1", "false", "2",
- "false", "true", "false",
- "DECIMAL", "0", "31", null, null, "10", null),
- asList("INTEGER", "4", "10", null, null, emptyList().toString(), "1",
"false", "2", "false", "false", "true",
- "INTEGER", "0", "0",
- null, null, "10", null),
- asList("SMALLINT", "5", "5", null, null, emptyList().toString(), "1",
"false", "2", "false", "false", "true",
- "SMALLINT", "0",
- "0", null, null, "10", null),
- asList("FLOAT", "6", "52", null, null,
singletonList("precision").toString(), "1", "false", "2", "false",
- "false", "false", "FLOAT", null,
- null, null, null, "2", null),
- asList("REAL", "7", "23", null, null, emptyList().toString(), "1",
"false", "2", "false", "false", "false",
- "REAL", null, null,
- null, null, "2", null),
- asList("DOUBLE", "8", "52", null, null, emptyList().toString(), "1",
"false", "2", "false", "false", "false",
- "DOUBLE", null,
- null, null, null, "2", null),
- asList("VARCHAR", "12", "32672", "'", "'",
singletonList("length").toString(), "1", "true", "3", "true",
- "false", "false", "VARCHAR",
- null, null, null, null, null, null),
- asList("BOOLEAN", "16", "1", null, null, emptyList().toString(), "1",
"false", "2", "true", "false", "false",
- "BOOLEAN", null,
- null, null, null, null, null),
- asList("DATE", "91", "10", "DATE'", "'", emptyList().toString(), "1",
"false", "2", "true", "false", "false",
- "DATE", "0", "0",
- null, null, "10", null),
- asList("TIME", "92", "8", "TIME'", "'", emptyList().toString(), "1",
"false", "2", "true", "false", "false",
- "TIME", "0", "0",
- null, null, "10", null),
- asList("TIMESTAMP", "93", "29", "TIMESTAMP'", "'",
emptyList().toString(), "1", "false", "2", "true", "false",
- "false",
- "TIMESTAMP", "0", "9", null, null, "10", null),
- asList("OBJECT", "2000", null, null, null, emptyList().toString(),
"1", "false", "2", "true", "false", "false",
- "OBJECT", null,
- null, null, null, null, null),
- asList("BLOB", "2004", "2147483647", null, null,
singletonList("length").toString(), "1", "false", "0", null,
- "false", null, "BLOB", null,
- null, null, null, null, null),
- asList("CLOB", "2005", "2147483647", "'", "'",
singletonList("length").toString(), "1", "true", "1", null,
- "false", null, "CLOB", null,
- null, null, null, null, null),
- asList("XML", "2009", null, null, null, emptyList().toString(), "1",
"true", "0", "false", "false", "false",
- "XML", null, null,
- null, null, null, null));
- MatcherAssert.assertThat(results, is(matchers));
+ try (FlightStream stream =
sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket())) {
+
+ final List<List<String>> results = getResults(stream);
+
+ final List<List<String>> matchers = ImmutableList.of(
+ asList("BIGINT", "-5", "19", null, null, emptyList().toString(),
"1", "false", "2", "false", "false", "true",
+ "BIGINT", "0", "0",
+ null, null, "10", null),
+ asList("LONG VARCHAR FOR BIT DATA", "-4", "32700", "X'", "'",
emptyList().toString(), "1", "false", "0",
+ "true", "false", "false",
+ "LONG VARCHAR FOR BIT DATA", null, null, null, null, null, null),
+ asList("VARCHAR () FOR BIT DATA", "-3", "32672", "X'", "'",
singletonList("length").toString(), "1", "false",
+ "2", "true", "false",
+ "false", "VARCHAR () FOR BIT DATA", null, null, null, null,
null, null),
+ asList("CHAR () FOR BIT DATA", "-2", "254", "X'", "'",
singletonList("length").toString(), "1", "false", "2",
+ "true", "false", "false",
+ "CHAR () FOR BIT DATA", null, null, null, null, null, null),
+ asList("LONG VARCHAR", "-1", "32700", "'", "'",
emptyList().toString(), "1", "true", "1", "true", "false",
+ "false",
+ "LONG VARCHAR", null, null, null, null, null, null),
+ asList("CHAR", "1", "254", "'", "'",
singletonList("length").toString(), "1", "true", "3", "true", "false",
+ "false", "CHAR", null, null,
+ null, null, null, null),
+ asList("NUMERIC", "2", "31", null, null, Arrays.asList("precision",
"scale").toString(), "1", "false", "2",
+ "false", "true", "false",
+ "NUMERIC", "0", "31", null, null, "10", null),
+ asList("DECIMAL", "3", "31", null, null, Arrays.asList("precision",
"scale").toString(), "1", "false", "2",
+ "false", "true", "false",
+ "DECIMAL", "0", "31", null, null, "10", null),
+ asList("INTEGER", "4", "10", null, null, emptyList().toString(),
"1", "false", "2", "false", "false", "true",
+ "INTEGER", "0", "0",
+ null, null, "10", null),
+ asList("SMALLINT", "5", "5", null, null, emptyList().toString(),
"1", "false", "2", "false", "false", "true",
+ "SMALLINT", "0",
+ "0", null, null, "10", null),
+ asList("FLOAT", "6", "52", null, null,
singletonList("precision").toString(), "1", "false", "2", "false",
+ "false", "false", "FLOAT", null,
+ null, null, null, "2", null),
+ asList("REAL", "7", "23", null, null, emptyList().toString(), "1",
"false", "2", "false", "false", "false",
+ "REAL", null, null,
+ null, null, "2", null),
+ asList("DOUBLE", "8", "52", null, null, emptyList().toString(), "1",
"false", "2", "false", "false", "false",
+ "DOUBLE", null,
+ null, null, null, "2", null),
+ asList("VARCHAR", "12", "32672", "'", "'",
singletonList("length").toString(), "1", "true", "3", "true",
+ "false", "false", "VARCHAR",
+ null, null, null, null, null, null),
+ asList("BOOLEAN", "16", "1", null, null, emptyList().toString(),
"1", "false", "2", "true", "false", "false",
+ "BOOLEAN", null,
+ null, null, null, null, null),
+ asList("DATE", "91", "10", "DATE'", "'", emptyList().toString(),
"1", "false", "2", "true", "false", "false",
+ "DATE", "0", "0",
+ null, null, "10", null),
+ asList("TIME", "92", "8", "TIME'", "'", emptyList().toString(), "1",
"false", "2", "true", "false", "false",
+ "TIME", "0", "0",
+ null, null, "10", null),
+ asList("TIMESTAMP", "93", "29", "TIMESTAMP'", "'",
emptyList().toString(), "1", "false", "2", "true", "false",
+ "false",
+ "TIMESTAMP", "0", "9", null, null, "10", null),
+ asList("OBJECT", "2000", null, null, null, emptyList().toString(),
"1", "false", "2", "true", "false",
+ "false", "OBJECT", null,
+ null, null, null, null, null),
+ asList("BLOB", "2004", "2147483647", null, null,
singletonList("length").toString(), "1", "false", "0", null,
+ "false", null, "BLOB", null,
+ null, null, null, null, null),
+ asList("CLOB", "2005", "2147483647", "'", "'",
singletonList("length").toString(), "1", "true", "1", null,
+ "false", null, "CLOB", null,
+ null, null, null, null, null),
+ asList("XML", "2009", null, null, null, emptyList().toString(), "1",
"true", "0", "false", "false", "false",
+ "XML", null, null,
+ null, null, null, null));
+ MatcherAssert.assertThat(results, is(matchers));
+ }
}
@Test
- public void testGetTypeInfoWithFiltering() {
+ public void testGetTypeInfoWithFiltering() throws Exception {
FlightInfo flightInfo = sqlClient.getXdbcTypeInfo(-5);
- FlightStream stream =
sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket());
+ try (FlightStream stream =
sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket())) {
- final List<List<String>> results = getResults(stream);
+ final List<List<String>> results = getResults(stream);
- final List<List<String>> matchers = ImmutableList.of(
- asList("BIGINT", "-5", "19", null, null, emptyList().toString(), "1",
"false", "2", "false", "false", "true",
- "BIGINT", "0", "0",
- null, null, "10", null));
- MatcherAssert.assertThat(results, is(matchers));
+ final List<List<String>> matchers = ImmutableList.of(
+ asList("BIGINT", "-5", "19", null, null, emptyList().toString(),
"1", "false", "2", "false", "false", "true",
+ "BIGINT", "0", "0",
+ null, null, "10", null));
+ MatcherAssert.assertThat(results, is(matchers));
+ }
}
@Test
- public void testGetCommandCrossReference() {
+ public void testGetCommandCrossReference() throws Exception {
final FlightInfo flightInfo =
sqlClient.getCrossReference(TableRef.of(null, null,
"FOREIGNTABLE"), TableRef.of(null, null, "INTTABLE"));
- final FlightStream stream =
sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket());
-
- final List<List<String>> results = getResults(stream);
-
- final List<Matcher<String>> matchers = asList(
- nullValue(String.class), // pk_catalog_name
- is("APP"), // pk_schema_name
- is("FOREIGNTABLE"), // pk_table_name
- is("ID"), // pk_column_name
- nullValue(String.class), // fk_catalog_name
- is("APP"), // fk_schema_name
- is("INTTABLE"), // fk_table_name
- is("FOREIGNID"), // fk_column_name
- is("1"), // key_sequence
- containsString("SQL"), // fk_key_name
- containsString("SQL"), // pk_key_name
- is("3"), // update_rule
- is("3")); // delete_rule
-
- Assertions.assertEquals(1, results.size());
- final List<Executable> assertions = new ArrayList<>();
- for (int i = 0; i < matchers.size(); i++) {
- final String actual = results.get(0).get(i);
- final Matcher<String> expected = matchers.get(i);
- assertions.add(() -> MatcherAssert.assertThat(actual, expected));
+ try (final FlightStream stream =
sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket())) {
+
+ final List<List<String>> results = getResults(stream);
+
+ final List<Matcher<String>> matchers = asList(
+ nullValue(String.class), // pk_catalog_name
+ is("APP"), // pk_schema_name
+ is("FOREIGNTABLE"), // pk_table_name
+ is("ID"), // pk_column_name
+ nullValue(String.class), // fk_catalog_name
+ is("APP"), // fk_schema_name
+ is("INTTABLE"), // fk_table_name
+ is("FOREIGNID"), // fk_column_name
+ is("1"), // key_sequence
+ containsString("SQL"), // fk_key_name
+ containsString("SQL"), // pk_key_name
+ is("3"), // update_rule
+ is("3")); // delete_rule
+
+ Assertions.assertEquals(1, results.size());
+ final List<Executable> assertions = new ArrayList<>();
+ for (int i = 0; i < matchers.size(); i++) {
+ final String actual = results.get(0).get(i);
+ final Matcher<String> expected = matchers.get(i);
+ assertions.add(() -> MatcherAssert.assertThat(actual, expected));
+ }
+ Assertions.assertAll(assertions);
}
- Assertions.assertAll(assertions);
}
@Test
@@ -878,90 +870,6 @@ public class TestFlightSql {
}
}
- List<List<String>> getResults(FlightStream stream) {
- final List<List<String>> results = new ArrayList<>();
- while (stream.next()) {
- try (final VectorSchemaRoot root = stream.getRoot()) {
- final long rowCount = root.getRowCount();
- for (int i = 0; i < rowCount; ++i) {
- results.add(new ArrayList<>());
- }
-
- root.getSchema().getFields().forEach(field -> {
- try (final FieldVector fieldVector =
root.getVector(field.getName())) {
- if (fieldVector instanceof VarCharVector) {
- final VarCharVector varcharVector = (VarCharVector) fieldVector;
- for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) {
- final Text data = varcharVector.getObject(rowIndex);
- results.get(rowIndex).add(isNull(data) ? null :
data.toString());
- }
- } else if (fieldVector instanceof IntVector) {
- for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) {
- Object data = fieldVector.getObject(rowIndex);
- results.get(rowIndex).add(isNull(data) ? null :
Objects.toString(data));
- }
- } else if (fieldVector instanceof VarBinaryVector) {
- final VarBinaryVector varbinaryVector = (VarBinaryVector)
fieldVector;
- for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) {
- final byte[] data = varbinaryVector.getObject(rowIndex);
- final String output;
- try {
- output = isNull(data) ?
- null :
- MessageSerializer.deserializeSchema(
- new ReadChannel(Channels.newChannel(new
ByteArrayInputStream(data)))).toJson();
- } catch (final IOException e) {
- throw new RuntimeException("Failed to deserialize schema",
e);
- }
- results.get(rowIndex).add(output);
- }
- } else if (fieldVector instanceof DenseUnionVector) {
- final DenseUnionVector denseUnionVector = (DenseUnionVector)
fieldVector;
- for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) {
- final Object data = denseUnionVector.getObject(rowIndex);
- results.get(rowIndex).add(isNull(data) ? null :
Objects.toString(data));
- }
- } else if (fieldVector instanceof ListVector) {
- for (int i = 0; i < fieldVector.getValueCount(); i++) {
- if (!fieldVector.isNull(i)) {
- List<Text> elements = (List<Text>) ((ListVector)
fieldVector).getObject(i);
- List<String> values = new ArrayList<>();
-
- for (Text element : elements) {
- values.add(element.toString());
- }
- results.get(i).add(values.toString());
- }
- }
-
- } else if (fieldVector instanceof UInt4Vector) {
- final UInt4Vector uInt4Vector = (UInt4Vector) fieldVector;
- for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) {
- final Object data = uInt4Vector.getObject(rowIndex);
- results.get(rowIndex).add(isNull(data) ? null :
Objects.toString(data));
- }
- } else if (fieldVector instanceof UInt1Vector) {
- final UInt1Vector uInt1Vector = (UInt1Vector) fieldVector;
- for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) {
- final Object data = uInt1Vector.getObject(rowIndex);
- results.get(rowIndex).add(isNull(data) ? null :
Objects.toString(data));
- }
- } else if (fieldVector instanceof BitVector) {
- for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) {
- Object data = fieldVector.getObject(rowIndex);
- results.get(rowIndex).add(isNull(data) ? null :
Objects.toString(data));
- }
- } else {
- throw new UnsupportedOperationException("Not yet implemented");
- }
- }
- });
- }
- }
-
- return results;
- }
-
@Test
public void testExecuteUpdate() {
Assertions.assertAll(
diff --git
a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSqlStreams.java
b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSqlStreams.java
new file mode 100644
index 0000000000..4672e0a141
--- /dev/null
+++
b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSqlStreams.java
@@ -0,0 +1,288 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import static java.util.Arrays.asList;
+import static java.util.Collections.emptyList;
+import static java.util.Collections.singletonList;
+import static org.apache.arrow.flight.sql.util.FlightStreamUtils.getResults;
+import static org.apache.arrow.util.AutoCloseables.close;
+import static org.apache.arrow.vector.types.Types.MinorType.INT;
+import static org.hamcrest.CoreMatchers.is;
+
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.arrow.flight.sql.BasicFlightSqlProducer;
+import org.apache.arrow.flight.sql.FlightSqlClient;
+import org.apache.arrow.flight.sql.FlightSqlProducer;
+import org.apache.arrow.flight.sql.impl.FlightSql;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.BitVector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.types.Types;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.arrow.vector.util.Text;
+import org.hamcrest.MatcherAssert;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+
+import com.google.common.collect.ImmutableList;
+import com.google.protobuf.Any;
+import com.google.protobuf.Message;
+
+public class TestFlightSqlStreams {
+
+ /**
+ * A limited {@link FlightSqlProducer} for testing GetTables, GetTableTypes,
GetSqlInfo, and limited SQL commands.
+ */
+ private static class FlightSqlTestProducer extends BasicFlightSqlProducer {
+
+ // Note that for simplicity the getStream* implementations are blocking,
but a proper FlightSqlProducer should
+ // have non-blocking implementations of getStream*.
+
+ private static final String FIXED_QUERY = "SELECT 1 AS c1 FROM test_table";
+ private static final Schema FIXED_SCHEMA = new Schema(asList(
+ Field.nullable("c1", Types.MinorType.INT.getType())));
+
+ private BufferAllocator allocator;
+
+ FlightSqlTestProducer(BufferAllocator allocator) {
+ this.allocator = allocator;
+ }
+
+ @Override
+ protected <T extends Message> List<FlightEndpoint> determineEndpoints(T
request, FlightDescriptor flightDescriptor,
+
Schema schema) {
+ if (request instanceof FlightSql.CommandGetTables ||
+ request instanceof FlightSql.CommandGetTableTypes ||
+ request instanceof FlightSql.CommandGetXdbcTypeInfo ||
+ request instanceof FlightSql.CommandGetSqlInfo) {
+ return Collections.singletonList(new FlightEndpoint(new
Ticket(Any.pack(request).toByteArray())));
+ } else if (request instanceof FlightSql.CommandStatementQuery &&
+ ((FlightSql.CommandStatementQuery)
request).getQuery().equals(FIXED_QUERY)) {
+
+ // Tickets from CommandStatementQuery requests should be built using
TicketStatementQuery then packed() into
+ // a ticket. The content of the statement handle is specific to the
FlightSqlProducer. It does not need to
+ // be the query. It can be a query ID for example.
+ FlightSql.TicketStatementQuery ticketStatementQuery =
FlightSql.TicketStatementQuery.newBuilder()
+ .setStatementHandle(((FlightSql.CommandStatementQuery)
request).getQueryBytes())
+ .build();
+ return Collections.singletonList(new FlightEndpoint(new
Ticket(Any.pack(ticketStatementQuery).toByteArray())));
+ }
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException();
+ }
+
+ @Override
+ public FlightInfo getFlightInfoStatement(FlightSql.CommandStatementQuery
command,
+ CallContext context,
FlightDescriptor descriptor) {
+ return generateFlightInfo(command, descriptor, FIXED_SCHEMA);
+ }
+
+ @Override
+ public void getStreamStatement(FlightSql.TicketStatementQuery ticket,
+ CallContext context, ServerStreamListener
listener) {
+ final String query = ticket.getStatementHandle().toStringUtf8();
+ if (!query.equals(FIXED_QUERY)) {
+ listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not
implemented.").toRuntimeException());
+ }
+
+ try (VectorSchemaRoot root = VectorSchemaRoot.create(FIXED_SCHEMA,
allocator)) {
+ root.setRowCount(1);
+ ((IntVector) root.getVector("c1")).setSafe(0, 1);
+ listener.start(root);
+ listener.putNext();
+ listener.completed();
+ }
+ }
+
+ @Override
+ public void getStreamSqlInfo(FlightSql.CommandGetSqlInfo command,
CallContext context,
+ ServerStreamListener listener) {
+ try (VectorSchemaRoot root =
VectorSchemaRoot.create(Schemas.GET_SQL_INFO_SCHEMA, allocator)) {
+ root.setRowCount(0);
+ listener.start(root);
+ listener.putNext();
+ listener.completed();
+ }
+ }
+
+ @Override
+ public void getStreamTypeInfo(FlightSql.CommandGetXdbcTypeInfo request,
+ CallContext context, ServerStreamListener
listener) {
+ try (VectorSchemaRoot root =
VectorSchemaRoot.create(Schemas.GET_TYPE_INFO_SCHEMA, allocator)) {
+ root.setRowCount(1);
+ ((VarCharVector) root.getVector("type_name")).setSafe(0, new
Text("Integer"));
+ ((IntVector) root.getVector("data_type")).setSafe(0, INT.ordinal());
+ ((IntVector) root.getVector("column_size")).setSafe(0, 400);
+ root.getVector("literal_prefix").setNull(0);
+ root.getVector("literal_suffix").setNull(0);
+ root.getVector("create_params").setNull(0);
+ ((IntVector) root.getVector("nullable")).setSafe(0,
FlightSql.Nullable.NULLABILITY_NULLABLE.getNumber());
+ ((BitVector) root.getVector("case_sensitive")).setSafe(0, 1);
+ ((IntVector) root.getVector("nullable")).setSafe(0,
FlightSql.Searchable.SEARCHABLE_FULL.getNumber());
+ ((BitVector) root.getVector("unsigned_attribute")).setSafe(0, 1);
+ root.getVector("fixed_prec_scale").setNull(0);
+ ((BitVector) root.getVector("auto_increment")).setSafe(0, 1);
+ ((VarCharVector) root.getVector("local_type_name")).setSafe(0, new
Text("Integer"));
+ root.getVector("minimum_scale").setNull(0);
+ root.getVector("maximum_scale").setNull(0);
+ ((IntVector) root.getVector("sql_data_type")).setSafe(0,
INT.ordinal());
+ root.getVector("datetime_subcode").setNull(0);
+ ((IntVector) root.getVector("num_prec_radix")).setSafe(0, 10);
+ root.getVector("interval_precision").setNull(0);
+
+ listener.start(root);
+ listener.putNext();
+ listener.completed();
+ }
+ }
+
+ @Override
+ public void getStreamTables(FlightSql.CommandGetTables command,
CallContext context,
+ ServerStreamListener listener) {
+ try (VectorSchemaRoot root =
VectorSchemaRoot.create(Schemas.GET_TABLES_SCHEMA_NO_SCHEMA, allocator)) {
+ root.setRowCount(1);
+ root.getVector("catalog_name").setNull(0);
+ root.getVector("db_schema_name").setNull(0);
+ ((VarCharVector) root.getVector("table_name")).setSafe(0, new
Text("test_table"));
+ ((VarCharVector) root.getVector("table_type")).setSafe(0, new
Text("TABLE"));
+
+ listener.start(root);
+ listener.putNext();
+ listener.completed();
+ }
+ }
+
+ @Override
+ public void getStreamTableTypes(CallContext context, ServerStreamListener
listener) {
+ try (VectorSchemaRoot root =
VectorSchemaRoot.create(Schemas.GET_TABLE_TYPES_SCHEMA, allocator)) {
+ root.setRowCount(1);
+ ((VarCharVector) root.getVector("table_type")).setSafe(0, new
Text("TABLE"));
+
+ listener.start(root);
+ listener.putNext();
+ listener.completed();
+ }
+ }
+ }
+
+ private static BufferAllocator allocator;
+
+ private static FlightServer server;
+ private static FlightSqlClient sqlClient;
+
+ @BeforeAll
+ public static void setUp() throws Exception {
+ allocator = new RootAllocator(Integer.MAX_VALUE);
+
+ final Location serverLocation = Location.forGrpcInsecure("localhost", 0);
+ server = FlightServer.builder(allocator, serverLocation, new
FlightSqlTestProducer(allocator))
+ .build()
+ .start();
+
+ final Location clientLocation = Location.forGrpcInsecure("localhost",
server.getPort());
+ sqlClient = new FlightSqlClient(FlightClient.builder(allocator,
clientLocation).build());
+ }
+
+ @AfterAll
+ public static void tearDown() throws Exception {
+ close(sqlClient, server, allocator);
+ }
+
+ @Test
+ public void testGetTablesResultNoSchema() throws Exception {
+ try (final FlightStream stream =
+ sqlClient.getStream(
+ sqlClient.getTables(null, null, null, null, false)
+ .getEndpoints().get(0).getTicket())) {
+ Assertions.assertAll(
+ () -> MatcherAssert.assertThat(stream.getSchema(),
is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA)),
+ () -> {
+ final List<List<String>> results = getResults(stream);
+ final List<List<String>> expectedResults = ImmutableList.of(
+ // catalog_name | schema_name | table_name | table_type |
table_schema
+ asList(null, null, "test_table", "TABLE"));
+ MatcherAssert.assertThat(results, is(expectedResults));
+ }
+ );
+ }
+ }
+
+ @Test
+ public void testGetTableTypesResult() throws Exception {
+ try (final FlightStream stream =
+
sqlClient.getStream(sqlClient.getTableTypes().getEndpoints().get(0).getTicket()))
{
+ Assertions.assertAll(
+ () -> MatcherAssert.assertThat(stream.getSchema(),
is(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA)),
+ () -> {
+ final List<List<String>> tableTypes = getResults(stream);
+ final List<List<String>> expectedTableTypes = ImmutableList.of(
+ // table_type
+ singletonList("TABLE")
+ );
+ MatcherAssert.assertThat(tableTypes, is(expectedTableTypes));
+ }
+ );
+ }
+ }
+
+ @Test
+ public void testGetSqlInfoResults() throws Exception {
+ final FlightInfo info = sqlClient.getSqlInfo();
+ try (final FlightStream stream =
sqlClient.getStream(info.getEndpoints().get(0).getTicket())) {
+ Assertions.assertAll(
+ () -> MatcherAssert.assertThat(stream.getSchema(),
is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)),
+ () -> MatcherAssert.assertThat(getResults(stream), is(emptyList()))
+ );
+ }
+ }
+
+ @Test
+ public void testGetTypeInfo() throws Exception {
+ FlightInfo flightInfo = sqlClient.getXdbcTypeInfo();
+
+ try (FlightStream stream =
sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket())) {
+
+ final List<List<String>> results = getResults(stream);
+
+ final List<List<String>> matchers = ImmutableList.of(
+ asList("Integer", "4", "400", null, null, "3", "true", null, "true",
null, "true",
+ "Integer", null, null, "4", null, "10", null));
+
+ MatcherAssert.assertThat(results, is(matchers));
+ }
+ }
+
+ @Test
+ public void testExecuteQuery() throws Exception {
+ try (final FlightStream stream = sqlClient
+
.getStream(sqlClient.execute(FlightSqlTestProducer.FIXED_QUERY).getEndpoints().get(0).getTicket()))
{
+ Assertions.assertAll(
+ () -> MatcherAssert.assertThat(stream.getSchema(),
is(FlightSqlTestProducer.FIXED_SCHEMA)),
+ () -> MatcherAssert.assertThat(getResults(stream),
is(singletonList(singletonList("1"))))
+ );
+ }
+ }
+}
diff --git
a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/FlightStreamUtils.java
b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/FlightStreamUtils.java
new file mode 100644
index 0000000000..fbbe9ef018
--- /dev/null
+++
b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/FlightStreamUtils.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.sql.util;
+
+import static java.util.Objects.isNull;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.nio.channels.Channels;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+
+import org.apache.arrow.flight.FlightStream;
+import org.apache.arrow.vector.BitVector;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.UInt1Vector;
+import org.apache.arrow.vector.UInt4Vector;
+import org.apache.arrow.vector.VarBinaryVector;
+import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.complex.DenseUnionVector;
+import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.ipc.ReadChannel;
+import org.apache.arrow.vector.ipc.message.MessageSerializer;
+import org.apache.arrow.vector.util.Text;
+
+public class FlightStreamUtils {
+
+ public static List<List<String>> getResults(FlightStream stream) {
+ final List<List<String>> results = new ArrayList<>();
+ while (stream.next()) {
+ try (final VectorSchemaRoot root = stream.getRoot()) {
+ final long rowCount = root.getRowCount();
+ for (int i = 0; i < rowCount; ++i) {
+ results.add(new ArrayList<>());
+ }
+
+ root.getSchema().getFields().forEach(field -> {
+ try (final FieldVector fieldVector =
root.getVector(field.getName())) {
+ if (fieldVector instanceof VarCharVector) {
+ final VarCharVector varcharVector = (VarCharVector) fieldVector;
+ for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) {
+ final Text data = varcharVector.getObject(rowIndex);
+ results.get(rowIndex).add(isNull(data) ? null :
data.toString());
+ }
+ } else if (fieldVector instanceof IntVector) {
+ for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) {
+ Object data = fieldVector.getObject(rowIndex);
+ results.get(rowIndex).add(isNull(data) ? null :
Objects.toString(data));
+ }
+ } else if (fieldVector instanceof VarBinaryVector) {
+ final VarBinaryVector varbinaryVector = (VarBinaryVector)
fieldVector;
+ for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) {
+ final byte[] data = varbinaryVector.getObject(rowIndex);
+ final String output;
+ try {
+ output = isNull(data) ?
+ null :
+ MessageSerializer.deserializeSchema(
+ new ReadChannel(Channels.newChannel(new
ByteArrayInputStream(data)))).toJson();
+ } catch (final IOException e) {
+ throw new RuntimeException("Failed to deserialize schema",
e);
+ }
+ results.get(rowIndex).add(output);
+ }
+ } else if (fieldVector instanceof DenseUnionVector) {
+ final DenseUnionVector denseUnionVector = (DenseUnionVector)
fieldVector;
+ for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) {
+ final Object data = denseUnionVector.getObject(rowIndex);
+ results.get(rowIndex).add(isNull(data) ? null :
Objects.toString(data));
+ }
+ } else if (fieldVector instanceof ListVector) {
+ for (int i = 0; i < fieldVector.getValueCount(); i++) {
+ if (!fieldVector.isNull(i)) {
+ List<Text> elements = (List<Text>) ((ListVector)
fieldVector).getObject(i);
+ List<String> values = new ArrayList<>();
+
+ for (Text element : elements) {
+ values.add(element.toString());
+ }
+ results.get(i).add(values.toString());
+ }
+ }
+
+ } else if (fieldVector instanceof UInt4Vector) {
+ final UInt4Vector uInt4Vector = (UInt4Vector) fieldVector;
+ for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) {
+ final Object data = uInt4Vector.getObject(rowIndex);
+ results.get(rowIndex).add(isNull(data) ? null :
Objects.toString(data));
+ }
+ } else if (fieldVector instanceof UInt1Vector) {
+ final UInt1Vector uInt1Vector = (UInt1Vector) fieldVector;
+ for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) {
+ final Object data = uInt1Vector.getObject(rowIndex);
+ results.get(rowIndex).add(isNull(data) ? null :
Objects.toString(data));
+ }
+ } else if (fieldVector instanceof BitVector) {
+ for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) {
+ Object data = fieldVector.getObject(rowIndex);
+ results.get(rowIndex).add(isNull(data) ? null :
Objects.toString(data));
+ }
+ } else {
+ throw new UnsupportedOperationException("Not yet implemented");
+ }
+ }
+ });
+ }
+ }
+
+ return results;
+ }
+}