nehsyc commented on a change in pull request #13292: URL: https://github.com/apache/beam/pull/13292#discussion_r526310348
########## File path: sdks/python/apache_beam/transforms/util.py ########## @@ -780,6 +783,48 @@ def expand(self, pcoll): self.max_buffering_duration_secs, self.clock)) + @typehints.with_input_types(Tuple[K, V]) + @typehints.with_output_types(Tuple[K, Iterable[V]]) + class WithShardedKey(PTransform): + """A GroupIntoBatches transform that outputs batched elements associated + with sharded input keys. + + The sharding is determined by the runner to balance the load during the + execution time. By default, it spreads the input elements with the same key + to all available threads executing the transform. + """ + def __init__(self, batch_size, max_buffering_duration_secs=None): + """Create a new GroupIntoBatches.WithShardedKey. + + Arguments: + batch_size: (required) How many elements should be in a batch + max_buffering_duration_secs: (optional) How long in seconds at most an + incomplete batch of elements is allowed to be buffered in the states. + The duration must be a positive second duration and should be given as + an int or float. + """ + self.batch_size = batch_size + + if max_buffering_duration_secs is not None: + assert max_buffering_duration_secs > 0, ( + 'max buffering duration should be a positive value') + self.max_buffering_duration_secs = max_buffering_duration_secs + + _pid = os.getpid() + + def expand(self, pcoll): + sharded_pcoll = pcoll | Map( + lambda x: ( Review comment: Done ########## File path: sdks/python/apache_beam/transforms/util.py ########## @@ -780,6 +783,48 @@ def expand(self, pcoll): self.max_buffering_duration_secs, self.clock)) + @typehints.with_input_types(Tuple[K, V]) + @typehints.with_output_types(Tuple[K, Iterable[V]]) + class WithShardedKey(PTransform): + """A GroupIntoBatches transform that outputs batched elements associated + with sharded input keys. + + The sharding is determined by the runner to balance the load during the + execution time. By default, it spreads the input elements with the same key + to all available threads executing the transform. + """ + def __init__(self, batch_size, max_buffering_duration_secs=None): + """Create a new GroupIntoBatches.WithShardedKey. + + Arguments: + batch_size: (required) How many elements should be in a batch + max_buffering_duration_secs: (optional) How long in seconds at most an + incomplete batch of elements is allowed to be buffered in the states. + The duration must be a positive second duration and should be given as + an int or float. + """ + self.batch_size = batch_size + + if max_buffering_duration_secs is not None: + assert max_buffering_duration_secs > 0, ( + 'max buffering duration should be a positive value') + self.max_buffering_duration_secs = max_buffering_duration_secs + + _pid = os.getpid() Review comment: I see. The pid is probably not a good choice then. I changed it to use uuid instead, analogous to the Java implementation: https://github.com/apache/beam/blob/a016ba5632e955af25cedfbc7a9bf93c9ed858ff/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java#L96 ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org