This is an automated email from the ASF dual-hosted git repository.
mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new f7d7bca39b [CALCITE-6635] Refactor Arrow adapter Test
f7d7bca39b is described below
commit f7d7bca39bf313c6fe511e60601e3f20d4bc353a
Author: Cancai Cai <[email protected]>
AuthorDate: Sat Oct 19 13:38:51 2024 +0800
[CALCITE-6635] Refactor Arrow adapter Test
---
.../adapter/arrow/ArrowAdapterDataTypesTest.java | 179 +++++++++++++++++++++
.../apache/calcite/adapter/arrow/ArrowData.java | 172 ++++++++++++++++++++
2 files changed, 351 insertions(+)
diff --git
a/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowAdapterDataTypesTest.java
b/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowAdapterDataTypesTest.java
new file mode 100644
index 0000000000..52c3c74f47
--- /dev/null
+++
b/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowAdapterDataTypesTest.java
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.adapter.arrow;
+
+import org.apache.calcite.test.CalciteAssert;
+import org.apache.calcite.util.Sources;
+
+import com.google.common.collect.ImmutableMap;
+
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
+
+import java.io.File;
+import java.io.IOException;
+import java.net.URL;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.sql.SQLException;
+import java.util.Map;
+
+import static java.util.Objects.requireNonNull;
+
+/**
+ * Test cases for Arrow adapter data types.
+ */
+public class ArrowAdapterDataTypesTest {
+
+ private static Map<String, String> arrow;
+ private static File arrowDataDirectory;
+
+ @BeforeAll
+ static void initializeArrowState(@TempDir Path sharedTempDir)
+ throws IOException, SQLException {
+ URL modelUrl =
+ requireNonNull(
+ ArrowAdapterTest.class.getResource("/arrow-model.json"), "url");
+ Path sourceModelFilePath = Sources.of(modelUrl).file().toPath();
+ Path modelFileTarget = sharedTempDir.resolve("arrow-model.json");
+ Files.copy(sourceModelFilePath, modelFileTarget);
+
+ Path arrowFilesDirectory = sharedTempDir.resolve("arrow");
+ Files.createDirectory(arrowFilesDirectory);
+ arrowDataDirectory = arrowFilesDirectory.toFile();
+
+ File dataLocationFile =
arrowFilesDirectory.resolve("arrowdatatype.arrow").toFile();
+ ArrowData arrowDataGenerator = new ArrowData();
+ arrowDataGenerator.writeArrowDataType(dataLocationFile);
+
+ arrow = ImmutableMap.of("model",
modelFileTarget.toAbsolutePath().toString());
+ }
+
+ @Test void testTinyIntProject() {
+ String sql = "select \"tinyIntField\" from arrowdatatype";
+ String plan = "PLAN=ArrowToEnumerableConverter\n"
+ + " ArrowProject(tinyIntField=[$0])\n"
+ + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1,
2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ String result = "tinyIntField=0\ntinyIntField=1\n";
+ CalciteAssert.that()
+ .with(arrow)
+ .query(sql)
+ .limit(2)
+ .returns(result)
+ .explainContains(plan);
+ }
+
+ @Test void testSmallIntProject() {
+ String sql = "select \"smallIntField\" from arrowdatatype";
+ String plan = "PLAN=ArrowToEnumerableConverter\n"
+ + " ArrowProject(smallIntField=[$1])\n"
+ + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1,
2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ String result = "smallIntField=0\nsmallIntField=1\n";
+ CalciteAssert.that()
+ .with(arrow)
+ .query(sql)
+ .limit(2)
+ .returns(result)
+ .explainContains(plan);
+ }
+
+ @Test void testIntProject() {
+ String sql = "select \"intField\" from arrowdatatype";
+ String plan = "PLAN=ArrowToEnumerableConverter\n"
+ + " ArrowProject(intField=[$2])\n"
+ + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1,
2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ String result = "intField=0\nintField=1\n";
+ CalciteAssert.that()
+ .with(arrow)
+ .query(sql)
+ .limit(2)
+ .returns(result)
+ .explainContains(plan);
+ }
+
+ @Test void testLongProject() {
+ String sql = "select \"longField\" from arrowdatatype";
+ String plan = "PLAN=ArrowToEnumerableConverter\n"
+ + " ArrowProject(longField=[$5])\n"
+ + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1,
2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ String result = "longField=0\nlongField=1\n";
+ CalciteAssert.that()
+ .with(arrow)
+ .query(sql)
+ .limit(2)
+ .returns(result)
+ .explainContains(plan);
+ }
+
+ @Test void testFloatProject() {
+ String sql = "select \"floatField\" from arrowdatatype";
+ String plan = "PLAN=ArrowToEnumerableConverter\n"
+ + " ArrowProject(floatField=[$4])\n"
+ + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1,
2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ String result = "floatField=0.0\nfloatField=1.0\n";
+ CalciteAssert.that()
+ .with(arrow)
+ .query(sql)
+ .limit(2)
+ .returns(result)
+ .explainContains(plan);
+ }
+
+ @Test void testDoubleProject() {
+ String sql = "select \"doubleField\" from arrowdatatype";
+ String plan = "PLAN=ArrowToEnumerableConverter\n"
+ + " ArrowProject(doubleField=[$6])\n"
+ + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1,
2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ String result = "doubleField=0.0\ndoubleField=1.0\n";
+ CalciteAssert.that()
+ .with(arrow)
+ .query(sql)
+ .limit(2)
+ .returns(result)
+ .explainContains(plan);
+ }
+
+ @Test void testDecimalProject() {
+ String sql = "select \"decimalField\" from arrowdatatype";
+ String plan = "PLAN=ArrowToEnumerableConverter\n"
+ + " ArrowProject(decimalField=[$8])\n"
+ + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1,
2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ String result = "decimalField=0.00\ndecimalField=1.00\n";
+ CalciteAssert.that()
+ .with(arrow)
+ .query(sql)
+ .limit(2)
+ .returns(result)
+ .explainContains(plan);
+ }
+
+ @Test void testDateProject() {
+ String sql = "select \"dateField\" from arrowdatatype";
+ String plan = "PLAN=ArrowToEnumerableConverter\n"
+ + " ArrowProject(dateField=[$9])\n"
+ + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1,
2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ String result = "dateField=1970-01-01\n"
+ + "dateField=1970-01-02\n";
+ CalciteAssert.that()
+ .with(arrow)
+ .query(sql)
+ .limit(2)
+ .returns(result)
+ .explainContains(plan);
+ }
+}
diff --git
a/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowData.java
b/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowData.java
index 3870bb2e11..e85c78e087 100644
--- a/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowData.java
+++ b/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowData.java
@@ -23,12 +23,19 @@ import
org.apache.arrow.adapter.jdbc.JdbcToArrowConfigBuilder;
import org.apache.arrow.adapter.jdbc.JdbcToArrowUtils;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
+import org.apache.arrow.vector.BitVector;
+import org.apache.arrow.vector.DateDayVector;
+import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.FloatingPointVector;
import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.SmallIntVector;
+import org.apache.arrow.vector.TinyIntVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowFileWriter;
+import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
@@ -43,6 +50,7 @@ import net.hydromatic.scott.data.hsqldb.ScottHsqldb;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
+import java.math.BigDecimal;
import java.nio.file.Path;
import java.sql.Connection;
import java.sql.DriverManager;
@@ -59,18 +67,57 @@ public class ArrowData {
private final int batchSize;
private final int entries;
+ private byte tinyIntValue;
+ private short smallIntValue;
private int intValue;
private int stringValue;
private float floatValue;
private long longValue;
+ private double doubleValue;
+ private boolean booleanValue;
+ private BigDecimal decimalValue;
public ArrowData() {
this.batchSize = 20;
this.entries = 50;
+ this.tinyIntValue = 0;
+ this.smallIntValue = 0;
this.intValue = 0;
this.stringValue = 0;
this.floatValue = 0;
this.longValue = 0;
+ this.doubleValue = 0;
+ this.booleanValue = false;
+ this.decimalValue = BigDecimal.ZERO;
+ }
+
+ private Schema makeArrowDateTypeSchema() {
+ ImmutableList.Builder<Field> childrenBuilder = ImmutableList.builder();
+ FieldType tinyIntType = FieldType.nullable(new ArrowType.Int(8, true));
+ FieldType smallIntType = FieldType.nullable(new ArrowType.Int(16, true));
+ FieldType intType = FieldType.nullable(new ArrowType.Int(32, true));
+ FieldType stringType = FieldType.nullable(new ArrowType.Utf8());
+ FieldType floatType =
+ FieldType.nullable(new
ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE));
+ FieldType longType = FieldType.nullable(new ArrowType.Int(64, true));
+ FieldType doubleType =
+ FieldType.nullable(new
ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE));
+ FieldType booleanType = FieldType.nullable(new ArrowType.Bool());
+ FieldType decimalType = FieldType.nullable(new ArrowType.Decimal(10, 2,
128));
+ FieldType dateType = FieldType.nullable(new ArrowType.Date(DateUnit.DAY));
+
+ childrenBuilder.add(new Field("tinyIntField", tinyIntType, null));
+ childrenBuilder.add(new Field("smallIntField", smallIntType, null));
+ childrenBuilder.add(new Field("intField", intType, null));
+ childrenBuilder.add(new Field("stringField", stringType, null));
+ childrenBuilder.add(new Field("floatField", floatType, null));
+ childrenBuilder.add(new Field("longField", longType, null));
+ childrenBuilder.add(new Field("doubleField", doubleType, null));
+ childrenBuilder.add(new Field("booleanField", booleanType, null));
+ childrenBuilder.add(new Field("decimalField", decimalType, null));
+ childrenBuilder.add(new Field("dateField", dateType, null));
+
+ return new Schema(childrenBuilder.build(), null);
}
private Schema makeArrowSchema() {
@@ -89,6 +136,7 @@ public class ArrowData {
return new Schema(childrenBuilder.build(), null);
}
+
public void writeScottEmpData(Path arrowDataDirectory) throws IOException,
SQLException {
List<String> tableNames = ImmutableList.of("EMP", "DEPT", "SALGRADE");
@@ -174,6 +222,87 @@ public class ArrowData {
fileOutputStream.close();
}
+ public void writeArrowDataType(File file) throws IOException {
+ FileOutputStream fileOutputStream = new FileOutputStream(file);
+ Schema arrowSchema = makeArrowDateTypeSchema();
+ VectorSchemaRoot vectorSchemaRoot =
+ VectorSchemaRoot.create(arrowSchema, new
RootAllocator(Integer.MAX_VALUE));
+ ArrowFileWriter arrowFileWriter =
+ new ArrowFileWriter(vectorSchemaRoot, null,
fileOutputStream.getChannel());
+
+ arrowFileWriter.start();
+
+ for (int i = 0; i < this.entries;) {
+ int numRows = Math.min(this.batchSize, this.entries - i);
+ vectorSchemaRoot.setRowCount(numRows);
+ for (Field field : vectorSchemaRoot.getSchema().getFields()) {
+ FieldVector vector = vectorSchemaRoot.getVector(field.getName());
+ switch (vector.getMinorType()) {
+ case TINYINT:
+ tinyIntField(vector, numRows);
+ break;
+ case SMALLINT:
+ smallIntFiled(vector, numRows);
+ break;
+ case INT:
+ intField(vector, numRows);
+ break;
+ case FLOAT4:
+ floatField(vector, numRows);
+ break;
+ case VARCHAR:
+ varCharField(vector, numRows);
+ break;
+ case BIGINT:
+ longField(vector, numRows);
+ break;
+ case FLOAT8:
+ doubleField(vector, numRows);
+ break;
+ case BIT:
+ booleanField(vector, numRows);
+ break;
+ case DECIMAL:
+ decimalField(vector, numRows);
+ break;
+ case DATEDAY:
+ dateField(vector, numRows);
+ break;
+ default:
+ throw new IllegalStateException("Not supported type yet: " +
vector.getMinorType());
+ }
+ }
+ arrowFileWriter.writeBatch();
+ i += numRows;
+ }
+ arrowFileWriter.end();
+ arrowFileWriter.close();
+ fileOutputStream.flush();
+ fileOutputStream.close();
+ }
+
+ private void tinyIntField(FieldVector fieldVector, int rowCount) {
+ TinyIntVector tinyIntVector = (TinyIntVector) fieldVector;
+ tinyIntVector.setInitialCapacity(rowCount);
+ tinyIntVector.allocateNew();
+ for (int i = 0; i < rowCount; i++) {
+ tinyIntVector.set(i, this.tinyIntValue);
+ this.tinyIntValue++;
+ }
+ fieldVector.setValueCount(rowCount);
+ }
+
+ private void smallIntFiled(FieldVector fieldVector, int rowCount) {
+ SmallIntVector smallIntVector = (SmallIntVector) fieldVector;
+ smallIntVector.setInitialCapacity(rowCount);
+ smallIntVector.allocateNew();
+ for (int i = 0; i < rowCount; i++) {
+ smallIntVector.set(i, this.smallIntValue);
+ this.smallIntValue++;
+ }
+ fieldVector.setValueCount(rowCount);
+ }
+
private void intField(FieldVector fieldVector, int rowCount) {
IntVector intVector = (IntVector) fieldVector;
intVector.setInitialCapacity(rowCount);
@@ -219,4 +348,47 @@ public class ArrowData {
}
fieldVector.setValueCount(rowCount);
}
+
+ private void doubleField(FieldVector fieldVector, int rowCount) {
+ Float8Vector float8Vector = (Float8Vector) fieldVector;
+ float8Vector.setInitialCapacity(rowCount);
+ float8Vector.allocateNew();
+ for (int i = 0; i < rowCount; i++) {
+ float8Vector.set(i, this.doubleValue);
+ this.doubleValue++;
+ }
+ fieldVector.setValueCount(rowCount);
+ }
+
+ private void booleanField(FieldVector fieldVector, int rowCount) {
+ BitVector bitVector = (BitVector) fieldVector;
+ bitVector.setInitialCapacity(rowCount);
+ bitVector.allocateNew();
+ for (int i = 0; i < rowCount; i++) {
+ bitVector.set(i, this.booleanValue ? 1 : 0);
+ this.booleanValue = !this.booleanValue;
+ }
+ fieldVector.setValueCount(rowCount);
+ }
+
+ private void decimalField(FieldVector fieldVector, int rowCount) {
+ DecimalVector decimalVector = (DecimalVector) fieldVector;
+ decimalVector.setInitialCapacity(rowCount);
+ decimalVector.allocateNew();
+ for (int i = 0; i < rowCount; i++) {
+ decimalVector.set(i, this.decimalValue.setScale(2));
+ this.decimalValue = this.decimalValue.add(BigDecimal.ONE);
+ }
+ fieldVector.setValueCount(rowCount);
+ }
+
+ private void dateField(FieldVector fieldVector, int rowCount) {
+ DateDayVector dateDayVector = (DateDayVector) fieldVector;
+ dateDayVector.setInitialCapacity(rowCount);
+ dateDayVector.allocateNew();
+ for (int i = 0; i < rowCount; i++) {
+ dateDayVector.set(i, i);
+ }
+ fieldVector.setValueCount(rowCount);
+ }
}