This is an automated email from the ASF dual-hosted git repository. pabloem pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new b09a028 Merge pull request #15954 from [BEAM-960][BEAM-1675] Improvements to JdbcIO coder inference b09a028 is described below commit b09a028121b7ffca6bb6409e419df61d6c6c3dc1 Author: Pablo <pabl...@users.noreply.github.com> AuthorDate: Mon Nov 29 13:17:06 2021 -0500 Merge pull request #15954 from [BEAM-960][BEAM-1675] Improvements to JdbcIO coder inference * [BEAM-960] support coder inference for JdbcIO * [BEAM-1675] Deprecate withCoder iin JdbcIO * Update JdbcIO integration test * Address comment --- .../java/org/apache/beam/sdk/io/jdbc/JdbcIO.java | 112 +++++++++++++++------ .../java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java | 4 +- .../org/apache/beam/sdk/io/jdbc/JdbcIOTest.java | 19 ++++ 3 files changed, 101 insertions(+), 34 deletions(-) diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java index 83f620f..0f6b9c3 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java @@ -47,7 +47,9 @@ import java.util.stream.IntStream; import javax.sql.DataSource; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.io.jdbc.JdbcIO.WriteFn.WriteFnSpec; @@ -84,6 +86,8 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PDone; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.sdk.values.TypeDescriptors.TypeVariableExtractor; import org.apache.commons.dbcp2.BasicDataSource; import org.apache.commons.dbcp2.DataSourceConnectionFactory; import org.apache.commons.dbcp2.PoolableConnectionFactory; @@ -118,7 +122,6 @@ import org.slf4j.LoggerFactory; * .withUsername("username") * .withPassword("password")) * .withQuery("select id,name from Person") - * .withCoder(KvCoder.of(BigEndianIntegerCoder.of(), StringUtf8Coder.of())) * .withRowMapper(new JdbcIO.RowMapper<KV<Integer, String>>() { * public KV<Integer, String> mapRow(ResultSet resultSet) throws Exception { * return KV.of(resultSet.getInt(1), resultSet.getString(2)); @@ -136,7 +139,6 @@ import org.slf4j.LoggerFactory; * "com.mysql.jdbc.Driver", "jdbc:mysql://hostname:3306/mydb", * "username", "password")) * .withQuery("select id,name from Person where name = ?") - * .withCoder(KvCoder.of(BigEndianIntegerCoder.of(), StringUtf8Coder.of())) * .withStatementPreparator(new JdbcIO.StatementPreparator() { * public void setParameters(PreparedStatement preparedStatement) throws Exception { * preparedStatement.setString(1, "Darwin"); @@ -202,7 +204,6 @@ import org.slf4j.LoggerFactory; * .withLowerBound(0) * .withUpperBound(1000) * .withNumPartitions(5) - * .withCoder(KvCoder.of(BigEndianIntegerCoder.of(), StringUtf8Coder.of())) * .withRowMapper(new JdbcIO.RowMapper<KV<Integer, String>>() { * public KV<Integer, String> mapRow(ResultSet resultSet) throws Exception { * return KV.of(resultSet.getInt(1), resultSet.getString(2)); @@ -226,7 +227,6 @@ import org.slf4j.LoggerFactory; * .withLowerBound(0) * .withUpperBound(1000) * .withNumPartitions(5) - * .withCoder(KvCoder.of(BigEndianIntegerCoder.of(), StringUtf8Coder.of())) * .withRowMapper(new JdbcIO.RowMapper<KV<Integer, String>>() { * public KV<Integer, String> mapRow(ResultSet resultSet) throws Exception { * return KV.of(resultSet.getInt(1), resultSet.getString(2)); @@ -737,6 +737,11 @@ public class JdbcIO { return toBuilder().setRowMapper(rowMapper).build(); } + /** + * @deprecated + * <p>{@link JdbcIO} is able to infer aprppriate coders from other parameters. + */ + @Deprecated public Read<T> withCoder(Coder<T> coder) { checkArgument(coder != null, "coder can not be null"); return toBuilder().setCoder(coder).build(); @@ -764,27 +769,28 @@ public class JdbcIO { public PCollection<T> expand(PBegin input) { checkArgument(getQuery() != null, "withQuery() is required"); checkArgument(getRowMapper() != null, "withRowMapper() is required"); - checkArgument(getCoder() != null, "withCoder() is required"); checkArgument( (getDataSourceProviderFn() != null), "withDataSourceConfiguration() or withDataSourceProviderFn() is required"); - return input - .apply(Create.of((Void) null)) - .apply( - JdbcIO.<Void, T>readAll() - .withDataSourceProviderFn(getDataSourceProviderFn()) - .withQuery(getQuery()) - .withCoder(getCoder()) - .withRowMapper(getRowMapper()) - .withFetchSize(getFetchSize()) - .withOutputParallelization(getOutputParallelization()) - .withParameterSetter( - (element, preparedStatement) -> { - if (getStatementPreparator() != null) { - getStatementPreparator().setParameters(preparedStatement); - } - })); + JdbcIO.ReadAll<Void, T> readAll = + JdbcIO.<Void, T>readAll() + .withDataSourceProviderFn(getDataSourceProviderFn()) + .withQuery(getQuery()) + .withRowMapper(getRowMapper()) + .withFetchSize(getFetchSize()) + .withOutputParallelization(getOutputParallelization()) + .withParameterSetter( + (element, preparedStatement) -> { + if (getStatementPreparator() != null) { + getStatementPreparator().setParameters(preparedStatement); + } + }); + + if (getCoder() != null) { + readAll = readAll.withCoder(getCoder()); + } + return input.apply(Create.of((Void) null)).apply(readAll); } @Override @@ -792,7 +798,9 @@ public class JdbcIO { super.populateDisplayData(builder); builder.add(DisplayData.item("query", getQuery())); builder.add(DisplayData.item("rowMapper", getRowMapper().getClass().getName())); - builder.add(DisplayData.item("coder", getCoder().getClass().getName())); + if (getCoder() != null) { + builder.add(DisplayData.item("coder", getCoder().getClass().getName())); + } if (getDataSourceProviderFn() instanceof HasDisplayData) { ((HasDisplayData) getDataSourceProviderFn()).populateDisplayData(builder); } @@ -882,6 +890,11 @@ public class JdbcIO { return toBuilder().setRowMapper(rowMapper).build(); } + /** + * @deprecated + * <p>{@link JdbcIO} is able to infer aprppriate coders from other parameters. + */ + @Deprecated public ReadAll<ParameterT, OutputT> withCoder(Coder<OutputT> coder) { checkArgument(coder != null, "JdbcIO.readAll().withCoder(coder) called with null coder"); return toBuilder().setCoder(coder).build(); @@ -905,8 +918,33 @@ public class JdbcIO { return toBuilder().setOutputParallelization(outputParallelization).build(); } + private Coder<OutputT> inferCoder(CoderRegistry registry) { + if (getCoder() != null) { + return getCoder(); + } else { + RowMapper<OutputT> rowMapper = getRowMapper(); + TypeDescriptor<OutputT> outputType = + TypeDescriptors.extractFromTypeParameters( + rowMapper, + RowMapper.class, + new TypeVariableExtractor<RowMapper<OutputT>, OutputT>() {}); + try { + return registry.getCoder(outputType); + } catch (CannotProvideCoderException e) { + LOG.warn("Unable to infer a coder for type {}", outputType); + return null; + } + } + } + @Override public PCollection<OutputT> expand(PCollection<ParameterT> input) { + Coder<OutputT> coder = inferCoder(input.getPipeline().getCoderRegistry()); + checkNotNull( + coder, + "Unable to infer a coder for JdbcIO.readAll() transform. " + + "Provide a coder via withCoder, or ensure that one can be inferred from the" + + " provided RowMapper."); PCollection<OutputT> output = input .apply( @@ -917,14 +955,14 @@ public class JdbcIO { getParameterSetter(), getRowMapper(), getFetchSize()))) - .setCoder(getCoder()); + .setCoder(coder); if (getOutputParallelization()) { output = output.apply(new Reparallelize<>()); } try { - TypeDescriptor<OutputT> typeDesc = getCoder().getEncodedTypeDescriptor(); + TypeDescriptor<OutputT> typeDesc = coder.getEncodedTypeDescriptor(); SchemaRegistry registry = input.getPipeline().getSchemaRegistry(); Schema schema = registry.getSchema(typeDesc); output.setSchema( @@ -944,7 +982,9 @@ public class JdbcIO { super.populateDisplayData(builder); builder.add(DisplayData.item("query", getQuery())); builder.add(DisplayData.item("rowMapper", getRowMapper().getClass().getName())); - builder.add(DisplayData.item("coder", getCoder().getClass().getName())); + if (getCoder() != null) { + builder.add(DisplayData.item("coder", getCoder().getClass().getName())); + } if (getDataSourceProviderFn() instanceof HasDisplayData) { ((HasDisplayData) getDataSourceProviderFn()).populateDisplayData(builder); } @@ -1010,6 +1050,11 @@ public class JdbcIO { return toBuilder().setRowMapper(rowMapper).build(); } + /** + * @deprecated + * <p>{@link JdbcIO} is able to infer aprppriate coders from other parameters. + */ + @Deprecated public ReadWithPartitions<T> withCoder(Coder<T> coder) { checkNotNull(coder, "coder can not be null"); return toBuilder().setCoder(coder).build(); @@ -1048,7 +1093,6 @@ public class JdbcIO { @Override public PCollection<T> expand(PBegin input) { checkNotNull(getRowMapper(), "withRowMapper() is required"); - checkNotNull(getCoder(), "withCoder() is required"); checkNotNull( getDataSourceProviderFn(), "withDataSourceConfiguration() or withDataSourceProviderFn() is required"); @@ -1074,15 +1118,13 @@ public class JdbcIO { .apply("Partitioning", ParDo.of(new PartitioningFn())) .apply("Group partitions", GroupByKey.create()); - return ranges.apply( - "Read ranges", + JdbcIO.ReadAll<KV<String, Iterable<Long>>, T> readAll = JdbcIO.<KV<String, Iterable<Long>>, T>readAll() .withDataSourceProviderFn(getDataSourceProviderFn()) .withQuery( String.format( "select * from %1$s where %2$s >= ? and %2$s < ?", getTable(), getPartitionColumn())) - .withCoder(getCoder()) .withRowMapper(getRowMapper()) .withParameterSetter( (PreparedStatementSetter<KV<String, Iterable<Long>>>) @@ -1091,14 +1133,22 @@ public class JdbcIO { preparedStatement.setLong(1, Long.parseLong(range[0])); preparedStatement.setLong(2, Long.parseLong(range[1])); }) - .withOutputParallelization(false)); + .withOutputParallelization(false); + + if (getCoder() != null) { + readAll = readAll.withCoder(getCoder()); + } + + return ranges.apply("Read ranges", readAll); } @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); builder.add(DisplayData.item("rowMapper", getRowMapper().getClass().getName())); - builder.add(DisplayData.item("coder", getCoder().getClass().getName())); + if (getCoder() != null) { + builder.add(DisplayData.item("coder", getCoder().getClass().getName())); + } builder.add(DisplayData.item("partitionColumn", getPartitionColumn())); builder.add(DisplayData.item("table", getTable())); builder.add(DisplayData.item("numPartitions", getNumPartitions())); diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java index c5f1362..8ea7086 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java @@ -32,7 +32,6 @@ import java.util.Set; import java.util.UUID; import java.util.function.Function; import org.apache.beam.sdk.PipelineResult; -import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.io.GenerateSequence; import org.apache.beam.sdk.io.common.DatabaseTestHelper; import org.apache.beam.sdk.io.common.HashingFn; @@ -234,8 +233,7 @@ public class JdbcIOIT { JdbcIO.<TestRow>read() .withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource)) .withQuery(String.format("select name,id from %s;", tableName)) - .withRowMapper(new JdbcTestHelper.CreateTestRowOfNameAndId()) - .withCoder(SerializableCoder.of(TestRow.class))) + .withRowMapper(new JdbcTestHelper.CreateTestRowOfNameAndId())) .apply(ParDo.of(new TimeMonitor<>(NAMESPACE, "read_time"))); PAssert.thatSingleton(namesAndIds.apply("Count All", Count.globally())) diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java index f223910..695ad03 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java @@ -251,6 +251,25 @@ public class JdbcIOTest implements Serializable { } @Test + public void testReadWithCoderInference() { + PCollection<TestRow> rows = + pipeline.apply( + JdbcIO.<TestRow>read() + .withDataSourceConfiguration(DATA_SOURCE_CONFIGURATION) + .withQuery(String.format("select name,id from %s where name = ?", READ_TABLE_NAME)) + .withStatementPreparator( + preparedStatement -> preparedStatement.setString(1, TestRow.getNameForSeed(1))) + .withRowMapper(new JdbcTestHelper.CreateTestRowOfNameAndId())); + + PAssert.thatSingleton(rows.apply("Count All", Count.globally())).isEqualTo(1L); + + Iterable<TestRow> expectedValues = Collections.singletonList(TestRow.fromSeed(1)); + PAssert.that(rows).containsInAnyOrder(expectedValues); + + pipeline.run(); + } + + @Test public void testReadRowsWithDataSourceConfiguration() { PCollection<Row> rows = pipeline.apply(