This is an automated email from the ASF dual-hosted git repository.

pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new b09a028  Merge pull request #15954 from [BEAM-960][BEAM-1675] 
Improvements to JdbcIO coder inference
b09a028 is described below

commit b09a028121b7ffca6bb6409e419df61d6c6c3dc1
Author: Pablo <pabl...@users.noreply.github.com>
AuthorDate: Mon Nov 29 13:17:06 2021 -0500

    Merge pull request #15954 from [BEAM-960][BEAM-1675] Improvements to JdbcIO 
coder inference
    
    * [BEAM-960] support coder inference for JdbcIO
    
    * [BEAM-1675] Deprecate withCoder iin JdbcIO
    
    * Update JdbcIO integration test
    
    * Address comment
---
 .../java/org/apache/beam/sdk/io/jdbc/JdbcIO.java   | 112 +++++++++++++++------
 .../java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java |   4 +-
 .../org/apache/beam/sdk/io/jdbc/JdbcIOTest.java    |  19 ++++
 3 files changed, 101 insertions(+), 34 deletions(-)

diff --git 
a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java 
b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
index 83f620f..0f6b9c3 100644
--- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
+++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
@@ -47,7 +47,9 @@ import java.util.stream.IntStream;
 import javax.sql.DataSource;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.annotations.Experimental.Kind;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderRegistry;
 import org.apache.beam.sdk.coders.RowCoder;
 import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.io.jdbc.JdbcIO.WriteFn.WriteFnSpec;
@@ -84,6 +86,8 @@ import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.PDone;
 import org.apache.beam.sdk.values.Row;
 import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.beam.sdk.values.TypeDescriptors;
+import org.apache.beam.sdk.values.TypeDescriptors.TypeVariableExtractor;
 import org.apache.commons.dbcp2.BasicDataSource;
 import org.apache.commons.dbcp2.DataSourceConnectionFactory;
 import org.apache.commons.dbcp2.PoolableConnectionFactory;
@@ -118,7 +122,6 @@ import org.slf4j.LoggerFactory;
  *        .withUsername("username")
  *        .withPassword("password"))
  *   .withQuery("select id,name from Person")
- *   .withCoder(KvCoder.of(BigEndianIntegerCoder.of(), StringUtf8Coder.of()))
  *   .withRowMapper(new JdbcIO.RowMapper<KV<Integer, String>>() {
  *     public KV<Integer, String> mapRow(ResultSet resultSet) throws Exception 
{
  *       return KV.of(resultSet.getInt(1), resultSet.getString(2));
@@ -136,7 +139,6 @@ import org.slf4j.LoggerFactory;
  *       "com.mysql.jdbc.Driver", "jdbc:mysql://hostname:3306/mydb",
  *       "username", "password"))
  *   .withQuery("select id,name from Person where name = ?")
- *   .withCoder(KvCoder.of(BigEndianIntegerCoder.of(), StringUtf8Coder.of()))
  *   .withStatementPreparator(new JdbcIO.StatementPreparator() {
  *     public void setParameters(PreparedStatement preparedStatement) throws 
Exception {
  *       preparedStatement.setString(1, "Darwin");
@@ -202,7 +204,6 @@ import org.slf4j.LoggerFactory;
  *  .withLowerBound(0)
  *  .withUpperBound(1000)
  *  .withNumPartitions(5)
- *  .withCoder(KvCoder.of(BigEndianIntegerCoder.of(), StringUtf8Coder.of()))
  *  .withRowMapper(new JdbcIO.RowMapper<KV<Integer, String>>() {
  *    public KV<Integer, String> mapRow(ResultSet resultSet) throws Exception {
  *      return KV.of(resultSet.getInt(1), resultSet.getString(2));
@@ -226,7 +227,6 @@ import org.slf4j.LoggerFactory;
  *  .withLowerBound(0)
  *  .withUpperBound(1000)
  *  .withNumPartitions(5)
- *  .withCoder(KvCoder.of(BigEndianIntegerCoder.of(), StringUtf8Coder.of()))
  *  .withRowMapper(new JdbcIO.RowMapper<KV<Integer, String>>() {
  *    public KV<Integer, String> mapRow(ResultSet resultSet) throws Exception {
  *      return KV.of(resultSet.getInt(1), resultSet.getString(2));
@@ -737,6 +737,11 @@ public class JdbcIO {
       return toBuilder().setRowMapper(rowMapper).build();
     }
 
+    /**
+     * @deprecated
+     *     <p>{@link JdbcIO} is able to infer aprppriate coders from other 
parameters.
+     */
+    @Deprecated
     public Read<T> withCoder(Coder<T> coder) {
       checkArgument(coder != null, "coder can not be null");
       return toBuilder().setCoder(coder).build();
@@ -764,27 +769,28 @@ public class JdbcIO {
     public PCollection<T> expand(PBegin input) {
       checkArgument(getQuery() != null, "withQuery() is required");
       checkArgument(getRowMapper() != null, "withRowMapper() is required");
-      checkArgument(getCoder() != null, "withCoder() is required");
       checkArgument(
           (getDataSourceProviderFn() != null),
           "withDataSourceConfiguration() or withDataSourceProviderFn() is 
required");
 
-      return input
-          .apply(Create.of((Void) null))
-          .apply(
-              JdbcIO.<Void, T>readAll()
-                  .withDataSourceProviderFn(getDataSourceProviderFn())
-                  .withQuery(getQuery())
-                  .withCoder(getCoder())
-                  .withRowMapper(getRowMapper())
-                  .withFetchSize(getFetchSize())
-                  .withOutputParallelization(getOutputParallelization())
-                  .withParameterSetter(
-                      (element, preparedStatement) -> {
-                        if (getStatementPreparator() != null) {
-                          
getStatementPreparator().setParameters(preparedStatement);
-                        }
-                      }));
+      JdbcIO.ReadAll<Void, T> readAll =
+          JdbcIO.<Void, T>readAll()
+              .withDataSourceProviderFn(getDataSourceProviderFn())
+              .withQuery(getQuery())
+              .withRowMapper(getRowMapper())
+              .withFetchSize(getFetchSize())
+              .withOutputParallelization(getOutputParallelization())
+              .withParameterSetter(
+                  (element, preparedStatement) -> {
+                    if (getStatementPreparator() != null) {
+                      
getStatementPreparator().setParameters(preparedStatement);
+                    }
+                  });
+
+      if (getCoder() != null) {
+        readAll = readAll.withCoder(getCoder());
+      }
+      return input.apply(Create.of((Void) null)).apply(readAll);
     }
 
     @Override
@@ -792,7 +798,9 @@ public class JdbcIO {
       super.populateDisplayData(builder);
       builder.add(DisplayData.item("query", getQuery()));
       builder.add(DisplayData.item("rowMapper", 
getRowMapper().getClass().getName()));
-      builder.add(DisplayData.item("coder", getCoder().getClass().getName()));
+      if (getCoder() != null) {
+        builder.add(DisplayData.item("coder", 
getCoder().getClass().getName()));
+      }
       if (getDataSourceProviderFn() instanceof HasDisplayData) {
         ((HasDisplayData) 
getDataSourceProviderFn()).populateDisplayData(builder);
       }
@@ -882,6 +890,11 @@ public class JdbcIO {
       return toBuilder().setRowMapper(rowMapper).build();
     }
 
+    /**
+     * @deprecated
+     *     <p>{@link JdbcIO} is able to infer aprppriate coders from other 
parameters.
+     */
+    @Deprecated
     public ReadAll<ParameterT, OutputT> withCoder(Coder<OutputT> coder) {
       checkArgument(coder != null, "JdbcIO.readAll().withCoder(coder) called 
with null coder");
       return toBuilder().setCoder(coder).build();
@@ -905,8 +918,33 @@ public class JdbcIO {
       return 
toBuilder().setOutputParallelization(outputParallelization).build();
     }
 
+    private Coder<OutputT> inferCoder(CoderRegistry registry) {
+      if (getCoder() != null) {
+        return getCoder();
+      } else {
+        RowMapper<OutputT> rowMapper = getRowMapper();
+        TypeDescriptor<OutputT> outputType =
+            TypeDescriptors.extractFromTypeParameters(
+                rowMapper,
+                RowMapper.class,
+                new TypeVariableExtractor<RowMapper<OutputT>, OutputT>() {});
+        try {
+          return registry.getCoder(outputType);
+        } catch (CannotProvideCoderException e) {
+          LOG.warn("Unable to infer a coder for type {}", outputType);
+          return null;
+        }
+      }
+    }
+
     @Override
     public PCollection<OutputT> expand(PCollection<ParameterT> input) {
+      Coder<OutputT> coder = 
inferCoder(input.getPipeline().getCoderRegistry());
+      checkNotNull(
+          coder,
+          "Unable to infer a coder for JdbcIO.readAll() transform. "
+              + "Provide a coder via withCoder, or ensure that one can be 
inferred from the"
+              + " provided RowMapper.");
       PCollection<OutputT> output =
           input
               .apply(
@@ -917,14 +955,14 @@ public class JdbcIO {
                           getParameterSetter(),
                           getRowMapper(),
                           getFetchSize())))
-              .setCoder(getCoder());
+              .setCoder(coder);
 
       if (getOutputParallelization()) {
         output = output.apply(new Reparallelize<>());
       }
 
       try {
-        TypeDescriptor<OutputT> typeDesc = 
getCoder().getEncodedTypeDescriptor();
+        TypeDescriptor<OutputT> typeDesc = coder.getEncodedTypeDescriptor();
         SchemaRegistry registry = input.getPipeline().getSchemaRegistry();
         Schema schema = registry.getSchema(typeDesc);
         output.setSchema(
@@ -944,7 +982,9 @@ public class JdbcIO {
       super.populateDisplayData(builder);
       builder.add(DisplayData.item("query", getQuery()));
       builder.add(DisplayData.item("rowMapper", 
getRowMapper().getClass().getName()));
-      builder.add(DisplayData.item("coder", getCoder().getClass().getName()));
+      if (getCoder() != null) {
+        builder.add(DisplayData.item("coder", 
getCoder().getClass().getName()));
+      }
       if (getDataSourceProviderFn() instanceof HasDisplayData) {
         ((HasDisplayData) 
getDataSourceProviderFn()).populateDisplayData(builder);
       }
@@ -1010,6 +1050,11 @@ public class JdbcIO {
       return toBuilder().setRowMapper(rowMapper).build();
     }
 
+    /**
+     * @deprecated
+     *     <p>{@link JdbcIO} is able to infer aprppriate coders from other 
parameters.
+     */
+    @Deprecated
     public ReadWithPartitions<T> withCoder(Coder<T> coder) {
       checkNotNull(coder, "coder can not be null");
       return toBuilder().setCoder(coder).build();
@@ -1048,7 +1093,6 @@ public class JdbcIO {
     @Override
     public PCollection<T> expand(PBegin input) {
       checkNotNull(getRowMapper(), "withRowMapper() is required");
-      checkNotNull(getCoder(), "withCoder() is required");
       checkNotNull(
           getDataSourceProviderFn(),
           "withDataSourceConfiguration() or withDataSourceProviderFn() is 
required");
@@ -1074,15 +1118,13 @@ public class JdbcIO {
               .apply("Partitioning", ParDo.of(new PartitioningFn()))
               .apply("Group partitions", GroupByKey.create());
 
-      return ranges.apply(
-          "Read ranges",
+      JdbcIO.ReadAll<KV<String, Iterable<Long>>, T> readAll =
           JdbcIO.<KV<String, Iterable<Long>>, T>readAll()
               .withDataSourceProviderFn(getDataSourceProviderFn())
               .withQuery(
                   String.format(
                       "select * from %1$s where %2$s >= ? and %2$s < ?",
                       getTable(), getPartitionColumn()))
-              .withCoder(getCoder())
               .withRowMapper(getRowMapper())
               .withParameterSetter(
                   (PreparedStatementSetter<KV<String, Iterable<Long>>>)
@@ -1091,14 +1133,22 @@ public class JdbcIO {
                         preparedStatement.setLong(1, Long.parseLong(range[0]));
                         preparedStatement.setLong(2, Long.parseLong(range[1]));
                       })
-              .withOutputParallelization(false));
+              .withOutputParallelization(false);
+
+      if (getCoder() != null) {
+        readAll = readAll.withCoder(getCoder());
+      }
+
+      return ranges.apply("Read ranges", readAll);
     }
 
     @Override
     public void populateDisplayData(DisplayData.Builder builder) {
       super.populateDisplayData(builder);
       builder.add(DisplayData.item("rowMapper", 
getRowMapper().getClass().getName()));
-      builder.add(DisplayData.item("coder", getCoder().getClass().getName()));
+      if (getCoder() != null) {
+        builder.add(DisplayData.item("coder", 
getCoder().getClass().getName()));
+      }
       builder.add(DisplayData.item("partitionColumn", getPartitionColumn()));
       builder.add(DisplayData.item("table", getTable()));
       builder.add(DisplayData.item("numPartitions", getNumPartitions()));
diff --git 
a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java 
b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java
index c5f1362..8ea7086 100644
--- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java
+++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java
@@ -32,7 +32,6 @@ import java.util.Set;
 import java.util.UUID;
 import java.util.function.Function;
 import org.apache.beam.sdk.PipelineResult;
-import org.apache.beam.sdk.coders.SerializableCoder;
 import org.apache.beam.sdk.io.GenerateSequence;
 import org.apache.beam.sdk.io.common.DatabaseTestHelper;
 import org.apache.beam.sdk.io.common.HashingFn;
@@ -234,8 +233,7 @@ public class JdbcIOIT {
                 JdbcIO.<TestRow>read()
                     
.withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource))
                     .withQuery(String.format("select name,id from %s;", 
tableName))
-                    .withRowMapper(new 
JdbcTestHelper.CreateTestRowOfNameAndId())
-                    .withCoder(SerializableCoder.of(TestRow.class)))
+                    .withRowMapper(new 
JdbcTestHelper.CreateTestRowOfNameAndId()))
             .apply(ParDo.of(new TimeMonitor<>(NAMESPACE, "read_time")));
 
     PAssert.thatSingleton(namesAndIds.apply("Count All", Count.globally()))
diff --git 
a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java 
b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
index f223910..695ad03 100644
--- 
a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
+++ 
b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
@@ -251,6 +251,25 @@ public class JdbcIOTest implements Serializable {
   }
 
   @Test
+  public void testReadWithCoderInference() {
+    PCollection<TestRow> rows =
+        pipeline.apply(
+            JdbcIO.<TestRow>read()
+                .withDataSourceConfiguration(DATA_SOURCE_CONFIGURATION)
+                .withQuery(String.format("select name,id from %s where name = 
?", READ_TABLE_NAME))
+                .withStatementPreparator(
+                    preparedStatement -> preparedStatement.setString(1, 
TestRow.getNameForSeed(1)))
+                .withRowMapper(new JdbcTestHelper.CreateTestRowOfNameAndId()));
+
+    PAssert.thatSingleton(rows.apply("Count All", 
Count.globally())).isEqualTo(1L);
+
+    Iterable<TestRow> expectedValues = 
Collections.singletonList(TestRow.fromSeed(1));
+    PAssert.that(rows).containsInAnyOrder(expectedValues);
+
+    pipeline.run();
+  }
+
+  @Test
   public void testReadRowsWithDataSourceConfiguration() {
     PCollection<Row> rows =
         pipeline.apply(

Reply via email to