George-Wu commented on a change in pull request #12331:
URL: https://github.com/apache/beam/pull/12331#discussion_r463382633



##########
File path: sdks/python/apache_beam/io/gcp/dicomio.py
##########
@@ -372,52 +420,63 @@ def __init__(self, destination_dict, input_type, 
credential=None):
       credential: # type: Google credential object, if it is specified, the
       Http client will use it instead of the default one.
     """
-    self.credential = credential
     self.destination_dict = destination_dict
     # input_type pre-check
     if input_type not in ['bytes', 'fileio']:
       raise ValueError("input_type could only be 'bytes' or 'fileio'")
     self.input_type = input_type
+    self.buffer_size = buffer_size
+    self.max_workers = max_workers
+    self.credential = credential
 
   def expand(self, pcoll):
     return pcoll | beam.ParDo(
-        _StoreInstance(self.destination_dict, self.input_type, 
self.credential))
+        _StoreInstance(
+            self.destination_dict,
+            self.input_type,
+            self.buffer_size,
+            self.max_workers,
+            self.credential))
 
 
 class _StoreInstance(beam.DoFn):
   """A DoFn read or fetch dicom files then push it to a dicom store."""
-  def __init__(self, destination_dict, input_type, credential=None):
-    self.credential = credential
+  def __init__(
+      self,
+      destination_dict,
+      input_type,
+      buffer_size,
+      max_workers,
+      credential=None):
     # pre-check destination dict
     required_keys = ['project_id', 'region', 'dataset_id', 'dicom_store_id']
     for key in required_keys:
       if key not in destination_dict:
         raise ValueError('Must have %s in the dict.' % (key))
     self.destination_dict = destination_dict
     self.input_type = input_type
+    self.buffer_size = buffer_size
+    self.max_workers = max_workers
+    self.credential = credential
 
-  def process(self, element):
+  def start_bundle(self):
+    self.buffer = []
+
+  def finish_bundle(self):
+    return self._flush()
+
+  def process(self, element, window=beam.DoFn.WindowParam):
+    self.buffer.append((element, window))
+    if len(self.buffer) >= self.buffer_size:
+      self._flush()

Review comment:
       Good catch, fixed and tests added!




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to