apilloud commented on a change in pull request #14518:
URL: https://github.com/apache/beam/pull/14518#discussion_r614297271



##########
File path: 
sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java
##########
@@ -311,110 +285,124 @@ public void processElement(ProcessContext c) {
     return jarPaths.build();
   }
 
-  private static final Map<TypeName, Type> rawTypeMap =
-      ImmutableMap.<TypeName, Type>builder()
-          .put(TypeName.BYTE, Byte.class)
-          .put(TypeName.INT16, Short.class)
-          .put(TypeName.INT32, Integer.class)
-          .put(TypeName.INT64, Long.class)
-          .put(TypeName.FLOAT, Float.class)
-          .put(TypeName.DOUBLE, Double.class)
-          .build();
-
-  private static Expression castOutput(Expression value, FieldType toType) {
-    Expression returnValue = value;
-    if (value.getType() == Object.class || !(value.getType() instanceof 
Class)) {
-      // fast copy path, just pass object through
-      returnValue = value;
-    } else if (CalciteUtils.isDateTimeType(toType)
-        && !Types.isAssignableFrom(ReadableInstant.class, (Class) 
value.getType())) {
-      returnValue = castOutputTime(value, toType);
-    } else if (toType.getTypeName() == TypeName.DECIMAL
-        && !Types.isAssignableFrom(BigDecimal.class, (Class) value.getType())) 
{
-      returnValue = Expressions.new_(BigDecimal.class, value);
-    } else if (toType.getTypeName() == TypeName.BYTES
-        && Types.isAssignableFrom(ByteString.class, (Class) value.getType())) {
-      returnValue =
-          Expressions.condition(
-              Expressions.equal(value, Expressions.constant(null)),
-              Expressions.constant(null),
-              Expressions.call(value, "getBytes"));
-    } else if (((Class) value.getType()).isPrimitive()
-        || Types.isAssignableFrom(Number.class, (Class) value.getType())) {
-      Type rawType = rawTypeMap.get(toType.getTypeName());
-      if (rawType != null) {
-        returnValue = Types.castIfNecessary(rawType, value);
-      }
-    } else if (Types.isAssignableFrom(Iterable.class, value.getType())) {
-      // Passing an Iterable into newArrayList gets interpreted to mean 
copying each individual
-      // element. We want the
-      // entire Iterable to be treated as a single element, so we cast to 
Object.
-      returnValue = Expressions.convert_(value, Object.class);
+  static Object toBeamObject(Object value, FieldType fieldType, boolean 
verifyValues) {
+    if (value == null) {
+      return null;
+    }
+    switch (fieldType.getTypeName()) {
+      case BYTE:
+        return ((Number) value).byteValue();
+      case INT16:
+        return ((Number) value).shortValue();
+      case INT32:
+        return ((Number) value).intValue();
+      case INT64:
+        return ((Number) value).longValue();
+      case FLOAT:
+        return ((Number) value).floatValue();
+      case DOUBLE:
+        return ((Number) value).doubleValue();
+      case DECIMAL:
+        if (value instanceof BigDecimal) {
+          return (BigDecimal) value;
+        } else if (value instanceof Long) {
+          return BigDecimal.valueOf((Long) value);
+        } else if (value instanceof Integer) {
+          return BigDecimal.valueOf((Integer) value);
+        }
+        return new BigDecimal(((Number) value).toString());
+      case STRING:
+        return (String) value;
+      case BOOLEAN:
+        return (Boolean) value;
+      case DATETIME:
+        if (value instanceof Timestamp) {
+          value = SqlFunctions.toLong((Timestamp) value);
+        }
+        return Instant.ofEpochMilli(((Number) value).longValue());
+      case BYTES:
+        if (value instanceof byte[]) {
+          return value;
+        }
+        return ((ByteString) value).getBytes();
+      case ARRAY:
+        return toBeamList((List<Object>) value, 
fieldType.getCollectionElementType(), verifyValues);
+      case MAP:
+        return toBeamMap(
+            (Map<Object, Object>) value,
+            fieldType.getMapKeyType(),
+            fieldType.getMapValueType(),
+            verifyValues);
+      case ROW:
+        if (value instanceof Object[]) {
+          value = Arrays.asList((Object[]) value);
+        }
+        return toBeamRow((List<Object>) value, fieldType.getRowSchema(), 
verifyValues);
+      case LOGICAL_TYPE:
+        String identifier = fieldType.getLogicalType().getIdentifier();
+        if (CharType.IDENTIFIER.equals(identifier)) {
+          return (String) value;
+        } else if (TimeWithLocalTzType.IDENTIFIER.equals(identifier)) {
+          return Instant.ofEpochMilli(((Number) value).longValue());
+        } else if (SqlTypes.DATE.getIdentifier().equals(identifier)) {
+          if (value instanceof Date) {
+            value = SqlFunctions.toInt((Date) value);
+          }
+          // This should always be Integer, but it isn't.
+          return LocalDate.ofEpochDay(((Number) value).longValue());
+        } else if (SqlTypes.TIME.getIdentifier().equals(identifier)) {
+          if (value instanceof Time) {
+            value = SqlFunctions.toInt((Time) value);
+          }
+          // This should always be Integer, but it isn't.
+          return LocalTime.ofNanoOfDay(((Number) value).longValue() * 
NANOS_PER_MILLISECOND);
+        } else if (SqlTypes.DATETIME.getIdentifier().equals(identifier)) {
+          if (value instanceof Timestamp) {
+            value = SqlFunctions.toLong((Timestamp) value);
+          }
+          return LocalDateTime.of(
+              LocalDate.ofEpochDay(((Number) value).longValue() / 
MILLIS_PER_DAY),
+              LocalTime.ofNanoOfDay(
+                  (((Number) value).longValue() % MILLIS_PER_DAY) * 
NANOS_PER_MILLISECOND));
+        } else {
+          throw new UnsupportedOperationException("Unable to convert logical 
type " + identifier);
+        }
+      default:
+        throw new UnsupportedOperationException("Unable to convert " + 
fieldType.getTypeName());
     }
-    returnValue =
-        Expressions.condition(
-            Expressions.equal(value, Expressions.constant(null)),
-            Expressions.constant(null),
-            returnValue);
-    return returnValue;
   }
 
-  private static Expression castOutputTime(Expression value, FieldType toType) 
{
-    Expression valueDateTime = value;
+  private static List<Object> toBeamList(
+      List<Object> arrayValue, FieldType elementType, boolean verifyValues) {
+    return arrayValue.stream()
+        .map(e -> toBeamObject(e, elementType, verifyValues))
+        .collect(Collectors.toList());
+  }
 
-    if (CalciteUtils.TIMESTAMP.typesEqual(toType)
-        || CalciteUtils.NULLABLE_TIMESTAMP.typesEqual(toType)) {
-      // Convert TIMESTAMP to joda Instant
-      if (value.getType() == java.sql.Timestamp.class) {
-        valueDateTime = 
Expressions.call(BuiltInMethod.TIMESTAMP_TO_LONG.method, valueDateTime);
-      }
-      valueDateTime = Expressions.new_(Instant.class, valueDateTime);
-    } else if (CalciteUtils.TIME.typesEqual(toType)
-        || CalciteUtils.NULLABLE_TIME.typesEqual(toType)) {
-      // Convert TIME to LocalTime
-      if (value.getType() == java.sql.Time.class) {
-        valueDateTime = Expressions.call(BuiltInMethod.TIME_TO_INT.method, 
valueDateTime);
-      } else if (value.getType() == Integer.class || value.getType() == 
Long.class) {
-        valueDateTime = Expressions.unbox(valueDateTime);
-      }
-      valueDateTime =
-          Expressions.multiply(valueDateTime, 
Expressions.constant(NANOS_PER_MILLISECOND));
-      valueDateTime = Expressions.call(LocalTime.class, "ofNanoOfDay", 
valueDateTime);
-    } else if (CalciteUtils.DATE.typesEqual(toType)
-        || CalciteUtils.NULLABLE_DATE.typesEqual(toType)) {
-      // Convert DATE to LocalDate
-      if (value.getType() == java.sql.Date.class) {
-        valueDateTime = Expressions.call(BuiltInMethod.DATE_TO_INT.method, 
valueDateTime);
-      } else if (value.getType() == Integer.class || value.getType() == 
Long.class) {
-        valueDateTime = Expressions.unbox(valueDateTime);
-      }
-      valueDateTime = Expressions.call(LocalDate.class, "ofEpochDay", 
valueDateTime);
-    } else if (CalciteUtils.TIMESTAMP_WITH_LOCAL_TZ.typesEqual(toType)
-        || CalciteUtils.NULLABLE_TIMESTAMP_WITH_LOCAL_TZ.typesEqual(toType)) {
-      // Convert TimeStamp_With_Local_TimeZone to LocalDateTime
-      Expression dateValue =
-          Expressions.divide(valueDateTime, 
Expressions.constant(MILLIS_PER_DAY));
-      Expression date = Expressions.call(LocalDate.class, "ofEpochDay", 
dateValue);
-      Expression timeValue =
-          Expressions.multiply(
-              Expressions.modulo(valueDateTime, 
Expressions.constant(MILLIS_PER_DAY)),
-              Expressions.constant(NANOS_PER_MILLISECOND));
-      Expression time = Expressions.call(LocalTime.class, "ofNanoOfDay", 
timeValue);
-      valueDateTime = Expressions.call(LocalDateTime.class, "of", date, time);
-    } else {
-      throw new UnsupportedOperationException("Unknown DateTime type " + 
toType);
+  private static Map<Object, Object> toBeamMap(
+      Map<Object, Object> mapValue,
+      FieldType keyType,
+      FieldType elementType,
+      boolean verifyValues) {
+    Map<Object, Object> output = new HashMap<>(mapValue.size());
+    for (Map.Entry<Object, Object> entry : mapValue.entrySet()) {
+      output.put(
+          toBeamObject(entry.getKey(), keyType, verifyValues),
+          toBeamObject(entry.getValue(), elementType, verifyValues));
     }
+    return output;
+  }
 
-    // make conversion conditional on non-null input.
-    if (!((Class) value.getType()).isPrimitive()) {
-      valueDateTime =
-          Expressions.condition(
-              Expressions.equal(value, Expressions.constant(null)),
-              Expressions.constant(null),
-              valueDateTime);
+  private static Row toBeamRow(List<Object> structValue, Schema schema, 
boolean verifyValues) {
+    List<Object> objects = new ArrayList<>(schema.getFieldCount());
+    for (int i = 0; i < structValue.size(); i++) {

Review comment:
       Done

##########
File path: 
sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
##########
@@ -83,9 +87,40 @@ public void testJodaTimeUdf() throws Exception {
     pipeline.run().waitUntilFinish();
   }
 
-  /** Test Joda time UDAF. */
   @Test
-  public void testJodaTimeUdaf() throws Exception {
+  public void testDateUdf() throws Exception {
+    Schema resultType =
+        Schema.builder().addField("jodatime", 
FieldType.logicalType(SqlTypes.DATE)).build();
+
+    Row row = Row.withSchema(resultType).addValues(LocalDate.of(2016, 12, 
31)).build();
+
+    String sql = "SELECT PRE_DATE(f_date) as jodatime FROM PCOLLECTION WHERE 
f_int=1";
+    PCollection<Row> result =
+        boundedInput1.apply(
+            "testTimeUdf", SqlTransform.query(sql).registerUdf("PRE_DATE", 
PreviousDate.class));
+    PAssert.that(result).containsInAnyOrder(row);
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  @Test
+  public void testTimeUdf() throws Exception {
+    Schema resultType =
+        Schema.builder().addField("jodatime", 
FieldType.logicalType(SqlTypes.TIME)).build();

Review comment:
       Done

##########
File path: 
sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
##########
@@ -83,9 +87,40 @@ public void testJodaTimeUdf() throws Exception {
     pipeline.run().waitUntilFinish();
   }
 
-  /** Test Joda time UDAF. */
   @Test
-  public void testJodaTimeUdaf() throws Exception {
+  public void testDateUdf() throws Exception {
+    Schema resultType =
+        Schema.builder().addField("jodatime", 
FieldType.logicalType(SqlTypes.DATE)).build();
+
+    Row row = Row.withSchema(resultType).addValues(LocalDate.of(2016, 12, 
31)).build();
+
+    String sql = "SELECT PRE_DATE(f_date) as jodatime FROM PCOLLECTION WHERE 
f_int=1";

Review comment:
       Done

##########
File path: 
sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
##########
@@ -83,9 +87,40 @@ public void testJodaTimeUdf() throws Exception {
     pipeline.run().waitUntilFinish();
   }
 
-  /** Test Joda time UDAF. */
   @Test
-  public void testJodaTimeUdaf() throws Exception {
+  public void testDateUdf() throws Exception {
+    Schema resultType =
+        Schema.builder().addField("jodatime", 
FieldType.logicalType(SqlTypes.DATE)).build();

Review comment:
       Done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to