pabloem commented on a change in pull request #15848:
URL: https://github.com/apache/beam/pull/15848#discussion_r801978211



##########
File path: 
sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java
##########
@@ -333,31 +343,154 @@ private static Calendar 
withTimestampAndTimezone(DateTime dateTime) {
     return calendar;
   }
 
+  /**
+   * A helper for {@link ReadWithPartitions} that handles range calculations.
+   *
+   * @param <PartitionT>
+   */
+  interface JdbcReadWithPartitionsHelper<PartitionT>
+      extends PreparedStatementSetter<KV<PartitionT, PartitionT>>,
+          RowMapper<KV<Long, KV<PartitionT, PartitionT>>> {
+    static <T> JdbcReadWithPartitionsHelper<T> 
getPartitionsHelper(TypeDescriptor<T> type) {
+      // This cast is unchecked, thus this is a small type-checking risk. We 
just need
+      // to make sure that all preset helpers in `JdbcUtil.PRESET_HELPERS` are 
matched
+      // in type from their Key and their Value.
+      return (JdbcReadWithPartitionsHelper<T>) 
PRESET_HELPERS.get(type.getRawType());
+    }
+
+    Iterable<KV<PartitionT, PartitionT>> calculateRanges(
+        PartitionT lowerBound, PartitionT upperBound, Long partitions);
+
+    @Override
+    void setParameters(KV<PartitionT, PartitionT> element, PreparedStatement 
preparedStatement);
+
+    @Override
+    KV<Long, KV<PartitionT, PartitionT>> mapRow(ResultSet resultSet) throws 
Exception;
+  }
+
   /** Create partitions on a table. */
-  static class PartitioningFn extends DoFn<KV<Integer, KV<Long, Long>>, 
KV<String, Long>> {
+  static class PartitioningFn<T> extends DoFn<KV<Long, KV<T, T>>, KV<T, T>> {
+    private static final Logger LOG = 
LoggerFactory.getLogger(PartitioningFn.class);
+    final TypeDescriptor<T> partitioningColumnType;
+
+    PartitioningFn(TypeDescriptor<T> partitioningColumnType) {
+      this.partitioningColumnType = partitioningColumnType;
+    }
+
     @ProcessElement
     public void processElement(ProcessContext c) {
-      Integer numPartitions = c.element().getKey();
-      Long lowerBound = c.element().getValue().getKey();
-      Long upperBound = c.element().getValue().getValue();
-      if (lowerBound > upperBound) {
-        throw new RuntimeException(
-            String.format(
-                "Lower bound [%s] is higher than upper bound [%s]", 
lowerBound, upperBound));
-      }
-      long stride = (upperBound - lowerBound) / numPartitions + 1;
-      for (long i = lowerBound; i < upperBound - stride; i += stride) {
-        String range = String.format("%s,%s", i, i + stride);
-        KV<String, Long> kvRange = KV.of(range, 1L);
-        c.output(kvRange);
-      }
-      if (upperBound - lowerBound > stride * (numPartitions - 1)) {
-        long indexFrom = (numPartitions - 1) * stride;
-        long indexTo = upperBound + 1;
-        String range = String.format("%s,%s", indexFrom, indexTo);
-        KV<String, Long> kvRange = KV.of(range, 1L);
-        c.output(kvRange);
+      T lowerBound = c.element().getValue().getKey();
+      T upperBound = c.element().getValue().getValue();
+      JdbcReadWithPartitionsHelper<T> helper =
+          
JdbcReadWithPartitionsHelper.getPartitionsHelper(partitioningColumnType);
+      List<KV<T, T>> ranges =
+          Lists.newArrayList(helper.calculateRanges(lowerBound, upperBound, 
c.element().getKey()));
+      LOG.warn("Total of {} ranges: {}", ranges.size(), ranges);
+      for (KV<T, T> e : ranges) {
+        c.output(e);
       }
     }
   }
+
+  public static final Map<Class<?>, JdbcReadWithPartitionsHelper<?>> 
PRESET_HELPERS =
+      ImmutableMap.of(
+          Long.class,
+          new JdbcReadWithPartitionsHelper<Long>() {
+            @Override
+            public Iterable<KV<Long, Long>> calculateRanges(
+                Long lowerBound, Long upperBound, Long partitions) {
+              List<KV<Long, Long>> ranges = new ArrayList<>();
+              // We divide by partitions FIRST to make sure that we can cover 
the whole LONG range.
+              // If we substract first, then we may end up with Long.MAX - 
Long.MIN, which is 2*MAX,
+              // and we'd have trouble with the pipeline.
+              long stride = (upperBound / partitions - lowerBound / 
partitions) + 1;
+              long highest = lowerBound;
+              for (long i = lowerBound; i < upperBound - stride; i += stride) {
+                ranges.add(KV.of(i, i + stride));
+                highest = i + stride;
+              }
+              if (upperBound - lowerBound > stride * (ranges.size() - 1)) {
+                long indexFrom = highest;
+                long indexTo = upperBound + 1;
+                ranges.add(KV.of(indexFrom, indexTo));
+              }
+              return ranges;
+            }
+
+            @Override
+            public void setParameters(KV<Long, Long> element, 
PreparedStatement preparedStatement) {
+              try {
+                preparedStatement.setLong(1, element.getKey());
+                preparedStatement.setLong(2, element.getValue());
+              } catch (SQLException e) {
+                throw new RuntimeException(e);
+              }
+            }
+
+            @Override
+            public KV<Long, KV<Long, Long>> mapRow(ResultSet resultSet) throws 
Exception {
+              if (resultSet.getMetaData().getColumnCount() == 3) {
+                return KV.of(
+                    resultSet.getLong(3), KV.of(resultSet.getLong(1), 
resultSet.getLong(2)));
+              } else {
+                return KV.of(0L, KV.of(resultSet.getLong(1), 
resultSet.getLong(2)));
+              }
+            }
+          },
+          DateTime.class,
+          new JdbcReadWithPartitionsHelper<DateTime>() {
+            @Override
+            public Iterable<KV<DateTime, DateTime>> calculateRanges(
+                DateTime lowerBound, DateTime upperBound, Long partitions) {
+              final List<KV<DateTime, DateTime>> result = new ArrayList<>();
+
+              final long intervalMillis = upperBound.getMillis() - 
lowerBound.getMillis();
+              final long strideMillis =

Review comment:
       done. thanks!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to