daria-malkova commented on a change in pull request #15049:
URL: https://github.com/apache/beam/pull/15049#discussion_r659629294



##########
File path: 
sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java
##########
@@ -252,4 +262,67 @@ private static Calendar withTimestampAndTimezone(DateTime 
dateTime) {
 
     return calendar;
   }
+
+  /** Create partitions on a table. */
+  static class PartitioningFn extends DoFn<List<Integer>, KV<String, Integer>> 
{
+    @ProcessElement
+    public void processElement(ProcessContext c) {
+      List<Integer> params = c.element();
+      Integer lowerBound = params.get(0);
+      Integer upperBound = params.get(1);
+      Integer numPartitions = params.get(2);
+      int stride = (upperBound - lowerBound) / numPartitions + 1;

Review comment:
       Added

##########
File path: 
sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
##########
@@ -873,8 +944,177 @@ public void populateDisplayData(DisplayData.Builder 
builder) {
     }
   }
 
+  /** Implementation of {@link #readWithPartitions}. */
+  @AutoValue
+  public abstract static class ReadWithPartitions<T> extends 
PTransform<PBegin, PCollection<T>> {
+
+    abstract @Nullable SerializableFunction<Void, DataSource> 
getDataSourceProviderFn();
+
+    abstract @Nullable RowMapper<T> getRowMapper();
+
+    abstract @Nullable Coder<T> getCoder();
+
+    abstract int getNumPartitions();
+
+    abstract @Nullable String getPartitionColumn();
+
+    abstract int getLowerBound();
+
+    abstract int getUpperBound();
+
+    abstract @Nullable String getTable();
+
+    abstract Builder<T> toBuilder();
+
+    @AutoValue.Builder
+    abstract static class Builder<T> {
+
+      abstract Builder<T> setDataSourceProviderFn(
+          SerializableFunction<Void, DataSource> dataSourceProviderFn);
+
+      abstract Builder<T> setRowMapper(RowMapper<T> rowMapper);
+
+      abstract Builder<T> setCoder(Coder<T> coder);
+
+      abstract Builder<T> setNumPartitions(int numPartitions);
+
+      abstract Builder<T> setPartitionColumn(String partitionColumn);
+
+      abstract Builder<T> setLowerBound(int lowerBound);
+
+      abstract Builder<T> setUpperBound(int upperBound);
+
+      abstract Builder<T> setTable(String tableName);
+
+      abstract ReadWithPartitions<T> build();
+    }
+
+    public ReadWithPartitions<T> withDataSourceConfiguration(final 
DataSourceConfiguration config) {
+      return withDataSourceProviderFn(new 
DataSourceProviderFromDataSourceConfiguration(config));
+    }
+
+    public ReadWithPartitions<T> withDataSourceProviderFn(
+        SerializableFunction<Void, DataSource> dataSourceProviderFn) {
+      return toBuilder().setDataSourceProviderFn(dataSourceProviderFn).build();
+    }
+
+    public ReadWithPartitions<T> withRowMapper(RowMapper<T> rowMapper) {
+      checkNotNull(rowMapper, "rowMapper can not be null");
+      return toBuilder().setRowMapper(rowMapper).build();
+    }
+
+    public ReadWithPartitions<T> withCoder(Coder<T> coder) {
+      checkNotNull(coder, "coder can not be null");
+      return toBuilder().setCoder(coder).build();
+    }
+
+    /**
+     * The number of partitions. This, along with withLowerBound and 
withUpperBound, form partitions
+     * strides for generated WHERE clause expressions used to split the column 
withPartitionColumn
+     * evenly. When the input is less than 1, the number is set to 1.
+     */
+    public ReadWithPartitions<T> withNumPartitions(int numPartitions) {
+      checkArgument(numPartitions > 0, "numPartitions can not be less than 1");
+      return toBuilder().setNumPartitions(numPartitions).build();
+    }
+
+    /** The name of a column of numeric type that will be used for 
partitioning */
+    public ReadWithPartitions<T> withPartitionColumn(String partitionColumn) {
+      checkNotNull(partitionColumn, "partitionColumn can not be null");
+      return toBuilder().setPartitionColumn(partitionColumn).build();
+    }
+
+    public ReadWithPartitions<T> withLowerBound(int lowerBound) {
+      return toBuilder().setLowerBound(lowerBound).build();
+    }
+
+    public ReadWithPartitions<T> withUpperBound(int upperBound) {
+      return toBuilder().setUpperBound(upperBound).build();
+    }
+
+    /** Name of the table in the external database. Can be used to pass a 
user-defined subqery. */
+    public ReadWithPartitions<T> withTable(String tableName) {
+      checkNotNull(tableName, "table can not be null");
+      return toBuilder().setTable(tableName).build();
+    }
+
+    @Override
+    public PCollection<T> expand(PBegin input) {
+      checkNotNull(getRowMapper(), "withRowMapper() is required");
+      checkNotNull(getCoder(), "withCoder() is required");
+      checkNotNull(
+          getDataSourceProviderFn(),
+          "withDataSourceConfiguration() or withDataSourceProviderFn() is 
required");
+      checkNotNull(getPartitionColumn(), "withPartitionColumn() is required");
+      checkNotNull(getTable(), "withTable() is required");
+      checkArgument(
+          getLowerBound() < getUpperBound(),
+          "The lower bound of partitioning column is larger or equal than the 
upper bound");
+      checkArgument(
+          getUpperBound() - getLowerBound() >= getNumPartitions(),
+          "The specified number of partitions is more than the difference 
between upper bound and lower bound");
+
+      if (getUpperBound() == MAX_VALUE || getLowerBound() == 0) {
+        refineBounds(input);
+      }
+
+      int stride = (getUpperBound() - getLowerBound()) / getNumPartitions();
+      PCollection<List<Integer>> params =
+          input.apply(
+              Create.of(
+                  Collections.singletonList(
+                      Arrays.asList(getLowerBound(), getUpperBound(), 
getNumPartitions()))));
+      PCollection<KV<String, Iterable<Integer>>> ranges =
+          params
+              .apply("Partitioning", ParDo.of(new PartitioningFn()))
+              .apply("Group partitions", GroupByKey.create());
+
+      return ranges.apply(
+          "Read ranges",
+          JdbcIO.<KV<String, Iterable<Integer>>, T>readAll()
+              .withDataSourceProviderFn(getDataSourceProviderFn())
+              .withFetchSize(stride)
+              .withQuery(
+                  String.format(
+                      "select * from %1$s where %2$s >= ? and %2$s < ?",
+                      getTable(), getPartitionColumn()))
+              .withCoder(getCoder())
+              .withRowMapper(getRowMapper())
+              .withParameterSetter(
+                  (PreparedStatementSetter<KV<String, Iterable<Integer>>>)
+                      (element, preparedStatement) -> {
+                        String[] range = element.getKey().split(",", -1);
+                        preparedStatement.setInt(1, 
Integer.parseInt(range[0]));
+                        preparedStatement.setInt(2, 
Integer.parseInt(range[1]));
+                      })
+              .withOutputParallelization(false));
+    }
+
+    private void refineBounds(PBegin input) {
+      Integer[] bounds =
+          JdbcUtil.getBounds(input, getTable(), getDataSourceProviderFn(), 
getPartitionColumn());
+      if (getLowerBound() == 0) {
+        withLowerBound(bounds[0]);
+      }
+      if (getUpperBound() == MAX_VALUE) {
+        withUpperBound(bounds[1]);
+      }
+    }
+
+    @Override
+    public void populateDisplayData(DisplayData.Builder builder) {

Review comment:
       Added




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to