This is an automated email from the ASF dual-hosted git repository. apilloud pushed a commit to branch fad in repository https://gitbox.apache.org/repos/asf/beam.git
commit 72ee5d3547d3990fbf61e9ed61aec5958effcb03 Author: Andrew Pilloud <apill...@google.com> AuthorDate: Wed Nov 3 14:33:18 2021 -0700 [BEAM-13056] Expose FieldAccess in DoFnSchemaInformation --- .../beam/sdk/transforms/DoFnSchemaInformation.java | 33 +++++- .../java/org/apache/beam/sdk/transforms/ParDo.java | 11 ++ .../org/apache/beam/sdk/transforms/ParDoTest.java | 126 +++++++++++++++++++++ .../extensions/sql/impl/rel/BeamCalcRelTest.java | 29 +++-- .../sql/zetasql/BeamZetaSqlCalcRelTest.java | 29 +++-- 5 files changed, 194 insertions(+), 34 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnSchemaInformation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnSchemaInformation.java index 54f6d19..9a8ccac 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnSchemaInformation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnSchemaInformation.java @@ -51,10 +51,14 @@ public abstract class DoFnSchemaInformation implements Serializable { */ public abstract List<SerializableFunction<?, ?>> getElementConverters(); + /** Effective FieldAccessDescriptor applied by DoFn. */ + public abstract FieldAccessDescriptor getFieldAccessDescriptor(); + /** Create an instance. */ public static DoFnSchemaInformation create() { return new AutoValue_DoFnSchemaInformation.Builder() .setElementConverters(Collections.emptyList()) + .setFieldAccessDescriptor(FieldAccessDescriptor.create()) .build(); } @@ -63,6 +67,8 @@ public abstract class DoFnSchemaInformation implements Serializable { public abstract static class Builder { abstract Builder setElementConverters(List<SerializableFunction<?, ?>> converters); + abstract Builder setFieldAccessDescriptor(FieldAccessDescriptor descriptor); + abstract DoFnSchemaInformation build(); } @@ -101,7 +107,10 @@ public abstract class DoFnSchemaInformation implements Serializable { unbox)) .build(); - return toBuilder().setElementConverters(converters).build(); + return toBuilder() + .setElementConverters(converters) + .setFieldAccessDescriptor(getFieldAccessDescriptor()) + .build(); } /** @@ -141,7 +150,27 @@ public abstract class DoFnSchemaInformation implements Serializable { elementT)) .build(); - return toBuilder().setElementConverters(converters).build(); + return toBuilder() + .setElementConverters(converters) + .setFieldAccessDescriptor(getFieldAccessDescriptor()) + .build(); + } + + /** + * Specified a descriptor of fields accessed from an input schema. + * + * @param selectDescriptor The descriptor describing which field to select. + * @return + */ + DoFnSchemaInformation withFieldAccessDescriptor(FieldAccessDescriptor selectDescriptor) { + + FieldAccessDescriptor descriptor = + FieldAccessDescriptor.union(ImmutableList.of(getFieldAccessDescriptor(), selectDescriptor)); + + return toBuilder() + .setElementConverters(getElementConverters()) + .setFieldAccessDescriptor(descriptor) + .build(); } private static class ConversionFunction<InputT, OutputT> diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java index 9da0998..97cd6cd 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java @@ -49,6 +49,8 @@ import org.apache.beam.sdk.transforms.reflect.DoFnSignature; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.FieldAccessDeclaration; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.MethodWithExtraParameters; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.OnTimerMethod; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.ElementParameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.ProcessContextParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SchemaElementParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SideInputParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; @@ -637,6 +639,7 @@ public class ParDo { fn.getClass().getName())); } } + /** * Extract information on how the DoFn uses schemas. In particular, if the schema of an element * parameter does not match the input PCollection's schema, convert. @@ -662,6 +665,7 @@ public class ParDo { input.getSchema(), signature.fieldAccessDeclarations(), fn); + doFnSchemaInformation = doFnSchemaInformation.withFieldAccessDescriptor(accessDescriptor); Schema selectedSchema = SelectHelpers.getOutputSchema(input.getSchema(), accessDescriptor); ConvertHelpers.ConvertedSchemaInformation converted = ConvertHelpers.getConvertedSchemaInformation(selectedSchema, elementT, schemaRegistry); @@ -683,6 +687,13 @@ public class ParDo { (SchemaCoder<?>) input.getCoder(), accessDescriptor, selectedSchema, elementT); } } + for (DoFnSignature.Parameter p : processElementMethod.extraParameters()) { + if (p instanceof ProcessContextParameter || p instanceof ElementParameter) { + doFnSchemaInformation = + doFnSchemaInformation.withFieldAccessDescriptor(FieldAccessDescriptor.withAllFields()); + break; + } + } return doFnSchemaInformation; } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java index 8becf9a..173bdad 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java @@ -76,6 +76,8 @@ import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.io.GenerateSequence; import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor; +import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.state.BagState; import org.apache.beam.sdk.state.CombiningState; import org.apache.beam.sdk.state.GroupingState; @@ -135,6 +137,7 @@ import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; @@ -6141,4 +6144,127 @@ public class ParDoTest implements Serializable { pipeline.run(); } } + + /** Tests to validate SchemaInformation. */ + @RunWith(JUnit4.class) + public static class SchemaInformationTests extends SharedTestBase implements Serializable { + + private static final Schema TEST_SCHEMA = + Schema.builder().addInt32Field("f_int").addStringField("f_string").build(); + private static final Row TEST_ROW = Row.withSchema(TEST_SCHEMA).addValues(10, "ten").build(); + + @Test + public void testUnboxFieldAccess() throws Exception { + + DoFn<Row, Integer> fn = + new DoFn<Row, Integer>() { + @ProcessElement + public void processElement( + @FieldAccess("f_int") Integer value, OutputReceiver<Integer> r) { + r.output(value); + } + }; + PCollection<Row> input = pipeline.apply(Create.of(TEST_ROW).withRowSchema(TEST_SCHEMA)); + + DoFnSchemaInformation info = ParDo.getDoFnSchemaInformation(fn, input); + assertEquals(info.getElementConverters().toString(), 1, info.getElementConverters().size()); + + FieldAccessDescriptor fieldAccessDescriptor = info.getFieldAccessDescriptor(); + assertEquals( + fieldAccessDescriptor.toString(), 1, fieldAccessDescriptor.getFieldsAccessed().size()); + assertEquals( + "f_int", + Iterables.getOnlyElement(fieldAccessDescriptor.getFieldsAccessed()).getFieldName()); + assertFalse(fieldAccessDescriptor.toString(), fieldAccessDescriptor.getAllFields()); + } + + @Test + public void testSingleFieldAccess() throws Exception { + + DoFn<Row, Integer> fn = + new DoFn<Row, Integer>() { + @FieldAccess("foo") + final FieldAccessDescriptor fieldAccess = FieldAccessDescriptor.withFieldNames("f_int"); + + @ProcessElement + public void processElement(@FieldAccess("foo") Row row, OutputReceiver<Integer> r) { + r.output(row.getInt32("f_int")); + } + }; + PCollection<Row> input = pipeline.apply(Create.of(TEST_ROW).withRowSchema(TEST_SCHEMA)); + + DoFnSchemaInformation info = ParDo.getDoFnSchemaInformation(fn, input); + assertEquals(info.getElementConverters().toString(), 1, info.getElementConverters().size()); + + FieldAccessDescriptor fieldAccessDescriptor = info.getFieldAccessDescriptor(); + assertEquals( + fieldAccessDescriptor.toString(), 1, fieldAccessDescriptor.getFieldsAccessed().size()); + assertEquals( + "f_int", + Iterables.getOnlyElement(fieldAccessDescriptor.getFieldsAccessed()).getFieldName()); + assertFalse(fieldAccessDescriptor.toString(), fieldAccessDescriptor.getAllFields()); + } + + @Test + public void testImplicitElement() throws Exception { + + DoFn<Row, Integer> fn = + new DoFn<Row, Integer>() { + @ProcessElement + public void processElement(@Element Row row, OutputReceiver<Integer> r) { + r.output(row.getInt32("f_int")); + } + }; + PCollection<Row> input = pipeline.apply(Create.of(TEST_ROW).withRowSchema(TEST_SCHEMA)); + + DoFnSchemaInformation info = ParDo.getDoFnSchemaInformation(fn, input); + assertEquals(info.getElementConverters().toString(), 0, info.getElementConverters().size()); + + FieldAccessDescriptor fieldAccessDescriptor = info.getFieldAccessDescriptor(); + assertTrue(fieldAccessDescriptor.toString(), fieldAccessDescriptor.getAllFields()); + } + + @Test + public void testImplicitProcessContext() throws Exception { + + DoFn<Row, Integer> fn = + new DoFn<Row, Integer>() { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(c.element().getInt32("f_int")); + } + }; + PCollection<Row> input = pipeline.apply(Create.of(TEST_ROW).withRowSchema(TEST_SCHEMA)); + + DoFnSchemaInformation info = ParDo.getDoFnSchemaInformation(fn, input); + assertEquals(info.getElementConverters().toString(), 0, info.getElementConverters().size()); + + FieldAccessDescriptor fieldAccessDescriptor = info.getFieldAccessDescriptor(); + assertTrue(fieldAccessDescriptor.toString(), fieldAccessDescriptor.getAllFields()); + } + + @Test + public void testNoAccess() throws Exception { + + DoFn<Row, Integer> fn = + new DoFn<Row, Integer>() { + @ProcessElement + public void processElement(OutputReceiver<Integer> r) { + r.output(1); + } + }; + PCollection<Row> input = pipeline.apply(Create.of(TEST_ROW).withRowSchema(TEST_SCHEMA)); + + DoFnSchemaInformation info = ParDo.getDoFnSchemaInformation(fn, input); + assertEquals(info.getElementConverters().toString(), 0, info.getElementConverters().size()); + + FieldAccessDescriptor fieldAccessDescriptor = info.getFieldAccessDescriptor(); + assertFalse(fieldAccessDescriptor.toString(), fieldAccessDescriptor.getAllFields()); + assertTrue( + fieldAccessDescriptor.toString(), fieldAccessDescriptor.getFieldsAccessed().isEmpty()); + assertTrue( + fieldAccessDescriptor.toString(), + fieldAccessDescriptor.getNestedFieldsAccessed().isEmpty()); + } + } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRelTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRelTest.java index 385d908..702dd9f 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRelTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRelTest.java @@ -28,13 +28,13 @@ import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.schemas.FieldAccessDescriptor; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.reflect.DoFnSignature; -import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.RelNode; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.joda.time.DateTime; import org.joda.time.Duration; import org.junit.Assert; @@ -213,18 +213,15 @@ public class BeamCalcRelTest extends BaseRelTest { ParDo.MultiOutput<Row, Row> pardo = (ParDo.MultiOutput<Row, Row>) nodeGetter.producer.getTransform(); - DoFnSignature sig = DoFnSignatures.getSignature(pardo.getFn().getClass()); + PCollection<Row> input = + (PCollection<Row>) Iterables.getOnlyElement(nodeGetter.producer.getInputs().values()); - Assert.assertEquals(1, sig.fieldAccessDeclarations().size()); - DoFnSignature.FieldAccessDeclaration dec = - sig.fieldAccessDeclarations().values().iterator().next(); - FieldAccessDescriptor fieldAccess = (FieldAccessDescriptor) dec.field().get(pardo.getFn()); + DoFnSchemaInformation info = ParDo.getDoFnSchemaInformation(pardo.getFn(), input); - Assert.assertTrue(fieldAccess.referencesSingleField()); + FieldAccessDescriptor fieldAccess = info.getFieldAccessDescriptor(); - fieldAccess = - fieldAccess.resolve(nodeGetter.producer.getInputs().values().iterator().next().getSchema()); - Assert.assertEquals("order_id", fieldAccess.fieldNamesAccessed().iterator().next()); + Assert.assertTrue(fieldAccess.referencesSingleField()); + Assert.assertEquals("order_id", Iterables.getOnlyElement(fieldAccess.fieldNamesAccessed())); pipeline.run().waitUntilFinish(); } @@ -240,12 +237,12 @@ public class BeamCalcRelTest extends BaseRelTest { ParDo.MultiOutput<Row, Row> pardo = (ParDo.MultiOutput<Row, Row>) nodeGetter.producer.getTransform(); - DoFnSignature sig = DoFnSignatures.getSignature(pardo.getFn().getClass()); + PCollection<Row> input = + (PCollection<Row>) Iterables.getOnlyElement(nodeGetter.producer.getInputs().values()); + + DoFnSchemaInformation info = ParDo.getDoFnSchemaInformation(pardo.getFn(), input); - Assert.assertEquals(1, sig.fieldAccessDeclarations().size()); - DoFnSignature.FieldAccessDeclaration dec = - sig.fieldAccessDeclarations().values().iterator().next(); - FieldAccessDescriptor fieldAccess = (FieldAccessDescriptor) dec.field().get(pardo.getFn()); + FieldAccessDescriptor fieldAccess = info.getFieldAccessDescriptor(); Assert.assertFalse(fieldAccess.getAllFields()); Assert.assertTrue(fieldAccess.getFieldsAccessed().isEmpty()); diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRelTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRelTest.java index 352e83a..b490458 100644 --- a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRelTest.java +++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRelTest.java @@ -24,12 +24,12 @@ import org.apache.beam.sdk.extensions.sql.impl.rel.BeamSqlRelUtils; import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.schemas.FieldAccessDescriptor; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.reflect.DoFnSignature; -import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.junit.Assert; import org.junit.Before; import org.junit.Rule; @@ -80,18 +80,15 @@ public class BeamZetaSqlCalcRelTest extends ZetaSqlTestBase { ParDo.MultiOutput<Row, Row> pardo = (ParDo.MultiOutput<Row, Row>) nodeGetter.producer.getTransform(); - DoFnSignature sig = DoFnSignatures.getSignature(pardo.getFn().getClass()); + PCollection<Row> input = + (PCollection<Row>) Iterables.getOnlyElement(nodeGetter.producer.getInputs().values()); - Assert.assertEquals(1, sig.fieldAccessDeclarations().size()); - DoFnSignature.FieldAccessDeclaration dec = - sig.fieldAccessDeclarations().values().iterator().next(); - FieldAccessDescriptor fieldAccess = (FieldAccessDescriptor) dec.field().get(pardo.getFn()); + DoFnSchemaInformation info = ParDo.getDoFnSchemaInformation(pardo.getFn(), input); - Assert.assertTrue(fieldAccess.referencesSingleField()); + FieldAccessDescriptor fieldAccess = info.getFieldAccessDescriptor(); - fieldAccess = - fieldAccess.resolve(nodeGetter.producer.getInputs().values().iterator().next().getSchema()); - Assert.assertEquals("Key", fieldAccess.fieldNamesAccessed().iterator().next()); + Assert.assertTrue(fieldAccess.referencesSingleField()); + Assert.assertEquals("Key", Iterables.getOnlyElement(fieldAccess.fieldNamesAccessed())); pipeline.run().waitUntilFinish(); } @@ -107,12 +104,12 @@ public class BeamZetaSqlCalcRelTest extends ZetaSqlTestBase { ParDo.MultiOutput<Row, Row> pardo = (ParDo.MultiOutput<Row, Row>) nodeGetter.producer.getTransform(); - DoFnSignature sig = DoFnSignatures.getSignature(pardo.getFn().getClass()); + PCollection<Row> input = + (PCollection<Row>) Iterables.getOnlyElement(nodeGetter.producer.getInputs().values()); + + DoFnSchemaInformation info = ParDo.getDoFnSchemaInformation(pardo.getFn(), input); - Assert.assertEquals(1, sig.fieldAccessDeclarations().size()); - DoFnSignature.FieldAccessDeclaration dec = - sig.fieldAccessDeclarations().values().iterator().next(); - FieldAccessDescriptor fieldAccess = (FieldAccessDescriptor) dec.field().get(pardo.getFn()); + FieldAccessDescriptor fieldAccess = info.getFieldAccessDescriptor(); Assert.assertFalse(fieldAccess.getAllFields()); Assert.assertTrue(fieldAccess.getFieldsAccessed().isEmpty());