http://git-wip-us.apache.org/repos/asf/beam/blob/908c8532/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/datastoreio.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/datastoreio.py b/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/datastoreio.py new file mode 100644 index 0000000..2eac4d5 --- /dev/null +++ b/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/datastoreio.py @@ -0,0 +1,391 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""A connector for reading from and writing to Google Cloud Datastore""" + +import logging + +from google.cloud.proto.datastore.v1 import datastore_pb2 +from googledatastore import helper as datastore_helper + +from apache_beam.io.google_cloud_platform.datastore.v1 import helper +from apache_beam.io.google_cloud_platform.datastore.v1 import query_splitter +from apache_beam.transforms import Create +from apache_beam.transforms import DoFn +from apache_beam.transforms import FlatMap +from apache_beam.transforms import GroupByKey +from apache_beam.transforms import Map +from apache_beam.transforms import PTransform +from apache_beam.transforms import ParDo +from apache_beam.transforms.util import Values + +__all__ = ['ReadFromDatastore', 'WriteToDatastore', 'DeleteFromDatastore'] + + +class ReadFromDatastore(PTransform): + """A ``PTransform`` for reading from Google Cloud Datastore. + + To read a ``PCollection[Entity]`` from a Cloud Datastore ``Query``, use + ``ReadFromDatastore`` transform by providing a `project` id and a `query` to + read from. You can optionally provide a `namespace` and/or specify how many + splits you want for the query through `num_splits` option. + + Note: Normally, a runner will read from Cloud Datastore in parallel across + many workers. However, when the `query` is configured with a `limit` or if the + query contains inequality filters like `GREATER_THAN, LESS_THAN` etc., then + all the returned results will be read by a single worker in order to ensure + correct data. Since data is read from a single worker, this could have + significant impact on the performance of the job. + + The semantics for the query splitting is defined below: + 1. If `num_splits` is equal to 0, then the number of splits will be chosen + dynamically at runtime based on the query data size. + + 2. Any value of `num_splits` greater than + `ReadFromDatastore._NUM_QUERY_SPLITS_MAX` will be capped at that value. + + 3. If the `query` has a user limit set, or contains inequality filters, then + `num_splits` will be ignored and no split will be performed. + + 4. Under certain cases Cloud Datastore is unable to split query to the + requested number of splits. In such cases we just use whatever the Cloud + Datastore returns. + + See https://developers.google.com/datastore/ for more details on Google Cloud + Datastore. + """ + + # An upper bound on the number of splits for a query. + _NUM_QUERY_SPLITS_MAX = 50000 + # A lower bound on the number of splits for a query. This is to ensure that + # we parellelize the query even when Datastore statistics are not available. + _NUM_QUERY_SPLITS_MIN = 12 + # Default bundle size of 64MB. + _DEFAULT_BUNDLE_SIZE_BYTES = 64 * 1024 * 1024 + + def __init__(self, project, query, namespace=None, num_splits=0): + """Initialize the ReadFromDatastore transform. + + Args: + project: The Project ID + query: Cloud Datastore query to be read from. + namespace: An optional namespace. + num_splits: Number of splits for the query. + """ + logging.warning('datastoreio read transform is experimental.') + super(ReadFromDatastore, self).__init__() + + if not project: + ValueError("Project cannot be empty") + if not query: + ValueError("Query cannot be empty") + if num_splits < 0: + ValueError("num_splits must be greater than or equal 0") + + self._project = project + # using _namespace conflicts with DisplayData._namespace + self._datastore_namespace = namespace + self._query = query + self._num_splits = num_splits + + def expand(self, pcoll): + # This is a composite transform involves the following: + # 1. Create a singleton of the user provided `query` and apply a ``ParDo`` + # that splits the query into `num_splits` and assign each split query a + # unique `int` as the key. The resulting output is of the type + # ``PCollection[(int, Query)]``. + # + # If the value of `num_splits` is less than or equal to 0, then the + # number of splits will be computed dynamically based on the size of the + # data for the `query`. + # + # 2. The resulting ``PCollection`` is sharded using a ``GroupByKey`` + # operation. The queries are extracted from the (int, Iterable[Query]) and + # flattened to output a ``PCollection[Query]``. + # + # 3. In the third step, a ``ParDo`` reads entities for each query and + # outputs a ``PCollection[Entity]``. + + queries = (pcoll.pipeline + | 'User Query' >> Create([self._query]) + | 'Split Query' >> ParDo(ReadFromDatastore.SplitQueryFn( + self._project, self._query, self._datastore_namespace, + self._num_splits))) + + sharded_queries = (queries + | GroupByKey() + | Values() + | 'flatten' >> FlatMap(lambda x: x)) + + entities = sharded_queries | 'Read' >> ParDo( + ReadFromDatastore.ReadFn(self._project, self._datastore_namespace)) + return entities + + def display_data(self): + disp_data = {'project': self._project, + 'query': str(self._query), + 'num_splits': self._num_splits} + + if self._datastore_namespace is not None: + disp_data['namespace'] = self._datastore_namespace + + return disp_data + + class SplitQueryFn(DoFn): + """A `DoFn` that splits a given query into multiple sub-queries.""" + def __init__(self, project, query, namespace, num_splits): + super(ReadFromDatastore.SplitQueryFn, self).__init__() + self._datastore = None + self._project = project + self._datastore_namespace = namespace + self._query = query + self._num_splits = num_splits + + def start_bundle(self): + self._datastore = helper.get_datastore(self._project) + + def process(self, query, *args, **kwargs): + # distinct key to be used to group query splits. + key = 1 + + # If query has a user set limit, then the query cannot be split. + if query.HasField('limit'): + return [(key, query)] + + # Compute the estimated numSplits if not specified by the user. + if self._num_splits == 0: + estimated_num_splits = ReadFromDatastore.get_estimated_num_splits( + self._project, self._datastore_namespace, self._query, + self._datastore) + else: + estimated_num_splits = self._num_splits + + logging.info("Splitting the query into %d splits", estimated_num_splits) + try: + query_splits = query_splitter.get_splits( + self._datastore, query, estimated_num_splits, + helper.make_partition(self._project, self._datastore_namespace)) + except Exception: + logging.warning("Unable to parallelize the given query: %s", query, + exc_info=True) + query_splits = [query] + + sharded_query_splits = [] + for split_query in query_splits: + sharded_query_splits.append((key, split_query)) + key += 1 + + return sharded_query_splits + + def display_data(self): + disp_data = {'project': self._project, + 'query': str(self._query), + 'num_splits': self._num_splits} + + if self._datastore_namespace is not None: + disp_data['namespace'] = self._datastore_namespace + + return disp_data + + class ReadFn(DoFn): + """A DoFn that reads entities from Cloud Datastore, for a given query.""" + def __init__(self, project, namespace=None): + super(ReadFromDatastore.ReadFn, self).__init__() + self._project = project + self._datastore_namespace = namespace + self._datastore = None + + def start_bundle(self): + self._datastore = helper.get_datastore(self._project) + + def process(self, query, *args, **kwargs): + # Returns an iterator of entities that reads in batches. + entities = helper.fetch_entities(self._project, self._datastore_namespace, + query, self._datastore) + return entities + + def display_data(self): + disp_data = {'project': self._project} + + if self._datastore_namespace is not None: + disp_data['namespace'] = self._datastore_namespace + + return disp_data + + @staticmethod + def query_latest_statistics_timestamp(project, namespace, datastore): + """Fetches the latest timestamp of statistics from Cloud Datastore. + + Cloud Datastore system tables with statistics are periodically updated. + This method fethes the latest timestamp (in microseconds) of statistics + update using the `__Stat_Total__` table. + """ + query = helper.make_latest_timestamp_query(namespace) + req = helper.make_request(project, namespace, query) + resp = datastore.run_query(req) + if len(resp.batch.entity_results) == 0: + raise RuntimeError("Datastore total statistics unavailable.") + + entity = resp.batch.entity_results[0].entity + return datastore_helper.micros_from_timestamp( + entity.properties['timestamp'].timestamp_value) + + @staticmethod + def get_estimated_size_bytes(project, namespace, query, datastore): + """Get the estimated size of the data returned by the given query. + + Cloud Datastore provides no way to get a good estimate of how large the + result of a query is going to be. Hence we use the __Stat_Kind__ system + table to get size of the entire kind as an approximate estimate, assuming + exactly 1 kind is specified in the query. + See https://cloud.google.com/datastore/docs/concepts/stats. + """ + kind = query.kind[0].name + latest_timestamp = ReadFromDatastore.query_latest_statistics_timestamp( + project, namespace, datastore) + logging.info('Latest stats timestamp for kind %s is %s', + kind, latest_timestamp) + + kind_stats_query = ( + helper.make_kind_stats_query(namespace, kind, latest_timestamp)) + + req = helper.make_request(project, namespace, kind_stats_query) + resp = datastore.run_query(req) + if len(resp.batch.entity_results) == 0: + raise RuntimeError("Datastore statistics for kind %s unavailable" % kind) + + entity = resp.batch.entity_results[0].entity + return datastore_helper.get_value(entity.properties['entity_bytes']) + + @staticmethod + def get_estimated_num_splits(project, namespace, query, datastore): + """Computes the number of splits to be performed on the given query.""" + try: + estimated_size_bytes = ReadFromDatastore.get_estimated_size_bytes( + project, namespace, query, datastore) + logging.info('Estimated size bytes for query: %s', estimated_size_bytes) + num_splits = int(min(ReadFromDatastore._NUM_QUERY_SPLITS_MAX, round( + (float(estimated_size_bytes) / + ReadFromDatastore._DEFAULT_BUNDLE_SIZE_BYTES)))) + + except Exception as e: + logging.warning('Failed to fetch estimated size bytes: %s', e) + # Fallback in case estimated size is unavailable. + num_splits = ReadFromDatastore._NUM_QUERY_SPLITS_MIN + + return max(num_splits, ReadFromDatastore._NUM_QUERY_SPLITS_MIN) + + +class _Mutate(PTransform): + """A ``PTransform`` that writes mutations to Cloud Datastore. + + Only idempotent Datastore mutation operations (upsert and delete) are + supported, as the commits are retried when failures occur. + """ + + # Max allowed Datastore write batch size. + _WRITE_BATCH_SIZE = 500 + + def __init__(self, project, mutation_fn): + """Initializes a Mutate transform. + + Args: + project: The Project ID + mutation_fn: A function that converts `entities` or `keys` to + `mutations`. + """ + self._project = project + self._mutation_fn = mutation_fn + logging.warning('datastoreio write transform is experimental.') + + def expand(self, pcoll): + return (pcoll + | 'Convert to Mutation' >> Map(self._mutation_fn) + | 'Write Mutation to Datastore' >> ParDo(_Mutate.DatastoreWriteFn( + self._project))) + + def display_data(self): + return {'project': self._project, + 'mutation_fn': self._mutation_fn.__class__.__name__} + + class DatastoreWriteFn(DoFn): + """A ``DoFn`` that write mutations to Datastore. + + Mutations are written in batches, where the maximum batch size is + `Mutate._WRITE_BATCH_SIZE`. + + Commits are non-transactional. If a commit fails because of a conflict over + an entity group, the commit will be retried. This means that the mutation + should be idempotent (`upsert` and `delete` mutations) to prevent duplicate + data or errors. + """ + def __init__(self, project): + self._project = project + self._datastore = None + self._mutations = [] + + def start_bundle(self): + self._mutations = [] + self._datastore = helper.get_datastore(self._project) + + def process(self, element): + self._mutations.append(element) + if len(self._mutations) >= _Mutate._WRITE_BATCH_SIZE: + self._flush_batch() + + def finish_bundle(self): + if self._mutations: + self._flush_batch() + self._mutations = [] + + def _flush_batch(self): + # Flush the current batch of mutations to Cloud Datastore. + helper.write_mutations(self._datastore, self._project, self._mutations) + logging.debug("Successfully wrote %d mutations.", len(self._mutations)) + self._mutations = [] + + +class WriteToDatastore(_Mutate): + """A ``PTransform`` to write a ``PCollection[Entity]`` to Cloud Datastore.""" + def __init__(self, project): + super(WriteToDatastore, self).__init__( + project, WriteToDatastore.to_upsert_mutation) + + @staticmethod + def to_upsert_mutation(entity): + if not helper.is_key_valid(entity.key): + raise ValueError('Entities to be written to the Cloud Datastore must ' + 'have complete keys:\n%s' % entity) + mutation = datastore_pb2.Mutation() + mutation.upsert.CopyFrom(entity) + return mutation + + +class DeleteFromDatastore(_Mutate): + """A ``PTransform`` to delete a ``PCollection[Key]`` from Cloud Datastore.""" + def __init__(self, project): + super(DeleteFromDatastore, self).__init__( + project, DeleteFromDatastore.to_delete_mutation) + + @staticmethod + def to_delete_mutation(key): + if not helper.is_key_valid(key): + raise ValueError('Keys to be deleted from the Cloud Datastore must be ' + 'complete:\n%s", key') + mutation = datastore_pb2.Mutation() + mutation.delete.CopyFrom(key) + return mutation
http://git-wip-us.apache.org/repos/asf/beam/blob/908c8532/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/datastoreio_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/datastoreio_test.py b/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/datastoreio_test.py new file mode 100644 index 0000000..1dd7779 --- /dev/null +++ b/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/datastoreio_test.py @@ -0,0 +1,237 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from google.cloud.proto.datastore.v1 import datastore_pb2 +from google.cloud.proto.datastore.v1 import query_pb2 +from google.protobuf import timestamp_pb2 +from googledatastore import helper as datastore_helper +from mock import MagicMock, call, patch + +from apache_beam.io.google_cloud_platform.datastore.v1 import fake_datastore +from apache_beam.io.google_cloud_platform.datastore.v1 import helper +from apache_beam.io.google_cloud_platform.datastore.v1 import query_splitter +from apache_beam.io.google_cloud_platform.datastore.v1.datastoreio import _Mutate +from apache_beam.io.google_cloud_platform.datastore.v1.datastoreio import ReadFromDatastore +from apache_beam.io.google_cloud_platform.datastore.v1.datastoreio import WriteToDatastore + + +class DatastoreioTest(unittest.TestCase): + _PROJECT = 'project' + _KIND = 'kind' + _NAMESPACE = 'namespace' + + def setUp(self): + self._mock_datastore = MagicMock() + self._query = query_pb2.Query() + self._query.kind.add().name = self._KIND + + def test_get_estimated_size_bytes_without_namespace(self): + entity_bytes = 100 + timestamp = timestamp_pb2.Timestamp(seconds=1234) + self.check_estimated_size_bytes(entity_bytes, timestamp) + + def test_get_estimated_size_bytes_with_namespace(self): + entity_bytes = 100 + timestamp = timestamp_pb2.Timestamp(seconds=1234) + self.check_estimated_size_bytes(entity_bytes, timestamp, self._NAMESPACE) + + def test_SplitQueryFn_with_num_splits(self): + with patch.object(helper, 'get_datastore', + return_value=self._mock_datastore): + num_splits = 23 + + def fake_get_splits(datastore, query, num_splits, partition=None): + return self.split_query(query, num_splits) + + with patch.object(query_splitter, 'get_splits', + side_effect=fake_get_splits): + + split_query_fn = ReadFromDatastore.SplitQueryFn( + self._PROJECT, self._query, None, num_splits) + split_query_fn.start_bundle() + returned_split_queries = [] + for split_query in split_query_fn.process(self._query): + returned_split_queries.append(split_query) + + self.assertEqual(len(returned_split_queries), num_splits) + self.assertEqual(0, len(self._mock_datastore.run_query.call_args_list)) + self.verify_unique_keys(returned_split_queries) + + def test_SplitQueryFn_without_num_splits(self): + with patch.object(helper, 'get_datastore', + return_value=self._mock_datastore): + # Force SplitQueryFn to compute the number of query splits + num_splits = 0 + expected_num_splits = 23 + entity_bytes = (expected_num_splits * + ReadFromDatastore._DEFAULT_BUNDLE_SIZE_BYTES) + with patch.object(ReadFromDatastore, 'get_estimated_size_bytes', + return_value=entity_bytes): + + def fake_get_splits(datastore, query, num_splits, partition=None): + return self.split_query(query, num_splits) + + with patch.object(query_splitter, 'get_splits', + side_effect=fake_get_splits): + split_query_fn = ReadFromDatastore.SplitQueryFn( + self._PROJECT, self._query, None, num_splits) + split_query_fn.start_bundle() + returned_split_queries = [] + for split_query in split_query_fn.process(self._query): + returned_split_queries.append(split_query) + + self.assertEqual(len(returned_split_queries), expected_num_splits) + self.assertEqual(0, + len(self._mock_datastore.run_query.call_args_list)) + self.verify_unique_keys(returned_split_queries) + + def test_SplitQueryFn_with_query_limit(self): + """A test that verifies no split is performed when the query has a limit.""" + with patch.object(helper, 'get_datastore', + return_value=self._mock_datastore): + self._query.limit.value = 3 + split_query_fn = ReadFromDatastore.SplitQueryFn( + self._PROJECT, self._query, None, 4) + split_query_fn.start_bundle() + returned_split_queries = [] + for split_query in split_query_fn.process(self._query): + returned_split_queries.append(split_query) + + self.assertEqual(1, len(returned_split_queries)) + self.assertEqual(0, len(self._mock_datastore.method_calls)) + + def test_SplitQueryFn_with_exception(self): + """A test that verifies that no split is performed when failures occur.""" + with patch.object(helper, 'get_datastore', + return_value=self._mock_datastore): + # Force SplitQueryFn to compute the number of query splits + num_splits = 0 + expected_num_splits = 1 + entity_bytes = (expected_num_splits * + ReadFromDatastore._DEFAULT_BUNDLE_SIZE_BYTES) + with patch.object(ReadFromDatastore, 'get_estimated_size_bytes', + return_value=entity_bytes): + + with patch.object(query_splitter, 'get_splits', + side_effect=ValueError("Testing query split error")): + split_query_fn = ReadFromDatastore.SplitQueryFn( + self._PROJECT, self._query, None, num_splits) + split_query_fn.start_bundle() + returned_split_queries = [] + for split_query in split_query_fn.process(self._query): + returned_split_queries.append(split_query) + + self.assertEqual(len(returned_split_queries), expected_num_splits) + self.assertEqual(returned_split_queries[0][1], self._query) + self.assertEqual(0, + len(self._mock_datastore.run_query.call_args_list)) + self.verify_unique_keys(returned_split_queries) + + def test_DatastoreWriteFn_with_emtpy_batch(self): + self.check_DatastoreWriteFn(0) + + def test_DatastoreWriteFn_with_one_batch(self): + num_entities_to_write = _Mutate._WRITE_BATCH_SIZE * 1 - 50 + self.check_DatastoreWriteFn(num_entities_to_write) + + def test_DatastoreWriteFn_with_multiple_batches(self): + num_entities_to_write = _Mutate._WRITE_BATCH_SIZE * 3 + 50 + self.check_DatastoreWriteFn(num_entities_to_write) + + def test_DatastoreWriteFn_with_batch_size_exact_multiple(self): + num_entities_to_write = _Mutate._WRITE_BATCH_SIZE * 2 + self.check_DatastoreWriteFn(num_entities_to_write) + + def check_DatastoreWriteFn(self, num_entities): + """A helper function to test DatastoreWriteFn.""" + + with patch.object(helper, 'get_datastore', + return_value=self._mock_datastore): + entities = [e.entity for e in + fake_datastore.create_entities(num_entities)] + + expected_mutations = map(WriteToDatastore.to_upsert_mutation, entities) + actual_mutations = [] + + self._mock_datastore.commit.side_effect = ( + fake_datastore.create_commit(actual_mutations)) + + datastore_write_fn = _Mutate.DatastoreWriteFn(self._PROJECT) + + datastore_write_fn.start_bundle() + for mutation in expected_mutations: + datastore_write_fn.process(mutation) + datastore_write_fn.finish_bundle() + + self.assertEqual(actual_mutations, expected_mutations) + self.assertEqual((num_entities - 1) / _Mutate._WRITE_BATCH_SIZE + 1, + self._mock_datastore.commit.call_count) + + def verify_unique_keys(self, queries): + """A helper function that verifies if all the queries have unique keys.""" + keys, _ = zip(*queries) + keys = set(keys) + self.assertEqual(len(keys), len(queries)) + + def check_estimated_size_bytes(self, entity_bytes, timestamp, namespace=None): + """A helper method to test get_estimated_size_bytes""" + + timestamp_req = helper.make_request( + self._PROJECT, namespace, helper.make_latest_timestamp_query(namespace)) + timestamp_resp = self.make_stats_response( + {'timestamp': datastore_helper.from_timestamp(timestamp)}) + kind_stat_req = helper.make_request( + self._PROJECT, namespace, helper.make_kind_stats_query( + namespace, self._query.kind[0].name, + datastore_helper.micros_from_timestamp(timestamp))) + kind_stat_resp = self.make_stats_response( + {'entity_bytes': entity_bytes}) + + def fake_run_query(req): + if req == timestamp_req: + return timestamp_resp + elif req == kind_stat_req: + return kind_stat_resp + else: + print kind_stat_req + raise ValueError("Unknown req: %s" % req) + + self._mock_datastore.run_query.side_effect = fake_run_query + self.assertEqual(entity_bytes, ReadFromDatastore.get_estimated_size_bytes( + self._PROJECT, namespace, self._query, self._mock_datastore)) + self.assertEqual(self._mock_datastore.run_query.call_args_list, + [call(timestamp_req), call(kind_stat_req)]) + + def make_stats_response(self, property_map): + resp = datastore_pb2.RunQueryResponse() + entity_result = resp.batch.entity_results.add() + datastore_helper.add_properties(entity_result.entity, property_map) + return resp + + def split_query(self, query, num_splits): + """Generate dummy query splits.""" + split_queries = [] + for _ in range(0, num_splits): + q = query_pb2.Query() + q.CopyFrom(query) + split_queries.append(q) + return split_queries + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/beam/blob/908c8532/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/fake_datastore.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/fake_datastore.py b/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/fake_datastore.py new file mode 100644 index 0000000..ac8e0e0 --- /dev/null +++ b/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/fake_datastore.py @@ -0,0 +1,92 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Fake datastore used for unit testing.""" +import uuid + +from google.cloud.proto.datastore.v1 import datastore_pb2 +from google.cloud.proto.datastore.v1 import query_pb2 + + +def create_run_query(entities, batch_size): + """A fake datastore run_query method that returns entities in batches. + + Note: the outer method is needed to make the `entities` and `batch_size` + available in the scope of fake_run_query method. + + Args: + entities: list of entities supposed to be contained in the datastore. + batch_size: the number of entities that run_query method returns in one + request. + """ + def run_query(req): + start = int(req.query.start_cursor) if req.query.start_cursor else 0 + # if query limit is less than batch_size, then only return that much. + count = min(batch_size, req.query.limit.value) + # cannot go more than the number of entities contained in datastore. + end = min(len(entities), start + count) + finish = False + # Finish reading when there are no more entities to return, + # or request query limit has been satisfied. + if end == len(entities) or count == req.query.limit.value: + finish = True + return create_response(entities[start:end], str(end), finish) + return run_query + + +def create_commit(mutations): + """A fake Datastore commit method that writes the mutations to a list. + + Args: + mutations: A list to write mutations to. + + Returns: + A fake Datastore commit method + """ + + def commit(req): + for mutation in req.mutations: + mutations.append(mutation) + + return commit + + +def create_response(entities, end_cursor, finish): + """Creates a query response for a given batch of scatter entities.""" + resp = datastore_pb2.RunQueryResponse() + if finish: + resp.batch.more_results = query_pb2.QueryResultBatch.NO_MORE_RESULTS + else: + resp.batch.more_results = query_pb2.QueryResultBatch.NOT_FINISHED + + resp.batch.end_cursor = end_cursor + for entity_result in entities: + resp.batch.entity_results.add().CopyFrom(entity_result) + + return resp + + +def create_entities(count): + """Creates a list of entities with random keys.""" + entities = [] + + for _ in range(count): + entity_result = query_pb2.EntityResult() + entity_result.entity.key.path.add().name = str(uuid.uuid4()) + entities.append(entity_result) + + return entities http://git-wip-us.apache.org/repos/asf/beam/blob/908c8532/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/helper.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/helper.py b/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/helper.py new file mode 100644 index 0000000..1497862 --- /dev/null +++ b/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/helper.py @@ -0,0 +1,267 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Cloud Datastore helper functions.""" +import sys + +from google.cloud.proto.datastore.v1 import datastore_pb2 +from google.cloud.proto.datastore.v1 import entity_pb2 +from google.cloud.proto.datastore.v1 import query_pb2 +from googledatastore import PropertyFilter, CompositeFilter +from googledatastore import helper as datastore_helper +from googledatastore.connection import Datastore +from googledatastore.connection import RPCError + +from apache_beam.internal import auth +from apache_beam.utils import retry + + +def key_comparator(k1, k2): + """A comparator for Datastore keys. + + Comparison is only valid for keys in the same partition. The comparison here + is between the list of paths for each key. + """ + + if k1.partition_id != k2.partition_id: + raise ValueError('Cannot compare keys with different partition ids.') + + k2_iter = iter(k2.path) + + for k1_path in k1.path: + k2_path = next(k2_iter, None) + if not k2_path: + return 1 + + result = compare_path(k1_path, k2_path) + + if result != 0: + return result + + k2_path = next(k2_iter, None) + if k2_path: + return -1 + else: + return 0 + + +def compare_path(p1, p2): + """A comparator for key path. + + A path has either an `id` or a `name` field defined. The + comparison works with the following rules: + + 1. If one path has `id` defined while the other doesn't, then the + one with `id` defined is considered smaller. + 2. If both paths have `id` defined, then their ids are compared. + 3. If no `id` is defined for both paths, then their `names` are compared. + """ + + result = str_compare(p1.kind, p2.kind) + if result != 0: + return result + + if p1.HasField('id'): + if not p2.HasField('id'): + return -1 + + return p1.id - p2.id + + if p2.HasField('id'): + return 1 + + return str_compare(p1.name, p2.name) + + +def str_compare(s1, s2): + if s1 == s2: + return 0 + elif s1 < s2: + return -1 + else: + return 1 + + +def get_datastore(project): + """Returns a Cloud Datastore client.""" + credentials = auth.get_service_credentials() + return Datastore(project, credentials) + + +def make_request(project, namespace, query): + """Make a Cloud Datastore request for the given query.""" + req = datastore_pb2.RunQueryRequest() + req.partition_id.CopyFrom(make_partition(project, namespace)) + + req.query.CopyFrom(query) + return req + + +def make_partition(project, namespace): + """Make a PartitionId for the given project and namespace.""" + partition = entity_pb2.PartitionId() + partition.project_id = project + if namespace is not None: + partition.namespace_id = namespace + + return partition + + +def retry_on_rpc_error(exception): + """A retry filter for Cloud Datastore RPCErrors.""" + if isinstance(exception, RPCError): + if exception.code >= 500: + return True + else: + return False + else: + # TODO(vikasrk): Figure out what other errors should be retried. + return False + + +def fetch_entities(project, namespace, query, datastore): + """A helper method to fetch entities from Cloud Datastore. + + Args: + project: Project ID + namespace: Cloud Datastore namespace + query: Query to be read from + datastore: Cloud Datastore Client + + Returns: + An iterator of entities. + """ + return QueryIterator(project, namespace, query, datastore) + + +def is_key_valid(key): + """Returns True if a Cloud Datastore key is complete. + + A key is complete if its last element has either an id or a name. + """ + if not key.path: + return False + return key.path[-1].HasField('id') or key.path[-1].HasField('name') + + +def write_mutations(datastore, project, mutations): + """A helper function to write a batch of mutations to Cloud Datastore. + + If a commit fails, it will be retried upto 5 times. All mutations in the + batch will be committed again, even if the commit was partially successful. + If the retry limit is exceeded, the last exception from Cloud Datastore will + be raised. + """ + commit_request = datastore_pb2.CommitRequest() + commit_request.mode = datastore_pb2.CommitRequest.NON_TRANSACTIONAL + commit_request.project_id = project + for mutation in mutations: + commit_request.mutations.add().CopyFrom(mutation) + + @retry.with_exponential_backoff(num_retries=5, + retry_filter=retry_on_rpc_error) + def commit(req): + datastore.commit(req) + + commit(commit_request) + + +def make_latest_timestamp_query(namespace): + """Make a Query to fetch the latest timestamp statistics.""" + query = query_pb2.Query() + if namespace is None: + query.kind.add().name = '__Stat_Total__' + else: + query.kind.add().name = '__Stat_Ns_Total__' + + # Descending order of `timestamp` + datastore_helper.add_property_orders(query, "-timestamp") + # Only get the latest entity + query.limit.value = 1 + return query + + +def make_kind_stats_query(namespace, kind, latest_timestamp): + """Make a Query to fetch the latest kind statistics.""" + kind_stat_query = query_pb2.Query() + if namespace is None: + kind_stat_query.kind.add().name = '__Stat_Kind__' + else: + kind_stat_query.kind.add().name = '__Stat_Ns_Kind__' + + kind_filter = datastore_helper.set_property_filter( + query_pb2.Filter(), 'kind_name', PropertyFilter.EQUAL, unicode(kind)) + timestamp_filter = datastore_helper.set_property_filter( + query_pb2.Filter(), 'timestamp', PropertyFilter.EQUAL, + latest_timestamp) + + datastore_helper.set_composite_filter(kind_stat_query.filter, + CompositeFilter.AND, kind_filter, + timestamp_filter) + return kind_stat_query + + +class QueryIterator(object): + """A iterator class for entities of a given query. + + Entities are read in batches. Retries on failures. + """ + _NOT_FINISHED = query_pb2.QueryResultBatch.NOT_FINISHED + # Maximum number of results to request per query. + _BATCH_SIZE = 500 + + def __init__(self, project, namespace, query, datastore): + self._query = query + self._datastore = datastore + self._project = project + self._namespace = namespace + self._start_cursor = None + self._limit = self._query.limit.value or sys.maxint + self._req = make_request(project, namespace, query) + + @retry.with_exponential_backoff(num_retries=5, + retry_filter=retry_on_rpc_error) + def _next_batch(self): + """Fetches the next batch of entities.""" + if self._start_cursor is not None: + self._req.query.start_cursor = self._start_cursor + + # set batch size + self._req.query.limit.value = min(self._BATCH_SIZE, self._limit) + resp = self._datastore.run_query(self._req) + return resp + + def __iter__(self): + more_results = True + while more_results: + resp = self._next_batch() + for entity_result in resp.batch.entity_results: + yield entity_result.entity + + self._start_cursor = resp.batch.end_cursor + num_results = len(resp.batch.entity_results) + self._limit -= num_results + + # Check if we need to read more entities. + # True when query limit hasn't been satisfied and there are more entities + # to be read. The latter is true if the response has a status + # `NOT_FINISHED` or if the number of results read in the previous batch + # is equal to `_BATCH_SIZE` (all indications that there is more data be + # read). + more_results = ((self._limit > 0) and + ((num_results == self._BATCH_SIZE) or + (resp.batch.more_results == self._NOT_FINISHED))) http://git-wip-us.apache.org/repos/asf/beam/blob/908c8532/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/helper_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/helper_test.py b/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/helper_test.py new file mode 100644 index 0000000..689c462 --- /dev/null +++ b/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/helper_test.py @@ -0,0 +1,256 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for datastore helper.""" +import sys +import unittest + +from google.cloud.proto.datastore.v1 import datastore_pb2 +from google.cloud.proto.datastore.v1 import entity_pb2 +from google.cloud.proto.datastore.v1 import query_pb2 +from google.cloud.proto.datastore.v1.entity_pb2 import Key +from googledatastore.connection import RPCError +from googledatastore import helper as datastore_helper +from mock import MagicMock + +from apache_beam.io.google_cloud_platform.datastore.v1 import fake_datastore +from apache_beam.io.google_cloud_platform.datastore.v1 import helper +from apache_beam.tests.test_utils import patch_retry + + +class HelperTest(unittest.TestCase): + + def setUp(self): + self._mock_datastore = MagicMock() + self._query = query_pb2.Query() + self._query.kind.add().name = 'dummy_kind' + patch_retry(self, helper) + + def permanent_datastore_failure(self, req): + raise RPCError("dummy", 500, "failed") + + def transient_datastore_failure(self, req): + if self._transient_fail_count: + self._transient_fail_count -= 1 + raise RPCError("dummy", 500, "failed") + else: + return datastore_pb2.RunQueryResponse() + + def test_query_iterator(self): + self._mock_datastore.run_query.side_effect = ( + self.permanent_datastore_failure) + query_iterator = helper.QueryIterator("project", None, self._query, + self._mock_datastore) + self.assertRaises(RPCError, iter(query_iterator).next) + self.assertEqual(6, len(self._mock_datastore.run_query.call_args_list)) + + def test_query_iterator_with_transient_failures(self): + self._mock_datastore.run_query.side_effect = ( + self.transient_datastore_failure) + query_iterator = helper.QueryIterator("project", None, self._query, + self._mock_datastore) + fail_count = 2 + self._transient_fail_count = fail_count + for _ in query_iterator: + pass + + self.assertEqual(fail_count + 1, + len(self._mock_datastore.run_query.call_args_list)) + + def test_query_iterator_with_single_batch(self): + num_entities = 100 + batch_size = 500 + self.check_query_iterator(num_entities, batch_size, self._query) + + def test_query_iterator_with_multiple_batches(self): + num_entities = 1098 + batch_size = 500 + self.check_query_iterator(num_entities, batch_size, self._query) + + def test_query_iterator_with_exact_batch_multiple(self): + num_entities = 1000 + batch_size = 500 + self.check_query_iterator(num_entities, batch_size, self._query) + + def test_query_iterator_with_query_limit(self): + num_entities = 1098 + batch_size = 500 + self._query.limit.value = 1004 + self.check_query_iterator(num_entities, batch_size, self._query) + + def test_query_iterator_with_large_query_limit(self): + num_entities = 1098 + batch_size = 500 + self._query.limit.value = 10000 + self.check_query_iterator(num_entities, batch_size, self._query) + + def check_query_iterator(self, num_entities, batch_size, query): + """A helper method to test the QueryIterator. + + Args: + num_entities: number of entities contained in the fake datastore. + batch_size: the number of entities returned by fake datastore in one req. + query: the query to be executed + + """ + entities = fake_datastore.create_entities(num_entities) + self._mock_datastore.run_query.side_effect = \ + fake_datastore.create_run_query(entities, batch_size) + query_iterator = helper.QueryIterator("project", None, self._query, + self._mock_datastore) + + i = 0 + for entity in query_iterator: + self.assertEqual(entity, entities[i].entity) + i += 1 + + limit = query.limit.value if query.HasField('limit') else sys.maxint + self.assertEqual(i, min(num_entities, limit)) + + def test_is_key_valid(self): + key = entity_pb2.Key() + # Complete with name, no ancestor + datastore_helper.add_key_path(key, 'kind', 'name') + self.assertTrue(helper.is_key_valid(key)) + + key = entity_pb2.Key() + # Complete with id, no ancestor + datastore_helper.add_key_path(key, 'kind', 12) + self.assertTrue(helper.is_key_valid(key)) + + key = entity_pb2.Key() + # Incomplete, no ancestor + datastore_helper.add_key_path(key, 'kind') + self.assertFalse(helper.is_key_valid(key)) + + key = entity_pb2.Key() + # Complete with name and ancestor + datastore_helper.add_key_path(key, 'kind', 'name', 'kind2', 'name2') + self.assertTrue(helper.is_key_valid(key)) + + key = entity_pb2.Key() + # Complete with id and ancestor + datastore_helper.add_key_path(key, 'kind', 'name', 'kind2', 123) + self.assertTrue(helper.is_key_valid(key)) + + key = entity_pb2.Key() + # Incomplete with ancestor + datastore_helper.add_key_path(key, 'kind', 'name', 'kind2') + self.assertFalse(helper.is_key_valid(key)) + + key = entity_pb2.Key() + self.assertFalse(helper.is_key_valid(key)) + + def test_compare_path_with_different_kind(self): + p1 = Key.PathElement() + p1.kind = 'dummy1' + + p2 = Key.PathElement() + p2.kind = 'dummy2' + + self.assertLess(helper.compare_path(p1, p2), 0) + + def test_compare_path_with_different_id(self): + p1 = Key.PathElement() + p1.kind = 'dummy' + p1.id = 10 + + p2 = Key.PathElement() + p2.kind = 'dummy' + p2.id = 15 + + self.assertLess(helper.compare_path(p1, p2), 0) + + def test_compare_path_with_different_name(self): + p1 = Key.PathElement() + p1.kind = 'dummy' + p1.name = "dummy1" + + p2 = Key.PathElement() + p2.kind = 'dummy' + p2.name = 'dummy2' + + self.assertLess(helper.compare_path(p1, p2), 0) + + def test_compare_path_of_different_type(self): + p1 = Key.PathElement() + p1.kind = 'dummy' + p1.id = 10 + + p2 = Key.PathElement() + p2.kind = 'dummy' + p2.name = 'dummy' + + self.assertLess(helper.compare_path(p1, p2), 0) + + def test_key_comparator_with_different_partition(self): + k1 = Key() + k1.partition_id.namespace_id = 'dummy1' + k2 = Key() + k2.partition_id.namespace_id = 'dummy2' + self.assertRaises(ValueError, helper.key_comparator, k1, k2) + + def test_key_comparator_with_single_path(self): + k1 = Key() + k2 = Key() + p1 = k1.path.add() + p2 = k2.path.add() + p1.kind = p2.kind = 'dummy' + self.assertEqual(helper.key_comparator(k1, k2), 0) + + def test_key_comparator_with_multiple_paths_1(self): + k1 = Key() + k2 = Key() + p11 = k1.path.add() + p12 = k1.path.add() + p21 = k2.path.add() + p11.kind = p12.kind = p21.kind = 'dummy' + self.assertGreater(helper.key_comparator(k1, k2), 0) + + def test_key_comparator_with_multiple_paths_2(self): + k1 = Key() + k2 = Key() + p11 = k1.path.add() + p21 = k2.path.add() + p22 = k2.path.add() + p11.kind = p21.kind = p22.kind = 'dummy' + self.assertLess(helper.key_comparator(k1, k2), 0) + + def test_key_comparator_with_multiple_paths_3(self): + k1 = Key() + k2 = Key() + p11 = k1.path.add() + p12 = k1.path.add() + p21 = k2.path.add() + p22 = k2.path.add() + p11.kind = p12.kind = p21.kind = p22.kind = 'dummy' + self.assertEqual(helper.key_comparator(k1, k2), 0) + + def test_key_comparator_with_multiple_paths_4(self): + k1 = Key() + k2 = Key() + p11 = k1.path.add() + p12 = k2.path.add() + p21 = k2.path.add() + p11.kind = p12.kind = 'dummy' + # make path2 greater than path1 + p21.kind = 'dummy1' + self.assertLess(helper.key_comparator(k1, k2), 0) + + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/beam/blob/908c8532/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/query_splitter.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/query_splitter.py b/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/query_splitter.py new file mode 100644 index 0000000..b101ad9 --- /dev/null +++ b/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/query_splitter.py @@ -0,0 +1,269 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Implements a Cloud Datastore query splitter.""" + +from apache_beam.io.google_cloud_platform.datastore.v1 import helper +from google.cloud.proto.datastore.v1 import datastore_pb2 +from google.cloud.proto.datastore.v1 import query_pb2 +from google.cloud.proto.datastore.v1.query_pb2 import PropertyFilter +from google.cloud.proto.datastore.v1.query_pb2 import CompositeFilter +from googledatastore import helper as datastore_helper + + +__all__ = [ + 'get_splits', +] + +SCATTER_PROPERTY_NAME = '__scatter__' +KEY_PROPERTY_NAME = '__key__' +# The number of keys to sample for each split. +KEYS_PER_SPLIT = 32 + +UNSUPPORTED_OPERATORS = [PropertyFilter.LESS_THAN, + PropertyFilter.LESS_THAN_OR_EQUAL, + PropertyFilter.GREATER_THAN, + PropertyFilter.GREATER_THAN_OR_EQUAL] + + +def get_splits(datastore, query, num_splits, partition=None): + """Returns a list of sharded queries for the given Cloud Datastore query. + + This will create up to the desired number of splits, however it may return + less splits if the desired number of splits is unavailable. This will happen + if the number of split points provided by the underlying Datastore is less + than the desired number, which will occur if the number of results for the + query is too small. + + This implementation of the QuerySplitter uses the __scatter__ property to + gather random split points for a query. + + Note: This implementation is derived from the java query splitter in + https://github.com/GoogleCloudPlatform/google-cloud-datastore/blob/master/java/datastore/src/main/java/com/google/datastore/v1/client/QuerySplitterImpl.java + + Args: + datastore: the datastore client. + query: the query to split. + num_splits: the desired number of splits. + partition: the partition the query is running in. + + Returns: + A list of split queries, of a max length of `num_splits` + """ + + # Validate that the number of splits is not out of bounds. + if num_splits < 1: + raise ValueError('The number of splits must be greater than 0.') + + if num_splits == 1: + return [query] + + _validate_query(query) + + splits = [] + scatter_keys = _get_scatter_keys(datastore, query, num_splits, partition) + last_key = None + for next_key in _get_split_key(scatter_keys, num_splits): + splits.append(_create_split(last_key, next_key, query)) + last_key = next_key + + splits.append(_create_split(last_key, None, query)) + return splits + + +def _validate_query(query): + """ Verifies that the given query can be properly scattered.""" + + if len(query.kind) != 1: + raise ValueError('Query must have exactly one kind.') + + if len(query.order) != 0: + raise ValueError('Query cannot have any sort orders.') + + if query.HasField('limit'): + raise ValueError('Query cannot have a limit set.') + + if query.offset > 0: + raise ValueError('Query cannot have an offset set.') + + _validate_filter(query.filter) + + +def _validate_filter(filter): + """Validates that we only have allowable filters. + + Note that equality and ancestor filters are allowed, however they may result + in inefficient sharding. + """ + + if filter.HasField('composite_filter'): + for sub_filter in filter.composite_filter.filters: + _validate_filter(sub_filter) + elif filter.HasField('property_filter'): + if filter.property_filter.op in UNSUPPORTED_OPERATORS: + raise ValueError('Query cannot have any inequality filters.') + else: + pass + + +def _create_scatter_query(query, num_splits): + """Creates a scatter query from the given user query.""" + + scatter_query = query_pb2.Query() + for kind in query.kind: + scatter_kind = scatter_query.kind.add() + scatter_kind.CopyFrom(kind) + + # ascending order + datastore_helper.add_property_orders(scatter_query, SCATTER_PROPERTY_NAME) + + # There is a split containing entities before and after each scatter entity: + # ||---*------*------*------*------*------*------*---|| * = scatter entity + # If we represent each split as a region before a scatter entity, there is an + # extra region following the last scatter point. Thus, we do not need the + # scatter entity for the last region. + scatter_query.limit.value = (num_splits - 1) * KEYS_PER_SPLIT + datastore_helper.add_projection(scatter_query, KEY_PROPERTY_NAME) + + return scatter_query + + +def _get_scatter_keys(datastore, query, num_splits, partition): + """Gets a list of split keys given a desired number of splits. + + This list will contain multiple split keys for each split. Only a single split + key will be chosen as the split point, however providing multiple keys allows + for more uniform sharding. + + Args: + numSplits: the number of desired splits. + query: the user query. + partition: the partition to run the query in. + datastore: the client to datastore containing the data. + + Returns: + A list of scatter keys returned by Datastore. + """ + scatter_point_query = _create_scatter_query(query, num_splits) + + key_splits = [] + while True: + req = datastore_pb2.RunQueryRequest() + if partition: + req.partition_id.CopyFrom(partition) + + req.query.CopyFrom(scatter_point_query) + + resp = datastore.run_query(req) + for entity_result in resp.batch.entity_results: + key_splits.append(entity_result.entity.key) + + if resp.batch.more_results != query_pb2.QueryResultBatch.NOT_FINISHED: + break + + scatter_point_query.start_cursor = resp.batch.end_cursor + scatter_point_query.limit.value -= len(resp.batch.entity_results) + + key_splits.sort(helper.key_comparator) + return key_splits + + +def _get_split_key(keys, num_splits): + """Given a list of keys and a number of splits find the keys to split on. + + Args: + keys: the list of keys. + num_splits: the number of splits. + + Returns: + A list of keys to split on. + + """ + + # If the number of keys is less than the number of splits, we are limited + # in the number of splits we can make. + if not keys or (len(keys) < (num_splits - 1)): + return keys + + # Calculate the number of keys per split. This should be KEYS_PER_SPLIT, + # but may be less if there are not KEYS_PER_SPLIT * (numSplits - 1) scatter + # entities. + # + # Consider the following dataset, where - represents an entity and + # * represents an entity that is returned as a scatter entity: + # ||---*-----*----*-----*-----*------*----*----|| + # If we want 4 splits in this data, the optimal split would look like: + # ||---*-----*----*-----*-----*------*----*----|| + # | | | + # The scatter keys in the last region are not useful to us, so we never + # request them: + # ||---*-----*----*-----*-----*------*---------|| + # | | | + # With 6 scatter keys we want to set scatter points at indexes: 1, 3, 5. + # + # We keep this as a float so that any "fractional" keys per split get + # distributed throughout the splits and don't make the last split + # significantly larger than the rest. + + num_keys_per_split = max(1.0, float(len(keys)) / (num_splits - 1)) + + split_keys = [] + + # Grab the last sample for each split, otherwise the first split will be too + # small. + for i in range(1, num_splits): + split_index = int(round(i * num_keys_per_split) - 1) + split_keys.append(keys[split_index]) + + return split_keys + + +def _create_split(last_key, next_key, query): + """Create a new {@link Query} given the query and range.. + + Args: + last_key: the previous key. If null then assumed to be the beginning. + next_key: the next key. If null then assumed to be the end. + query: the desired query. + + Returns: + A split query with fetches entities in the range [last_key, next_key) + """ + if not (last_key or next_key): + return query + + split_query = query_pb2.Query() + split_query.CopyFrom(query) + composite_filter = split_query.filter.composite_filter + composite_filter.op = CompositeFilter.AND + + if query.HasField('filter'): + composite_filter.filters.add().CopyFrom(query.filter) + + if last_key: + lower_bound = composite_filter.filters.add() + lower_bound.property_filter.property.name = KEY_PROPERTY_NAME + lower_bound.property_filter.op = PropertyFilter.GREATER_THAN_OR_EQUAL + lower_bound.property_filter.value.key_value.CopyFrom(last_key) + + if next_key: + upper_bound = composite_filter.filters.add() + upper_bound.property_filter.property.name = KEY_PROPERTY_NAME + upper_bound.property_filter.op = PropertyFilter.LESS_THAN + upper_bound.property_filter.value.key_value.CopyFrom(next_key) + + return split_query http://git-wip-us.apache.org/repos/asf/beam/blob/908c8532/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/query_splitter_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/query_splitter_test.py b/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/query_splitter_test.py new file mode 100644 index 0000000..676f311 --- /dev/null +++ b/sdks/python/apache_beam/io/google_cloud_platform/datastore/v1/query_splitter_test.py @@ -0,0 +1,201 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Cloud Datastore query splitter test.""" + +import unittest + +from mock import MagicMock +from mock import call + +from apache_beam.io.google_cloud_platform.datastore.v1 import fake_datastore +from apache_beam.io.google_cloud_platform.datastore.v1 import query_splitter + +from google.cloud.proto.datastore.v1 import datastore_pb2 +from google.cloud.proto.datastore.v1 import query_pb2 +from google.cloud.proto.datastore.v1.query_pb2 import PropertyFilter + + +class QuerySplitterTest(unittest.TestCase): + + def test_get_splits_query_with_multiple_kinds(self): + query = query_pb2.Query() + query.kind.add() + query.kind.add() + self.assertRaises(ValueError, query_splitter.get_splits, None, query, 4) + + def test_get_splits_query_with_order(self): + query = query_pb2.Query() + query.kind.add() + query.order.add() + + self.assertRaises(ValueError, query_splitter.get_splits, None, query, 3) + + def test_get_splits_query_with_unsupported_filter(self): + query = query_pb2.Query() + query.kind.add() + test_filter = query.filter.composite_filter.filters.add() + test_filter.property_filter.op = PropertyFilter.GREATER_THAN + self.assertRaises(ValueError, query_splitter.get_splits, None, query, 2) + + def test_get_splits_query_with_limit(self): + query = query_pb2.Query() + query.kind.add() + query.limit.value = 10 + self.assertRaises(ValueError, query_splitter.get_splits, None, query, 2) + + def test_get_splits_query_with_offset(self): + query = query_pb2.Query() + query.kind.add() + query.offset = 10 + self.assertRaises(ValueError, query_splitter.get_splits, None, query, 2) + + def test_create_scatter_query(self): + query = query_pb2.Query() + kind = query.kind.add() + kind.name = 'shakespeare-demo' + num_splits = 10 + scatter_query = query_splitter._create_scatter_query(query, num_splits) + self.assertEqual(scatter_query.kind[0], kind) + self.assertEqual(scatter_query.limit.value, + (num_splits -1) * query_splitter.KEYS_PER_SPLIT) + self.assertEqual(scatter_query.order[0].direction, + query_pb2.PropertyOrder.ASCENDING) + self.assertEqual(scatter_query.projection[0].property.name, + query_splitter.KEY_PROPERTY_NAME) + + def test_get_splits_with_two_splits(self): + query = query_pb2.Query() + kind = query.kind.add() + kind.name = 'shakespeare-demo' + num_splits = 2 + num_entities = 97 + batch_size = 9 + + self.check_get_splits(query, num_splits, num_entities, batch_size) + + def test_get_splits_with_multiple_splits(self): + query = query_pb2.Query() + kind = query.kind.add() + kind.name = 'shakespeare-demo' + num_splits = 4 + num_entities = 369 + batch_size = 12 + + self.check_get_splits(query, num_splits, num_entities, batch_size) + + def test_get_splits_with_large_num_splits(self): + query = query_pb2.Query() + kind = query.kind.add() + kind.name = 'shakespeare-demo' + num_splits = 10 + num_entities = 4 + batch_size = 10 + + self.check_get_splits(query, num_splits, num_entities, batch_size) + + def test_get_splits_with_small_num_entities(self): + query = query_pb2.Query() + kind = query.kind.add() + kind.name = 'shakespeare-demo' + num_splits = 4 + num_entities = 50 + batch_size = 10 + + self.check_get_splits(query, num_splits, num_entities, batch_size) + + def test_get_splits_with_batch_size_exact_multiple(self): + """Test get_splits when num scatter keys is a multiple of batch size.""" + query = query_pb2.Query() + kind = query.kind.add() + kind.name = 'shakespeare-demo' + num_splits = 4 + num_entities = 400 + batch_size = 32 + + self.check_get_splits(query, num_splits, num_entities, batch_size) + + def test_get_splits_with_large_batch_size(self): + """Test get_splits when all scatter keys are retured in a single req.""" + query = query_pb2.Query() + kind = query.kind.add() + kind.name = 'shakespeare-demo' + num_splits = 4 + num_entities = 400 + batch_size = 500 + + self.check_get_splits(query, num_splits, num_entities, batch_size) + + def check_get_splits(self, query, num_splits, num_entities, batch_size): + """A helper method to test the query_splitter get_splits method. + + Args: + query: the query to be split + num_splits: number of splits + num_entities: number of scatter entities contained in the fake datastore. + batch_size: the number of entities returned by fake datastore in one req. + """ + + entities = fake_datastore.create_entities(num_entities) + mock_datastore = MagicMock() + # Assign a fake run_query method as a side_effect to the mock. + mock_datastore.run_query.side_effect = \ + fake_datastore.create_run_query(entities, batch_size) + + split_queries = query_splitter.get_splits(mock_datastore, query, num_splits) + + # if request num_splits is greater than num_entities, the best it can + # do is one entity per split. + expected_num_splits = min(num_splits, num_entities + 1) + self.assertEqual(len(split_queries), expected_num_splits) + + expected_requests = QuerySplitterTest.create_scatter_requests( + query, num_splits, batch_size, num_entities) + + expected_calls = [] + for req in expected_requests: + expected_calls.append(call(req)) + + self.assertEqual(expected_calls, mock_datastore.run_query.call_args_list) + + @staticmethod + def create_scatter_requests(query, num_splits, batch_size, num_entities): + """Creates a list of expected scatter requests from the query splitter. + + This list of requests returned is used to verify that the query splitter + made the same number of requests in the same order to datastore. + """ + + requests = [] + count = (num_splits - 1) * query_splitter.KEYS_PER_SPLIT + start_cursor = '' + i = 0 + scatter_query = query_splitter._create_scatter_query(query, count) + while i < count and i < num_entities: + request = datastore_pb2.RunQueryRequest() + request.query.CopyFrom(scatter_query) + request.query.start_cursor = start_cursor + request.query.limit.value = count - i + requests.append(request) + i += batch_size + start_cursor = str(i) + + return requests + + +if __name__ == '__main__': + unittest.main()