This is an automated email from the ASF dual-hosted git repository.

apilloud pushed a commit to branch fad
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 72ee5d3547d3990fbf61e9ed61aec5958effcb03
Author: Andrew Pilloud <apill...@google.com>
AuthorDate: Wed Nov 3 14:33:18 2021 -0700

    [BEAM-13056] Expose FieldAccess in DoFnSchemaInformation
---
 .../beam/sdk/transforms/DoFnSchemaInformation.java |  33 +++++-
 .../java/org/apache/beam/sdk/transforms/ParDo.java |  11 ++
 .../org/apache/beam/sdk/transforms/ParDoTest.java  | 126 +++++++++++++++++++++
 .../extensions/sql/impl/rel/BeamCalcRelTest.java   |  29 +++--
 .../sql/zetasql/BeamZetaSqlCalcRelTest.java        |  29 +++--
 5 files changed, 194 insertions(+), 34 deletions(-)

diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnSchemaInformation.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnSchemaInformation.java
index 54f6d19..9a8ccac 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnSchemaInformation.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnSchemaInformation.java
@@ -51,10 +51,14 @@ public abstract class DoFnSchemaInformation implements 
Serializable {
    */
   public abstract List<SerializableFunction<?, ?>> getElementConverters();
 
+  /** Effective FieldAccessDescriptor applied by DoFn. */
+  public abstract FieldAccessDescriptor getFieldAccessDescriptor();
+
   /** Create an instance. */
   public static DoFnSchemaInformation create() {
     return new AutoValue_DoFnSchemaInformation.Builder()
         .setElementConverters(Collections.emptyList())
+        .setFieldAccessDescriptor(FieldAccessDescriptor.create())
         .build();
   }
 
@@ -63,6 +67,8 @@ public abstract class DoFnSchemaInformation implements 
Serializable {
   public abstract static class Builder {
     abstract Builder setElementConverters(List<SerializableFunction<?, ?>> 
converters);
 
+    abstract Builder setFieldAccessDescriptor(FieldAccessDescriptor 
descriptor);
+
     abstract DoFnSchemaInformation build();
   }
 
@@ -101,7 +107,10 @@ public abstract class DoFnSchemaInformation implements 
Serializable {
                     unbox))
             .build();
 
-    return toBuilder().setElementConverters(converters).build();
+    return toBuilder()
+        .setElementConverters(converters)
+        .setFieldAccessDescriptor(getFieldAccessDescriptor())
+        .build();
   }
 
   /**
@@ -141,7 +150,27 @@ public abstract class DoFnSchemaInformation implements 
Serializable {
                     elementT))
             .build();
 
-    return toBuilder().setElementConverters(converters).build();
+    return toBuilder()
+        .setElementConverters(converters)
+        .setFieldAccessDescriptor(getFieldAccessDescriptor())
+        .build();
+  }
+
+  /**
+   * Specified a descriptor of fields accessed from an input schema.
+   *
+   * @param selectDescriptor The descriptor describing which field to select.
+   * @return
+   */
+  DoFnSchemaInformation withFieldAccessDescriptor(FieldAccessDescriptor 
selectDescriptor) {
+
+    FieldAccessDescriptor descriptor =
+        
FieldAccessDescriptor.union(ImmutableList.of(getFieldAccessDescriptor(), 
selectDescriptor));
+
+    return toBuilder()
+        .setElementConverters(getElementConverters())
+        .setFieldAccessDescriptor(descriptor)
+        .build();
   }
 
   private static class ConversionFunction<InputT, OutputT>
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
index 9da0998..97cd6cd 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
@@ -49,6 +49,8 @@ import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
 import 
org.apache.beam.sdk.transforms.reflect.DoFnSignature.FieldAccessDeclaration;
 import 
org.apache.beam.sdk.transforms.reflect.DoFnSignature.MethodWithExtraParameters;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.OnTimerMethod;
+import 
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.ElementParameter;
+import 
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.ProcessContextParameter;
 import 
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SchemaElementParameter;
 import 
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SideInputParameter;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
@@ -637,6 +639,7 @@ public class ParDo {
               fn.getClass().getName()));
     }
   }
+
   /**
    * Extract information on how the DoFn uses schemas. In particular, if the 
schema of an element
    * parameter does not match the input PCollection's schema, convert.
@@ -662,6 +665,7 @@ public class ParDo {
               input.getSchema(),
               signature.fieldAccessDeclarations(),
               fn);
+      doFnSchemaInformation = 
doFnSchemaInformation.withFieldAccessDescriptor(accessDescriptor);
       Schema selectedSchema = SelectHelpers.getOutputSchema(input.getSchema(), 
accessDescriptor);
       ConvertHelpers.ConvertedSchemaInformation converted =
           ConvertHelpers.getConvertedSchemaInformation(selectedSchema, 
elementT, schemaRegistry);
@@ -683,6 +687,13 @@ public class ParDo {
                 (SchemaCoder<?>) input.getCoder(), accessDescriptor, 
selectedSchema, elementT);
       }
     }
+    for (DoFnSignature.Parameter p : processElementMethod.extraParameters()) {
+      if (p instanceof ProcessContextParameter || p instanceof 
ElementParameter) {
+        doFnSchemaInformation =
+            
doFnSchemaInformation.withFieldAccessDescriptor(FieldAccessDescriptor.withAllFields());
+        break;
+      }
+    }
 
     return doFnSchemaInformation;
   }
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
index 8becf9a..173bdad 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
@@ -76,6 +76,8 @@ import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.io.GenerateSequence;
 import org.apache.beam.sdk.options.Default;
 import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
+import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.state.BagState;
 import org.apache.beam.sdk.state.CombiningState;
 import org.apache.beam.sdk.state.GroupingState;
@@ -135,6 +137,7 @@ import org.apache.beam.sdk.values.PCollection.IsBounded;
 import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.PDone;
+import org.apache.beam.sdk.values.Row;
 import org.apache.beam.sdk.values.TimestampedValue;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.TupleTagList;
@@ -6141,4 +6144,127 @@ public class ParDoTest implements Serializable {
       pipeline.run();
     }
   }
+
+  /** Tests to validate SchemaInformation. */
+  @RunWith(JUnit4.class)
+  public static class SchemaInformationTests extends SharedTestBase implements 
Serializable {
+
+    private static final Schema TEST_SCHEMA =
+        
Schema.builder().addInt32Field("f_int").addStringField("f_string").build();
+    private static final Row TEST_ROW = 
Row.withSchema(TEST_SCHEMA).addValues(10, "ten").build();
+
+    @Test
+    public void testUnboxFieldAccess() throws Exception {
+
+      DoFn<Row, Integer> fn =
+          new DoFn<Row, Integer>() {
+            @ProcessElement
+            public void processElement(
+                @FieldAccess("f_int") Integer value, OutputReceiver<Integer> 
r) {
+              r.output(value);
+            }
+          };
+      PCollection<Row> input = 
pipeline.apply(Create.of(TEST_ROW).withRowSchema(TEST_SCHEMA));
+
+      DoFnSchemaInformation info = ParDo.getDoFnSchemaInformation(fn, input);
+      assertEquals(info.getElementConverters().toString(), 1, 
info.getElementConverters().size());
+
+      FieldAccessDescriptor fieldAccessDescriptor = 
info.getFieldAccessDescriptor();
+      assertEquals(
+          fieldAccessDescriptor.toString(), 1, 
fieldAccessDescriptor.getFieldsAccessed().size());
+      assertEquals(
+          "f_int",
+          
Iterables.getOnlyElement(fieldAccessDescriptor.getFieldsAccessed()).getFieldName());
+      assertFalse(fieldAccessDescriptor.toString(), 
fieldAccessDescriptor.getAllFields());
+    }
+
+    @Test
+    public void testSingleFieldAccess() throws Exception {
+
+      DoFn<Row, Integer> fn =
+          new DoFn<Row, Integer>() {
+            @FieldAccess("foo")
+            final FieldAccessDescriptor fieldAccess = 
FieldAccessDescriptor.withFieldNames("f_int");
+
+            @ProcessElement
+            public void processElement(@FieldAccess("foo") Row row, 
OutputReceiver<Integer> r) {
+              r.output(row.getInt32("f_int"));
+            }
+          };
+      PCollection<Row> input = 
pipeline.apply(Create.of(TEST_ROW).withRowSchema(TEST_SCHEMA));
+
+      DoFnSchemaInformation info = ParDo.getDoFnSchemaInformation(fn, input);
+      assertEquals(info.getElementConverters().toString(), 1, 
info.getElementConverters().size());
+
+      FieldAccessDescriptor fieldAccessDescriptor = 
info.getFieldAccessDescriptor();
+      assertEquals(
+          fieldAccessDescriptor.toString(), 1, 
fieldAccessDescriptor.getFieldsAccessed().size());
+      assertEquals(
+          "f_int",
+          
Iterables.getOnlyElement(fieldAccessDescriptor.getFieldsAccessed()).getFieldName());
+      assertFalse(fieldAccessDescriptor.toString(), 
fieldAccessDescriptor.getAllFields());
+    }
+
+    @Test
+    public void testImplicitElement() throws Exception {
+
+      DoFn<Row, Integer> fn =
+          new DoFn<Row, Integer>() {
+            @ProcessElement
+            public void processElement(@Element Row row, 
OutputReceiver<Integer> r) {
+              r.output(row.getInt32("f_int"));
+            }
+          };
+      PCollection<Row> input = 
pipeline.apply(Create.of(TEST_ROW).withRowSchema(TEST_SCHEMA));
+
+      DoFnSchemaInformation info = ParDo.getDoFnSchemaInformation(fn, input);
+      assertEquals(info.getElementConverters().toString(), 0, 
info.getElementConverters().size());
+
+      FieldAccessDescriptor fieldAccessDescriptor = 
info.getFieldAccessDescriptor();
+      assertTrue(fieldAccessDescriptor.toString(), 
fieldAccessDescriptor.getAllFields());
+    }
+
+    @Test
+    public void testImplicitProcessContext() throws Exception {
+
+      DoFn<Row, Integer> fn =
+          new DoFn<Row, Integer>() {
+            @ProcessElement
+            public void processElement(ProcessContext c) {
+              c.output(c.element().getInt32("f_int"));
+            }
+          };
+      PCollection<Row> input = 
pipeline.apply(Create.of(TEST_ROW).withRowSchema(TEST_SCHEMA));
+
+      DoFnSchemaInformation info = ParDo.getDoFnSchemaInformation(fn, input);
+      assertEquals(info.getElementConverters().toString(), 0, 
info.getElementConverters().size());
+
+      FieldAccessDescriptor fieldAccessDescriptor = 
info.getFieldAccessDescriptor();
+      assertTrue(fieldAccessDescriptor.toString(), 
fieldAccessDescriptor.getAllFields());
+    }
+
+    @Test
+    public void testNoAccess() throws Exception {
+
+      DoFn<Row, Integer> fn =
+          new DoFn<Row, Integer>() {
+            @ProcessElement
+            public void processElement(OutputReceiver<Integer> r) {
+              r.output(1);
+            }
+          };
+      PCollection<Row> input = 
pipeline.apply(Create.of(TEST_ROW).withRowSchema(TEST_SCHEMA));
+
+      DoFnSchemaInformation info = ParDo.getDoFnSchemaInformation(fn, input);
+      assertEquals(info.getElementConverters().toString(), 0, 
info.getElementConverters().size());
+
+      FieldAccessDescriptor fieldAccessDescriptor = 
info.getFieldAccessDescriptor();
+      assertFalse(fieldAccessDescriptor.toString(), 
fieldAccessDescriptor.getAllFields());
+      assertTrue(
+          fieldAccessDescriptor.toString(), 
fieldAccessDescriptor.getFieldsAccessed().isEmpty());
+      assertTrue(
+          fieldAccessDescriptor.toString(),
+          fieldAccessDescriptor.getNestedFieldsAccessed().isEmpty());
+    }
+  }
 }
diff --git 
a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRelTest.java
 
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRelTest.java
index 385d908..702dd9f 100644
--- 
a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRelTest.java
+++ 
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRelTest.java
@@ -28,13 +28,13 @@ import org.apache.beam.sdk.runners.TransformHierarchy;
 import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
 import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
 import org.apache.beam.sdk.transforms.ParDo;
-import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
-import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PValue;
 import org.apache.beam.sdk.values.Row;
 import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.RelNode;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 import org.joda.time.DateTime;
 import org.joda.time.Duration;
 import org.junit.Assert;
@@ -213,18 +213,15 @@ public class BeamCalcRelTest extends BaseRelTest {
 
     ParDo.MultiOutput<Row, Row> pardo =
         (ParDo.MultiOutput<Row, Row>) nodeGetter.producer.getTransform();
-    DoFnSignature sig = DoFnSignatures.getSignature(pardo.getFn().getClass());
+    PCollection<Row> input =
+        (PCollection<Row>) 
Iterables.getOnlyElement(nodeGetter.producer.getInputs().values());
 
-    Assert.assertEquals(1, sig.fieldAccessDeclarations().size());
-    DoFnSignature.FieldAccessDeclaration dec =
-        sig.fieldAccessDeclarations().values().iterator().next();
-    FieldAccessDescriptor fieldAccess = (FieldAccessDescriptor) 
dec.field().get(pardo.getFn());
+    DoFnSchemaInformation info = ParDo.getDoFnSchemaInformation(pardo.getFn(), 
input);
 
-    Assert.assertTrue(fieldAccess.referencesSingleField());
+    FieldAccessDescriptor fieldAccess = info.getFieldAccessDescriptor();
 
-    fieldAccess =
-        
fieldAccess.resolve(nodeGetter.producer.getInputs().values().iterator().next().getSchema());
-    Assert.assertEquals("order_id", 
fieldAccess.fieldNamesAccessed().iterator().next());
+    Assert.assertTrue(fieldAccess.referencesSingleField());
+    Assert.assertEquals("order_id", 
Iterables.getOnlyElement(fieldAccess.fieldNamesAccessed()));
 
     pipeline.run().waitUntilFinish();
   }
@@ -240,12 +237,12 @@ public class BeamCalcRelTest extends BaseRelTest {
 
     ParDo.MultiOutput<Row, Row> pardo =
         (ParDo.MultiOutput<Row, Row>) nodeGetter.producer.getTransform();
-    DoFnSignature sig = DoFnSignatures.getSignature(pardo.getFn().getClass());
+    PCollection<Row> input =
+        (PCollection<Row>) 
Iterables.getOnlyElement(nodeGetter.producer.getInputs().values());
+
+    DoFnSchemaInformation info = ParDo.getDoFnSchemaInformation(pardo.getFn(), 
input);
 
-    Assert.assertEquals(1, sig.fieldAccessDeclarations().size());
-    DoFnSignature.FieldAccessDeclaration dec =
-        sig.fieldAccessDeclarations().values().iterator().next();
-    FieldAccessDescriptor fieldAccess = (FieldAccessDescriptor) 
dec.field().get(pardo.getFn());
+    FieldAccessDescriptor fieldAccess = info.getFieldAccessDescriptor();
 
     Assert.assertFalse(fieldAccess.getAllFields());
     Assert.assertTrue(fieldAccess.getFieldsAccessed().isEmpty());
diff --git 
a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRelTest.java
 
b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRelTest.java
index 352e83a..b490458 100644
--- 
a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRelTest.java
+++ 
b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRelTest.java
@@ -24,12 +24,12 @@ import 
org.apache.beam.sdk.extensions.sql.impl.rel.BeamSqlRelUtils;
 import org.apache.beam.sdk.runners.TransformHierarchy;
 import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
 import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
 import org.apache.beam.sdk.transforms.ParDo;
-import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
-import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PValue;
 import org.apache.beam.sdk.values.Row;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Rule;
@@ -80,18 +80,15 @@ public class BeamZetaSqlCalcRelTest extends ZetaSqlTestBase 
{
 
     ParDo.MultiOutput<Row, Row> pardo =
         (ParDo.MultiOutput<Row, Row>) nodeGetter.producer.getTransform();
-    DoFnSignature sig = DoFnSignatures.getSignature(pardo.getFn().getClass());
+    PCollection<Row> input =
+        (PCollection<Row>) 
Iterables.getOnlyElement(nodeGetter.producer.getInputs().values());
 
-    Assert.assertEquals(1, sig.fieldAccessDeclarations().size());
-    DoFnSignature.FieldAccessDeclaration dec =
-        sig.fieldAccessDeclarations().values().iterator().next();
-    FieldAccessDescriptor fieldAccess = (FieldAccessDescriptor) 
dec.field().get(pardo.getFn());
+    DoFnSchemaInformation info = ParDo.getDoFnSchemaInformation(pardo.getFn(), 
input);
 
-    Assert.assertTrue(fieldAccess.referencesSingleField());
+    FieldAccessDescriptor fieldAccess = info.getFieldAccessDescriptor();
 
-    fieldAccess =
-        
fieldAccess.resolve(nodeGetter.producer.getInputs().values().iterator().next().getSchema());
-    Assert.assertEquals("Key", 
fieldAccess.fieldNamesAccessed().iterator().next());
+    Assert.assertTrue(fieldAccess.referencesSingleField());
+    Assert.assertEquals("Key", 
Iterables.getOnlyElement(fieldAccess.fieldNamesAccessed()));
 
     pipeline.run().waitUntilFinish();
   }
@@ -107,12 +104,12 @@ public class BeamZetaSqlCalcRelTest extends 
ZetaSqlTestBase {
 
     ParDo.MultiOutput<Row, Row> pardo =
         (ParDo.MultiOutput<Row, Row>) nodeGetter.producer.getTransform();
-    DoFnSignature sig = DoFnSignatures.getSignature(pardo.getFn().getClass());
+    PCollection<Row> input =
+        (PCollection<Row>) 
Iterables.getOnlyElement(nodeGetter.producer.getInputs().values());
+
+    DoFnSchemaInformation info = ParDo.getDoFnSchemaInformation(pardo.getFn(), 
input);
 
-    Assert.assertEquals(1, sig.fieldAccessDeclarations().size());
-    DoFnSignature.FieldAccessDeclaration dec =
-        sig.fieldAccessDeclarations().values().iterator().next();
-    FieldAccessDescriptor fieldAccess = (FieldAccessDescriptor) 
dec.field().get(pardo.getFn());
+    FieldAccessDescriptor fieldAccess = info.getFieldAccessDescriptor();
 
     Assert.assertFalse(fieldAccess.getAllFields());
     Assert.assertTrue(fieldAccess.getFieldsAccessed().isEmpty());

Reply via email to