martin-traverse commented on code in PR #638:
URL: https://github.com/apache/arrow-java/pull/638#discussion_r2017456583


##########
adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java:
##########
@@ -0,0 +1,718 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.adapter.avro;
+
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.arrow.adapter.avro.producers.AvroBigIntProducer;
+import org.apache.arrow.adapter.avro.producers.AvroBooleanProducer;
+import org.apache.arrow.adapter.avro.producers.AvroBytesProducer;
+import org.apache.arrow.adapter.avro.producers.AvroFixedSizeBinaryProducer;
+import org.apache.arrow.adapter.avro.producers.AvroFixedSizeListProducer;
+import org.apache.arrow.adapter.avro.producers.AvroFloat2Producer;
+import org.apache.arrow.adapter.avro.producers.AvroFloat4Producer;
+import org.apache.arrow.adapter.avro.producers.AvroFloat8Producer;
+import org.apache.arrow.adapter.avro.producers.AvroIntProducer;
+import org.apache.arrow.adapter.avro.producers.AvroListProducer;
+import org.apache.arrow.adapter.avro.producers.AvroMapProducer;
+import org.apache.arrow.adapter.avro.producers.AvroNullProducer;
+import org.apache.arrow.adapter.avro.producers.AvroNullableProducer;
+import org.apache.arrow.adapter.avro.producers.AvroSmallIntProducer;
+import org.apache.arrow.adapter.avro.producers.AvroStringProducer;
+import org.apache.arrow.adapter.avro.producers.AvroStructProducer;
+import org.apache.arrow.adapter.avro.producers.AvroTinyIntProducer;
+import org.apache.arrow.adapter.avro.producers.AvroUint1Producer;
+import org.apache.arrow.adapter.avro.producers.AvroUint2Producer;
+import org.apache.arrow.adapter.avro.producers.AvroUint4Producer;
+import org.apache.arrow.adapter.avro.producers.AvroUint8Producer;
+import org.apache.arrow.adapter.avro.producers.BaseAvroProducer;
+import org.apache.arrow.adapter.avro.producers.CompositeAvroProducer;
+import org.apache.arrow.adapter.avro.producers.Producer;
+import org.apache.arrow.adapter.avro.producers.logical.AvroDateDayProducer;
+import org.apache.arrow.adapter.avro.producers.logical.AvroDateMilliProducer;
+import org.apache.arrow.adapter.avro.producers.logical.AvroDecimal256Producer;
+import org.apache.arrow.adapter.avro.producers.logical.AvroDecimalProducer;
+import org.apache.arrow.adapter.avro.producers.logical.AvroTimeMicroProducer;
+import org.apache.arrow.adapter.avro.producers.logical.AvroTimeMilliProducer;
+import org.apache.arrow.adapter.avro.producers.logical.AvroTimeNanoProducer;
+import org.apache.arrow.adapter.avro.producers.logical.AvroTimeSecProducer;
+import 
org.apache.arrow.adapter.avro.producers.logical.AvroTimestampMicroProducer;
+import 
org.apache.arrow.adapter.avro.producers.logical.AvroTimestampMicroTzProducer;
+import 
org.apache.arrow.adapter.avro.producers.logical.AvroTimestampMilliProducer;
+import 
org.apache.arrow.adapter.avro.producers.logical.AvroTimestampMilliTzProducer;
+import 
org.apache.arrow.adapter.avro.producers.logical.AvroTimestampNanoProducer;
+import 
org.apache.arrow.adapter.avro.producers.logical.AvroTimestampNanoTzProducer;
+import 
org.apache.arrow.adapter.avro.producers.logical.AvroTimestampSecProducer;
+import 
org.apache.arrow.adapter.avro.producers.logical.AvroTimestampSecTzProducer;
+import org.apache.arrow.util.Preconditions;
+import org.apache.arrow.vector.BigIntVector;
+import org.apache.arrow.vector.BitVector;
+import org.apache.arrow.vector.DateDayVector;
+import org.apache.arrow.vector.DateMilliVector;
+import org.apache.arrow.vector.Decimal256Vector;
+import org.apache.arrow.vector.DecimalVector;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.FixedSizeBinaryVector;
+import org.apache.arrow.vector.Float2Vector;
+import org.apache.arrow.vector.Float4Vector;
+import org.apache.arrow.vector.Float8Vector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.NullVector;
+import org.apache.arrow.vector.SmallIntVector;
+import org.apache.arrow.vector.TimeMicroVector;
+import org.apache.arrow.vector.TimeMilliVector;
+import org.apache.arrow.vector.TimeNanoVector;
+import org.apache.arrow.vector.TimeSecVector;
+import org.apache.arrow.vector.TimeStampMicroTZVector;
+import org.apache.arrow.vector.TimeStampMicroVector;
+import org.apache.arrow.vector.TimeStampMilliTZVector;
+import org.apache.arrow.vector.TimeStampMilliVector;
+import org.apache.arrow.vector.TimeStampNanoTZVector;
+import org.apache.arrow.vector.TimeStampNanoVector;
+import org.apache.arrow.vector.TimeStampSecTZVector;
+import org.apache.arrow.vector.TimeStampSecVector;
+import org.apache.arrow.vector.TinyIntVector;
+import org.apache.arrow.vector.UInt1Vector;
+import org.apache.arrow.vector.UInt2Vector;
+import org.apache.arrow.vector.UInt4Vector;
+import org.apache.arrow.vector.UInt8Vector;
+import org.apache.arrow.vector.VarBinaryVector;
+import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.complex.FixedSizeListVector;
+import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.MapVector;
+import org.apache.arrow.vector.complex.StructVector;
+import org.apache.arrow.vector.types.FloatingPointPrecision;
+import org.apache.arrow.vector.types.TimeUnit;
+import org.apache.arrow.vector.types.Types;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.avro.Schema;
+import org.apache.avro.SchemaBuilder;
+
+public class ArrowToAvroUtils {
+
+  public static final String GENERIC_RECORD_TYPE_NAME = "GenericRecord";
+
+  /**
+   * Create an Avro record schema for a given list of Arrow fields.
+   *
+   * <p>This method currently performs following type mapping for Avro data 
types to corresponding
+   * Arrow data types.
+   *
+   * <table>
+   *   <thead><tr><th>Arrow type</th><th>Avro encoding</th></tr></thead>
+   *   <tbody>
+   *     <tr><td>ArrowType.Null</td><td>NULL</td></tr>
+   *     <tr><td>ArrowType.Bool</td><td>BOOLEAN</td></tr>
+   *     <tr><td>ArrowType.Int(64 bit, unsigned 32 bit)</td><td>LONG</td></tr>
+   *     <tr><td>ArrowType.Int(signed 32 bit, &lt; 32 
bit)</td><td>INT</td></tr>
+   *     <tr><td>ArrowType.FloatingPoint(double)</td><td>DOUBLE</td></tr>
+   *     <tr><td>ArrowType.FloatingPoint(single, half)</td><td>FLOAT</td></tr>
+   *     <tr><td>ArrowType.Utf8</td><td>STRING</td></tr>
+   *     <tr><td>ArrowType.LargeUtf8</td><td>STRING</td></tr>
+   *     <tr><td>ArrowType.Binary</td><td>BYTES</td></tr>
+   *     <tr><td>ArrowType.LargeBinary</td><td>BYTES</td></tr>
+   *     <tr><td>ArrowType.FixedSizeBinary</td><td>FIXED</td></tr>
+   *     <tr><td>ArrowType.Decimal</td><td>decimal (FIXED)</td></tr>
+   *     <tr><td>ArrowType.Date</td><td>date (INT)</td></tr>
+   *     <tr><td>ArrowType.Time (SEC | MILLI)</td><td>time-millis 
(INT)</td></tr>
+   *     <tr><td>ArrowType.Time (MICRO | NANO)</td><td>time-micros 
(LONG)</td></tr>
+   *     <tr><td>ArrowType.Timestamp (NANOSECONDS, TZ != 
NULL)</td><td>time-nanos (LONG)</td></tr>
+   *     <tr><td>ArrowType.Timestamp (MICROSECONDS, TZ != 
NULL)</td><td>time-micros (LONG)</td></tr>
+   *     <tr><td>ArrowType.Timestamp (MILLISECONDS | SECONDS, TZ != 
NULL)</td><td>time-millis (LONG)</td></tr>
+   *     <tr><td>ArrowType.Timestamp (NANOSECONDS, TZ == 
NULL)</td><td>local-time-nanos (LONG)</td></tr>
+   *     <tr><td>ArrowType.Timestamp (MICROSECONDS, TZ == 
NULL)</td><td>local-time-micros (LONG)</td></tr>
+   *     <tr><td>ArrowType.Timestamp (MILLISECONDS | SECONDS, TZ == 
NULL)</td><td>local-time-millis (LONG)</td></tr>
+   *     <tr><td>ArrowType.Duration</td><td>duration (FIXED)</td></tr>
+   *     <tr><td>ArrowType.Interval</td><td>duration (FIXED)</td></tr>
+   *     <tr><td>ArrowType.Struct</td><td>record</td></tr>
+   *     <tr><td>ArrowType.List</td><td>array</td></tr>
+   *     <tr><td>ArrowType.LargeList</td><td>array</td></tr>
+   *     <tr><td>ArrowType.FixedSizeList</td><td>array</td></tr>
+   *     <tr><td>ArrowType.Map</td><td>map</td></tr>
+   *     <tr><td>ArrowType.Union</td><td>union</td></tr>
+   *   </tbody>
+   * </table>
+   *
+   * <p>Nullable fields are represented as a union of [null | base-type]. 
Special treatment is given
+   * to nullability of unions - a union is considered nullable if the union 
field is nullable or any
+   * of its child fields are nullable. The schema for a nullable union will 
always contain a null
+   * type,none of the direct child types will be nullable.
+   *
+   * <p>List fields must contain precisely one child field, which may be 
nullable. Map fields are
+   * represented as a list of structs, where the struct fields are "key" and 
"value". The key field
+   * must always be of type STRING (Utf8) and cannot be nullable. The value 
can be of any type and
+   * may be nullable. Record types must contain at least one child field and 
cannot contain multiple
+   * fields with the same name
+   *
+   * @param arrowFields The arrow fields used to generate the Avro schema
+   * @param typeName Name of the top level Avro record type
+   * @param namespace Namespace of the top level Avro record type
+   * @return An Avro record schema for the given list of fields, with the 
specified name and
+   *     namespace
+   */
+  public static Schema createAvroSchema(
+      List<Field> arrowFields, String typeName, String namespace) {
+    SchemaBuilder.RecordBuilder<Schema> assembler =
+        SchemaBuilder.record(typeName).namespace(namespace);
+    return buildRecordSchema(assembler, arrowFields, namespace);
+  }
+
+  /** Overload provided for convenience, sets namespace = null. */
+  public static Schema createAvroSchema(List<Field> arrowFields, String 
typeName) {
+    return createAvroSchema(arrowFields, typeName, null);
+  }
+
+  /** Overload provided for convenience, sets name = GENERIC_RECORD_TYPE_NAME. 
*/
+  public static Schema createAvroSchema(List<Field> arrowFields) {
+    return createAvroSchema(arrowFields, GENERIC_RECORD_TYPE_NAME);
+  }
+
+  private static <T> T buildRecordSchema(
+      SchemaBuilder.RecordBuilder<T> builder, List<Field> fields, String 
namespace) {
+    if (fields.isEmpty()) {
+      throw new IllegalArgumentException("Record field must have at least one 
child field");
+    }
+    SchemaBuilder.FieldAssembler<T> assembler = 
builder.namespace(namespace).fields();
+    for (Field field : fields) {
+      assembler = buildFieldSchema(assembler, field, namespace);
+    }
+    return assembler.endRecord();
+  }
+
+  private static <T> SchemaBuilder.FieldAssembler<T> buildFieldSchema(
+      SchemaBuilder.FieldAssembler<T> assembler, Field field, String 
namespace) {
+
+    SchemaBuilder.FieldTypeBuilder<T> builder = 
assembler.name(field.getName()).type();
+
+    // Nullable unions need special handling, since union types cannot be 
directly nested
+    if (field.getType().getTypeID() == ArrowType.ArrowTypeID.Union) {
+      boolean unionNullable = 
field.getChildren().stream().anyMatch(Field::isNullable);
+      if (unionNullable) {
+        SchemaBuilder.UnionAccumulator<SchemaBuilder.NullDefault<T>> union =
+            builder.unionOf().nullType();
+        return addTypesToUnion(union, field.getChildren(), 
namespace).nullDefault();
+      } else {
+        Field headType = field.getChildren().get(0);
+        List<Field> tailTypes = field.getChildren().subList(1, 
field.getChildren().size());
+        SchemaBuilder.UnionAccumulator<SchemaBuilder.FieldDefault<T, ?>> union 
=
+            buildUnionFieldSchema(builder.unionOf(), headType, namespace);
+        return addTypesToUnion(union, tailTypes, namespace).noDefault();
+      }
+    } else if (field.isNullable()) {
+      return buildBaseFieldSchema(builder.nullable(), field, namespace);
+    } else {
+      return buildBaseFieldSchema(builder, field, namespace);
+    }
+  }
+
+  private static <T> T buildArraySchema(
+      SchemaBuilder.ArrayBuilder<T> builder, Field listField, String 
namespace) {
+    if (listField.getChildren().size() != 1) {
+      throw new IllegalArgumentException("List field must have exactly one 
child field");
+    }
+    Field itemField = listField.getChildren().get(0);
+    return buildTypeSchema(builder.items(), itemField, namespace);
+  }
+
+  private static <T> T buildMapSchema(
+      SchemaBuilder.MapBuilder<T> builder, Field mapField, String namespace) {
+    if (mapField.getChildren().size() != 1) {
+      throw new IllegalArgumentException("Map field must have exactly one 
child field");
+    }
+    Field entriesField = mapField.getChildren().get(0);
+    if (mapField.getChildren().size() != 1) {
+      throw new IllegalArgumentException("Map entries must have exactly two 
child fields");
+    }
+    Field keyField = entriesField.getChildren().get(0);
+    Field valueField = entriesField.getChildren().get(1);
+    if (keyField.getType().getTypeID() != ArrowType.ArrowTypeID.Utf8 || 
keyField.isNullable()) {
+      throw new IllegalArgumentException(
+          "Map keys must be of type string and cannot be nullable for 
conversion to Avro");
+    }
+    return buildTypeSchema(builder.values(), valueField, namespace);
+  }
+
+  private static <T> T buildTypeSchema(
+      SchemaBuilder.TypeBuilder<T> builder, Field field, String namespace) {
+
+    // Nullable unions need special handling, since union types cannot be 
directly nested
+    if (field.getType().getTypeID() == ArrowType.ArrowTypeID.Union) {
+      boolean unionNullable = 
field.getChildren().stream().anyMatch(Field::isNullable);
+      if (unionNullable) {
+        SchemaBuilder.UnionAccumulator<T> union = builder.unionOf().nullType();
+        return addTypesToUnion(union, field.getChildren(), namespace);
+      } else {
+        Field headType = field.getChildren().get(0);
+        List<Field> tailTypes = field.getChildren().subList(1, 
field.getChildren().size());
+        SchemaBuilder.UnionAccumulator<T> union =
+            buildBaseTypeSchema(builder.unionOf(), headType, namespace);
+        return addTypesToUnion(union, tailTypes, namespace);
+      }
+    } else if (field.isNullable()) {
+      return buildBaseTypeSchema(builder.nullable(), field, namespace);
+    } else {
+      return buildBaseTypeSchema(builder, field, namespace);
+    }
+  }
+
+  private static <T> T buildBaseTypeSchema(
+      SchemaBuilder.BaseTypeBuilder<T> builder, Field field, String namespace) 
{
+
+    ArrowType.ArrowTypeID typeID = field.getType().getTypeID();
+
+    switch (typeID) {
+      case Null:
+        return builder.nullType();
+
+      case Bool:
+        return builder.booleanType();
+
+      case Int:
+        ArrowType.Int intType = (ArrowType.Int) field.getType();
+        if (intType.getBitWidth() > 32 || (intType.getBitWidth() == 32 && 
!intType.getIsSigned())) {
+          return builder.longType();
+        } else {
+          return builder.intType();
+        }
+
+      case FloatingPoint:
+        ArrowType.FloatingPoint floatType = (ArrowType.FloatingPoint) 
field.getType();
+        if (floatType.getPrecision() == FloatingPointPrecision.DOUBLE) {
+          return builder.doubleType();
+        } else {
+          return builder.floatType();
+        }
+
+      case Utf8:
+        return builder.stringType();
+
+      case Binary:
+        return builder.bytesType();
+
+      case FixedSizeBinary:
+        ArrowType.FixedSizeBinary fixedType = (ArrowType.FixedSizeBinary) 
field.getType();
+        String fixedTypeName = field.getName();
+        int fixedTypeWidth = fixedType.getByteWidth();
+        return builder.fixed(fixedTypeName).size(fixedTypeWidth);
+
+      case Decimal:
+        ArrowType.Decimal decimalType = (ArrowType.Decimal) field.getType();
+        return builder
+            .fixed(field.getName())
+            .prop("logicalType", "decimal")
+            .prop("precision", decimalType.getPrecision())
+            .prop("scale", decimalType.getScale())
+            .size(decimalType.getBitWidth() / 8);
+
+      case Date:
+        return builder.intBuilder().prop("logicalType", "date").endInt();
+
+      case Time:
+        ArrowType.Time timeType = (ArrowType.Time) field.getType();
+        if ((timeType.getUnit() == TimeUnit.SECOND || timeType.getUnit() == 
TimeUnit.MILLISECOND)) {
+          return builder.intBuilder().prop("logicalType", 
"time-millis").endInt();
+        } else {
+          // All other time types (sec, micro, nano) are encoded as 
time-micros (LONG)
+          return builder.longBuilder().prop("logicalType", 
"time-micros").endLong();
+        }
+
+      case Timestamp:
+        ArrowType.Timestamp timestampType = (ArrowType.Timestamp) 
field.getType();
+        String timestampLogicalType = timestampLogicalType(timestampType);
+        return builder.longBuilder().prop("logicalType", 
timestampLogicalType).endLong();
+
+      case Struct:
+        String childNamespace =
+            namespace == null ? field.getName() : namespace + "." + 
field.getName();
+        return buildRecordSchema(
+            builder.record(field.getName()), field.getChildren(), 
childNamespace);
+
+      case List:
+      case FixedSizeList:
+        return buildArraySchema(builder.array(), field, namespace);
+
+      case Map:
+        return buildMapSchema(builder.map(), field, namespace);
+
+      default:
+        throw new IllegalArgumentException(
+            "Element type not supported for Avro conversion: " + 
typeID.name());
+    }
+  }
+
+  private static <T> SchemaBuilder.FieldAssembler<T> buildBaseFieldSchema(
+      SchemaBuilder.BaseFieldTypeBuilder<T> builder, Field field, String 
namespace) {
+
+    ArrowType.ArrowTypeID typeID = field.getType().getTypeID();
+
+    switch (typeID) {
+      case Null:
+        return builder.nullType().noDefault();
+
+      case Bool:
+        return builder.booleanType().noDefault();
+
+      case Int:
+        ArrowType.Int intType = (ArrowType.Int) field.getType();
+        if (intType.getBitWidth() > 32 || (intType.getBitWidth() == 32 && 
!intType.getIsSigned())) {
+          return builder.longType().noDefault();
+        } else {
+          return builder.intType().noDefault();
+        }
+
+      case FloatingPoint:
+        ArrowType.FloatingPoint floatType = (ArrowType.FloatingPoint) 
field.getType();
+        if (floatType.getPrecision() == FloatingPointPrecision.DOUBLE) {
+          return builder.doubleType().noDefault();
+        } else {
+          return builder.floatType().noDefault();
+        }
+
+      case Utf8:
+        return builder.stringType().noDefault();
+
+      case Binary:
+        return builder.bytesType().noDefault();
+
+      case FixedSizeBinary:
+        ArrowType.FixedSizeBinary fixedType = (ArrowType.FixedSizeBinary) 
field.getType();
+        return 
builder.fixed(field.getName()).size(fixedType.getByteWidth()).noDefault();
+
+      case Decimal:
+        ArrowType.Decimal decimalType = (ArrowType.Decimal) field.getType();
+        return builder
+            .fixed(field.getName())
+            .prop("logicalType", "decimal")
+            .prop("precision", decimalType.getPrecision())
+            .prop("scale", decimalType.getScale())
+            .size(decimalType.getBitWidth() / 8)
+            .noDefault();
+
+      case Date:
+        return builder.intBuilder().prop("logicalType", 
"date").endInt().noDefault();
+
+      case Time:
+        ArrowType.Time timeType = (ArrowType.Time) field.getType();
+        if ((timeType.getUnit() == TimeUnit.SECOND || timeType.getUnit() == 
TimeUnit.MILLISECOND)) {
+          return builder.intBuilder().prop("logicalType", 
"time-millis").endInt().noDefault();
+        } else {
+          // All other time types (sec, micro, nano) are encoded as 
time-micros (LONG)

Review Comment:
   Comments updated



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to