This is an automated email from the ASF dual-hosted git repository.

chamikara pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 2fd9cd1  [BEAM-10917] Add support for BigQuery Read API in Python BEAM 
(#15185)
2fd9cd1 is described below

commit 2fd9cd1a5077f5370612922da723aa82cd78a636
Author: vachan-shetty <52260220+vachan-she...@users.noreply.github.com>
AuthorDate: Fri Aug 20 14:48:59 2021 -0400

    [BEAM-10917] Add support for BigQuery Read API in Python BEAM (#15185)
    
    * Adding support for reading from BigQuery ReadAPI in Python BEAM.
    
    * Formatting fixes.
    
    * Fixing lint errors.
    
    * Adding singleton comparison.
    
    * Some more lint fixes.
    
    * Doc fixes.
    
    * Updating docstring about DATETIME handling in fastavro.
    
    * Actually making fastavro the default and some minor fixes.
    
    * Updating 'estimate_size()'. This should improve AutoScaling.
    
    * Renaming use_fastavro flag.
    
    * Updating Docs.
    
    * Fix for failing pre-commit tests and some other fixes.
---
 sdks/python/apache_beam/io/gcp/bigquery.py         | 344 ++++++++++++++++++++-
 .../apache_beam/io/gcp/bigquery_read_it_test.py    | 161 +++++++++-
 sdks/python/setup.py                               |   1 +
 3 files changed, 494 insertions(+), 12 deletions(-)

diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py 
b/sdks/python/apache_beam/io/gcp/bigquery.py
index 38d341c..1f351ab 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery.py
@@ -270,6 +270,7 @@ encoding when writing to BigQuery.
 # pytype: skip-file
 
 import collections
+import io
 import itertools
 import json
 import logging
@@ -277,13 +278,20 @@ import random
 import time
 import uuid
 from typing import Dict
+from typing import List
+from typing import Optional
 from typing import Union
 
+import avro.schema
+import fastavro
+from avro import io as avroio
+
 import apache_beam as beam
 from apache_beam import coders
 from apache_beam import pvalue
 from apache_beam.internal.gcp.json_value import from_json_value
 from apache_beam.internal.gcp.json_value import to_json_value
+from apache_beam.io import range_trackers
 from apache_beam.io.avroio import _create_avro_source as create_avro_source
 from apache_beam.io.filesystems import CompressionTypes
 from apache_beam.io.filesystems import FileSystems
@@ -324,6 +332,7 @@ from apache_beam.utils.annotations import experimental
 try:
   from apache_beam.io.gcp.internal.clients.bigquery import DatasetReference
   from apache_beam.io.gcp.internal.clients.bigquery import TableReference
+  import google.cloud.bigquery_storage_v1 as bq_storage
 except ImportError:
   DatasetReference = None
   TableReference = None
@@ -794,7 +803,8 @@ class _CustomBigQuerySource(BoundedSource):
         bq.clean_up_temporary_dataset(self._get_project())
 
     for source in self.split_result:
-      yield SourceBundle(1.0, source, None, None)
+      yield SourceBundle(
+          weight=1.0, source=source, start_position=None, stop_position=None)
 
   def get_range_tracker(self, start_position, stop_position):
     class CustomBigQuerySourceRangeTracker(RangeTracker):
@@ -883,6 +893,276 @@ class _CustomBigQuerySource(BoundedSource):
     return table.schema, metadata_list
 
 
+class _CustomBigQueryStorageSourceBase(BoundedSource):
+  """A base class for BoundedSource implementations which read from BigQuery
+  using the BigQuery Storage API.
+
+  Args:
+    table (str, TableReference): The ID of the table. The ID must contain only
+      letters ``a-z``, ``A-Z``, numbers ``0-9``, or underscores ``_``  If
+      **dataset** argument is :data:`None` then the table argument must
+      contain the entire table reference specified as:
+      ``'PROJECT:DATASET.TABLE'`` or must specify a TableReference.
+    dataset (str): Optional ID of the dataset containing this table or
+      :data:`None` if the table argument specifies a TableReference.
+    project (str): Optional ID of the project containing this table or
+      :data:`None` if the table argument specifies a TableReference.
+    selected_fields (List[str]): Optional List of names of the fields in the
+      table that should be read. If empty, all fields will be read. If the
+      specified field is a nested field, all the sub-fields in the field will 
be
+      selected. The output field order is unrelated to the order of fields in
+      selected_fields.
+    row_restriction (str): Optional SQL text filtering statement, similar to a
+      WHERE clause in a query. Aggregates are not supported. Restricted to a
+      maximum length for 1 MB.
+  """
+
+  # The maximum number of streams which will be requested when creating a read
+  # session, regardless of the desired bundle size.
+  MAX_SPLIT_COUNT = 10000
+  # The minimum number of streams which will be requested when creating a read
+  # session, regardless of the desired bundle size. Note that the server may
+  # still choose to return fewer than ten streams based on the layout of the
+  # table.
+  MIN_SPLIT_COUNT = 10
+
+  def __init__(
+      self,
+      table: Union[str, TableReference],
+      dataset: Optional[str] = None,
+      project: Optional[str] = None,
+      selected_fields: Optional[List[str]] = None,
+      row_restriction: Optional[str] = None,
+      use_fastavro_for_direct_read: Optional[bool] = None,
+      pipeline_options: Optional[GoogleCloudOptions] = None):
+
+    self.table_reference = bigquery_tools.parse_table_reference(
+        table, dataset, project)
+    self.table = self.table_reference.tableId
+    self.dataset = self.table_reference.datasetId
+    self.project = self.table_reference.projectId
+    self.selected_fields = selected_fields
+    self.row_restriction = row_restriction
+    self.use_fastavro = \
+      True if use_fastavro_for_direct_read is None else \
+      use_fastavro_for_direct_read
+    self.pipeline_options = pipeline_options
+    self.split_result = None
+
+  def _get_parent_project(self):
+    """Returns the project that will be billed."""
+    project = self.pipeline_options.view_as(GoogleCloudOptions).project
+    if isinstance(project, vp.ValueProvider):
+      project = project.get()
+    if not project:
+      project = self.project
+    return project
+
+  def _get_table_size(self, table, dataset, project):
+    if project is None:
+      project = self._get_parent_project()
+
+    bq = bigquery_tools.BigQueryWrapper()
+    table = bq.get_table(project, dataset, table)
+    return table.numBytes
+
+  def display_data(self):
+    return {
+        'project': str(self.project),
+        'dataset': str(self.dataset),
+        'table': str(self.table),
+        'selected_fields': str(self.selected_fields),
+        'row_restriction': str(self.row_restriction),
+        'use_fastavro': str(self.use_fastavro)
+    }
+
+  def estimate_size(self):
+    # Returns the pre-filtering size of the table being read.
+    return self._get_table_size(self.table, self.dataset, self.project)
+
+  def split(self, desired_bundle_size, start_position=None, 
stop_position=None):
+    requested_session = bq_storage.types.ReadSession()
+    requested_session.table = 'projects/{}/datasets/{}/tables/{}'.format(
+        self.project, self.dataset, self.table)
+    requested_session.data_format = bq_storage.types.DataFormat.AVRO
+    if self.selected_fields is not None:
+      requested_session.read_options.selected_fields = self.selected_fields
+    if self.row_restriction is not None:
+      requested_session.read_options.row_restriction = self.row_restriction
+
+    storage_client = bq_storage.BigQueryReadClient()
+    stream_count = 0
+    if desired_bundle_size > 0:
+      table_size = self._get_table_size(self.table, self.dataset, self.project)
+      stream_count = min(
+          int(table_size / desired_bundle_size),
+          _CustomBigQueryStorageSourceBase.MAX_SPLIT_COUNT)
+    stream_count = max(
+        stream_count, _CustomBigQueryStorageSourceBase.MIN_SPLIT_COUNT)
+
+    parent = 'projects/{}'.format(self.project)
+    read_session = storage_client.create_read_session(
+        parent=parent,
+        read_session=requested_session,
+        max_stream_count=stream_count)
+    _LOGGER.info(
+        'Sent BigQuery Storage API CreateReadSession request: \n %s \n'
+        'Received response \n %s.',
+        requested_session,
+        read_session)
+
+    self.split_result = [
+        _CustomBigQueryStorageStreamSource(stream.name, self.use_fastavro)
+        for stream in read_session.streams
+    ]
+
+    for source in self.split_result:
+      yield SourceBundle(
+          weight=1.0, source=source, start_position=None, stop_position=None)
+
+  def get_range_tracker(self, start_position, stop_position):
+    class NonePositionRangeTracker(RangeTracker):
+      """A RangeTracker that always returns positions as None. Prevents the
+      BigQuery Storage source from being read() before being split()."""
+      def start_position(self):
+        return None
+
+      def stop_position(self):
+        return None
+
+    return NonePositionRangeTracker()
+
+  def read(self, range_tracker):
+    raise NotImplementedError(
+        'BigQuery storage source must be split before being read')
+
+
+class _CustomBigQueryStorageStreamSource(BoundedSource):
+  """A source representing a single stream in a read session."""
+  def __init__(self, read_stream_name: str, use_fastavro: bool):
+    self.read_stream_name = read_stream_name
+    self.use_fastavro = use_fastavro
+
+  def display_data(self):
+    return {
+        'read_stream': str(self.read_stream_name),
+    }
+
+  def estimate_size(self):
+    # The size of stream source cannot be estimate due to server-side liquid
+    # sharding.
+    # TODO: Implement progress reporting.
+    return None
+
+  def split(self, desired_bundle_size, start_position=None, 
stop_position=None):
+    # A stream source can't be split without reading from it due to
+    # server-side liquid sharding. A split will simply return the current 
source
+    # for now.
+    return SourceBundle(
+        weight=1.0,
+        source=_CustomBigQueryStorageStreamSource(
+            self.read_stream_name, self.use_fastavro),
+        start_position=None,
+        stop_position=None)
+
+  def get_range_tracker(self, start_position, stop_position):
+    # TODO: Implement dynamic work rebalancing.
+    assert start_position is None
+    # Defaulting to the start of the stream.
+    start_position = 0
+    # Since the streams are unsplittable we choose OFFSET_INFINITY as the
+    # default end offset so that all data of the source gets read.
+    stop_position = range_trackers.OffsetRangeTracker.OFFSET_INFINITY
+    range_tracker = range_trackers.OffsetRangeTracker(
+        start_position, stop_position)
+    # Ensuring that all try_split() calls will be ignored by the Rangetracker.
+    range_tracker = range_trackers.UnsplittableRangeTracker(range_tracker)
+
+    return range_tracker
+
+  def read(self, range_tracker):
+    _LOGGER.info(
+        "Started BigQuery Storage API read from stream %s.",
+        self.read_stream_name)
+    storage_client = bq_storage.BigQueryReadClient()
+    read_rows_iterator = iter(storage_client.read_rows(self.read_stream_name))
+    # Handling the case where the user might provide very selective filters
+    # which can result in read_rows_response being empty.
+    first_read_rows_response = next(read_rows_iterator, None)
+    if first_read_rows_response is None:
+      return iter([])
+
+    if self.use_fastavro:
+      row_reader = _ReadRowsResponseReaderWithFastAvro(
+          read_rows_iterator, first_read_rows_response)
+      return iter(row_reader)
+
+    row_reader = _ReadRowsResponseReader(
+        read_rows_iterator, first_read_rows_response)
+    return iter(row_reader)
+
+
+class _ReadRowsResponseReaderWithFastAvro():
+  """An iterator that deserializes ReadRowsResponses using the fastavro
+  library."""
+  def __init__(self, read_rows_iterator, read_rows_response):
+    self.read_rows_iterator = read_rows_iterator
+    self.read_rows_response = read_rows_response
+    self.avro_schema = fastavro.parse_schema(
+        json.loads(self.read_rows_response.avro_schema.schema))
+    self.bytes_reader = io.BytesIO(
+        self.read_rows_response.avro_rows.serialized_binary_rows)
+
+  def __iter__(self):
+    return self
+
+  def __next__(self):
+    try:
+      return fastavro.schemaless_reader(self.bytes_reader, self.avro_schema)
+    except StopIteration:
+      self.read_rows_response = next(self.read_rows_iterator, None)
+      if self.read_rows_response is not None:
+        self.bytes_reader = io.BytesIO(
+            self.read_rows_response.avro_rows.serialized_binary_rows)
+        return fastavro.schemaless_reader(self.bytes_reader, self.avro_schema)
+      else:
+        raise StopIteration
+
+
+class _ReadRowsResponseReader():
+  """An iterator that deserializes ReadRowsResponses."""
+  def __init__(self, read_rows_iterator, read_rows_response):
+    self.read_rows_iterator = read_rows_iterator
+    self.read_rows_response = read_rows_response
+    self.avro_schema = avro.schema.Parse(
+        self.read_rows_response.avro_schema.schema)
+    self.reader = avroio.DatumReader(self.avro_schema)
+    self.decoder = avroio.BinaryDecoder(
+        io.BytesIO(self.read_rows_response.avro_rows.serialized_binary_rows))
+    self.next_row = 0
+
+  def __iter__(self):
+    return self
+
+  def get_deserialized_row(self):
+    deserialized_row = self.reader.read(self.decoder)
+    self.next_row += 1
+    return deserialized_row
+
+  def __next__(self):
+    if self.next_row < self.read_rows_response.row_count:
+      return self.get_deserialized_row()
+
+    self.read_rows_response = next(self.read_rows_iterator, None)
+    if self.read_rows_response is not None:
+      self.decoder = avroio.BinaryDecoder(
+          io.BytesIO(self.read_rows_response.avro_rows.serialized_binary_rows))
+      self.next_row = 0
+      return self.get_deserialized_row()
+    else:
+      raise StopIteration
+
+
 @deprecated(since='2.11.0', current="WriteToBigQuery")
 class BigQuerySink(dataflow_io.NativeSink):
   """A sink based on a BigQuery table.
@@ -1837,19 +2117,38 @@ bigquery_v2_messages.TableSchema`. or a `ValueProvider` 
that has a JSON string,
 class ReadFromBigQuery(PTransform):
   """Read data from BigQuery.
 
-    This PTransform uses a BigQuery export job to take a snapshot of the table
-    on GCS, and then reads from each produced file. File format is Avro by
+    This PTransform uses either a BigQuery export job to take a snapshot of the
+    table on GCS, and then reads from each produced file (EXPORT) or reads
+    directly from BigQuery storage using BigQuery Read API (DIRECT_READ). The
+    option is specified using the 'method' :parameter. File format is Avro by
     default.
 
+    NOTE: DIRECT_READ only supports reading from BigQuery Tables currently. To
+    read the results of a query please use EXPORT.
+
+  .. warning::
+      DATETIME columns are parsed as strings in the fastavro library. As a
+      result, such columns will be converted to Python strings instead of 
native
+      Python DATETIME types.
+
   Args:
+    method: The method to use to read from BigQuery. It may be EXPORT or
+      DIRECT_READ. EXPORT invokes a BigQuery export request
+      (https://cloud.google.com/bigquery/docs/exporting-data). DIRECT_READ 
reads
+      directly from BigQuery storage using the BigQuery Read API
+      (https://cloud.google.com/bigquery/docs/reference/storage). If
+      unspecified, the default is currently EXPORT.
+    use_fastavro_for_direct_read (bool): If method is `DIRECT_READ` and
+       :data:`True`, the fastavro library is used to deserialize the data
+       received from the BigQuery Read API. The default here is :data:`True`.
     table (str, callable, ValueProvider): The ID of the table, or a callable
       that returns it. The ID must contain only letters ``a-z``, ``A-Z``,
       numbers ``0-9``, or underscores ``_``. If dataset argument is
       :data:`None` then the table argument must contain the entire table
-      reference specified as: ``'DATASET.TABLE'``
-      or ``'PROJECT:DATASET.TABLE'``. If it's a callable, it must receive one
-      argument representing an element to be written to BigQuery, and return
-      a TableReference, or a string table name as specified above.
+      reference specified as: ``'PROJECT:DATASET.TABLE'``.
+      If it's a callable, it must receive one argument representing an element
+      to be written to BigQuery, and return a TableReference, or a string table
+      name as specified above.
     dataset (str): The ID of the dataset containing this table or
       :data:`None` if the table reference is specified entirely by the table
       argument.
@@ -1900,16 +2199,21 @@ class ReadFromBigQuery(PTransform):
       https://cloud.google.com/bigquery/docs/loading-data-cloud-storage-avro\
               #avro_conversions
     temp_dataset (``apache_beam.io.gcp.internal.clients.bigquery.\
-DatasetReference``):
+        DatasetReference``):
         The dataset in which to create temporary tables when performing file
         loads. By default, a new dataset is created in the execution project 
for
         temporary tables.
    """
+  class Method(object):
+    EXPORT = 'EXPORT'  #  This is currently the default.
+    DIRECT_READ = 'DIRECT_READ'
 
   COUNTER = 0
 
-  def __init__(self, gcs_location=None, *args, **kwargs):
-    if gcs_location:
+  def __init__(self, gcs_location=None, method=None, *args, **kwargs):
+    self.method = method or ReadFromBigQuery.Method.EXPORT
+
+    if gcs_location and self.method is ReadFromBigQuery.Method.EXPORT:
       if not isinstance(gcs_location, (str, ValueProvider)):
         raise TypeError(
             '%s: gcs_location must be of type string'
@@ -1920,12 +2224,21 @@ DatasetReference``):
         gcs_location = StaticValueProvider(str, gcs_location)
 
     self.gcs_location = gcs_location
-
     self._args = args
     self._kwargs = kwargs
 
   def expand(self, pcoll):
     # TODO(BEAM-11115): Make ReadFromBQ rely on ReadAllFromBQ implementation.
+    if self.method is ReadFromBigQuery.Method.EXPORT:
+      return self._expand_export(pcoll)
+    elif self.method is ReadFromBigQuery.Method.DIRECT_READ:
+      return self._expand_direct_read(pcoll)
+    else:
+      raise ValueError(
+          'The method to read from BigQuery must be either EXPORT'
+          'or DIRECT_READ.')
+
+  def _expand_export(self, pcoll):
     temp_location = pcoll.pipeline.options.view_as(
         GoogleCloudOptions).temp_location
     job_name = pcoll.pipeline.options.view_as(GoogleCloudOptions).job_name
@@ -1960,6 +2273,15 @@ DatasetReference``):
                 **self._kwargs))
         | _PassThroughThenCleanup(files_to_remove_pcoll))
 
+  def _expand_direct_read(self, pcoll):
+    return (
+        pcoll
+        | beam.io.Read(
+            _CustomBigQueryStorageSourceBase(
+                pipeline_options=pcoll.pipeline.options,
+                *self._args,
+                **self._kwargs)))
+
 
 class ReadFromBigQueryRequest:
   """
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py 
b/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py
index 472b521..53bf567 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py
@@ -26,12 +26,14 @@ import logging
 import random
 import time
 import unittest
+import uuid
 from decimal import Decimal
 from functools import wraps
 
 import pytest
 
 import apache_beam as beam
+from apache_beam.io.gcp import bigquery_tools
 from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper
 from apache_beam.io.gcp.internal.clients import bigquery
 from apache_beam.options.value_provider import StaticValueProvider
@@ -128,7 +130,7 @@ class ReadTests(BigQueryReadIntegrationTests):
   @classmethod
   def setUpClass(cls):
     super(ReadTests, cls).setUpClass()
-    cls.table_name = 'python_write_table'
+    cls.table_name = 'python_read_table'
     cls.create_table(cls.table_name)
 
     table_id = '{}.{}'.format(cls.dataset_id, cls.table_name)
@@ -175,6 +177,163 @@ class ReadTests(BigQueryReadIntegrationTests):
       assert_that(result, equal_to(self.TABLE_DATA))
 
 
+class ReadUsingStorageApiTests(BigQueryReadIntegrationTests):
+  TABLE_DATA = [{
+      'number': 1, 'str': 'abc'
+  }, {
+      'number': 2, 'str': 'def'
+  }, {
+      'number': 3, 'str': u'你好'
+  }, {
+      'number': 4, 'str': u'привет'
+  }]
+
+  @classmethod
+  def setUpClass(cls):
+    super(ReadUsingStorageApiTests, cls).setUpClass()
+    cls.table_name = 'python_read_table'
+    cls._create_table(cls.table_name)
+
+    table_id = '{}.{}'.format(cls.dataset_id, cls.table_name)
+    cls.query = 'SELECT number, str FROM `%s`' % table_id
+
+    # Materializing the newly created Table to ensure the Read API can stream.
+    cls.temp_table_reference = cls._execute_query(cls.project, cls.query)
+
+  @classmethod
+  def tearDownClass(cls):
+    cls.bigquery_client.clean_up_temporary_dataset(cls.project)
+    super(ReadUsingStorageApiTests, cls).tearDownClass()
+
+  @classmethod
+  def _create_table(cls, table_name):
+    table_schema = bigquery.TableSchema()
+    table_field = bigquery.TableFieldSchema()
+    table_field.name = 'number'
+    table_field.type = 'INTEGER'
+    table_schema.fields.append(table_field)
+    table_field = bigquery.TableFieldSchema()
+    table_field.name = 'str'
+    table_field.type = 'STRING'
+    table_schema.fields.append(table_field)
+    table = bigquery.Table(
+        tableReference=bigquery.TableReference(
+            projectId=cls.project, datasetId=cls.dataset_id,
+            tableId=table_name),
+        schema=table_schema)
+    request = bigquery.BigqueryTablesInsertRequest(
+        projectId=cls.project, datasetId=cls.dataset_id, table=table)
+    cls.bigquery_client.client.tables.Insert(request)
+    cls.bigquery_client.insert_rows(
+        cls.project, cls.dataset_id, table_name, cls.TABLE_DATA)
+
+  @classmethod
+  def _setup_temporary_dataset(cls, project, query):
+    location = cls.bigquery_client.get_query_location(project, query, False)
+    cls.bigquery_client.create_temporary_dataset(project, location)
+
+  @classmethod
+  def _execute_query(cls, project, query):
+    query_job_name = bigquery_tools.generate_bq_job_name(
+        'materializing_table_before_reading',
+        str(uuid.uuid4())[0:10],
+        bigquery_tools.BigQueryJobTypes.QUERY,
+        '%s_%s' % (int(time.time()), random.randint(0, 1000)))
+    cls._setup_temporary_dataset(cls.project, cls.query)
+    job = cls.bigquery_client._start_query_job(
+        project,
+        query,
+        use_legacy_sql=False,
+        flatten_results=False,
+        job_id=query_job_name)
+    job_ref = job.jobReference
+    cls.bigquery_client.wait_for_bq_job(job_ref, max_retries=0)
+    return cls.bigquery_client._get_temp_table(project)
+
+  def test_iobase_source(self):
+    with beam.Pipeline(argv=self.args) as p:
+      result = (
+          p | 'Read with BigQuery Storage API' >> beam.io.ReadFromBigQuery(
+              method=beam.io.ReadFromBigQuery.Method.DIRECT_READ,
+              project=self.temp_table_reference.projectId,
+              dataset=self.temp_table_reference.datasetId,
+              table=self.temp_table_reference.tableId,
+              use_fastavro_for_direct_read=True))
+      assert_that(result, equal_to(self.TABLE_DATA))
+
+  def test_iobase_source_without_fastavro(self):
+    with beam.Pipeline(argv=self.args) as p:
+      result = (
+          p | 'Read with BigQuery Storage API' >> beam.io.ReadFromBigQuery(
+              method=beam.io.ReadFromBigQuery.Method.DIRECT_READ,
+              project=self.temp_table_reference.projectId,
+              dataset=self.temp_table_reference.datasetId,
+              table=self.temp_table_reference.tableId,
+              use_fastavro_for_direct_read=False))
+      assert_that(result, equal_to(self.TABLE_DATA))
+
+  def test_iobase_source_with_column_selection(self):
+    EXPECTED_TABLE_DATA = [{
+        'number': 1
+    }, {
+        'number': 2
+    }, {
+        'number': 3
+    }, {
+        'number': 4
+    }]
+    with beam.Pipeline(argv=self.args) as p:
+      result = (
+          p | 'Read with BigQuery Storage API' >> beam.io.ReadFromBigQuery(
+              method=beam.io.ReadFromBigQuery.Method.DIRECT_READ,
+              project=self.temp_table_reference.projectId,
+              dataset=self.temp_table_reference.datasetId,
+              table=self.temp_table_reference.tableId,
+              selected_fields=['number']))
+      assert_that(result, equal_to(EXPECTED_TABLE_DATA))
+
+  def test_iobase_source_with_row_restriction(self):
+    EXPECTED_TABLE_DATA = [{
+        'number': 3, 'str': u'你好'
+    }, {
+        'number': 4, 'str': u'привет'
+    }]
+    with beam.Pipeline(argv=self.args) as p:
+      result = (
+          p | 'Read with BigQuery Storage API' >> beam.io.ReadFromBigQuery(
+              method=beam.io.ReadFromBigQuery.Method.DIRECT_READ,
+              project=self.temp_table_reference.projectId,
+              dataset=self.temp_table_reference.datasetId,
+              table=self.temp_table_reference.tableId,
+              row_restriction='number > 2'))
+      assert_that(result, equal_to(EXPECTED_TABLE_DATA))
+
+  def test_iobase_source_with_column_selection_and_row_restriction(self):
+    EXPECTED_TABLE_DATA = [{'str': u'你好'}, {'str': u'привет'}]
+    with beam.Pipeline(argv=self.args) as p:
+      result = (
+          p | 'Read with BigQuery Storage API' >> beam.io.ReadFromBigQuery(
+              method=beam.io.ReadFromBigQuery.Method.DIRECT_READ,
+              project=self.temp_table_reference.projectId,
+              dataset=self.temp_table_reference.datasetId,
+              table=self.temp_table_reference.tableId,
+              selected_fields=['str'],
+              row_restriction='number > 2'))
+      assert_that(result, equal_to(EXPECTED_TABLE_DATA))
+
+  def test_iobase_source_with_very_selective_filters(self):
+    with beam.Pipeline(argv=self.args) as p:
+      result = (
+          p | 'Read with BigQuery Storage API' >> beam.io.ReadFromBigQuery(
+              method=beam.io.ReadFromBigQuery.Method.DIRECT_READ,
+              project=self.temp_table_reference.projectId,
+              dataset=self.temp_table_reference.datasetId,
+              table=self.temp_table_reference.tableId,
+              selected_fields=['str'],
+              row_restriction='number > 4'))
+      assert_that(result, equal_to([]))
+
+
 class ReadNewTypesTests(BigQueryReadIntegrationTests):
   @classmethod
   def setUpClass(cls):
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index 5d72a9d..338251d 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -186,6 +186,7 @@ GCP_REQUIREMENTS = [
     'google-auth>=1.18.0,<2',
     'google-cloud-datastore>=1.8.0,<2',
     'google-cloud-pubsub>=0.39.0,<2',
+    'google-cloud-bigquery-storage>=2.4.0',
     # GCP packages required by tests
     'google-cloud-bigquery>=1.6.0,<3',
     'google-cloud-core>=0.28.1,<2',

Reply via email to