Repository: arrow
Updated Branches:
  refs/heads/master 2e5ddfe7d -> 2c3e8b09d


ARROW-692: Integration test data generator for dictionary types

cc @BryanCutler -- sorry to take a little while to get to this. I suspect we 
may have a little bit of slippage in the JSON we are generating, so this will 
help resolve the discrepancies

Author: Wes McKinney <wes.mckin...@twosigma.com>

Closes #800 from wesm/ARROW-692 and squashes the following commits:

d04d0a60 [Wes McKinney] Disable dictionary case for now
acd9ea4c [Wes McKinney] Start dictionary test cases from 0, since C++ does not 
preserve the ids through stream-to-file
0f30d0b0 [Wes McKinney] Comment out generate_dictionary_case to get passing CI 
build
63a9ad5b [Wes McKinney] Add hack to be able to generate a column with a 
different name
548452d9 [Wes McKinney] Do not write dictionaries if length-0
4779db0c [Wes McKinney] Add dictionary JSON test data generation, passes C++ 
tests


Project: http://git-wip-us.apache.org/repos/asf/arrow/repo
Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/2c3e8b09
Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/2c3e8b09
Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/2c3e8b09

Branch: refs/heads/master
Commit: 2c3e8b09d168597260e28c331ed6cef7a744f82c
Parents: 2e5ddfe
Author: Wes McKinney <wes.mckin...@twosigma.com>
Authored: Mon Jul 3 12:35:01 2017 -0400
Committer: Wes McKinney <wes.mckin...@twosigma.com>
Committed: Mon Jul 3 12:35:01 2017 -0400

----------------------------------------------------------------------
 cpp/src/arrow/ipc/metadata.cc   |   2 +-
 integration/integration_test.py | 177 ++++++++++++++++++++++++++++-------
 2 files changed, 142 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/arrow/blob/2c3e8b09/cpp/src/arrow/ipc/metadata.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/ipc/metadata.cc b/cpp/src/arrow/ipc/metadata.cc
index 706ab2e..54f0547 100644
--- a/cpp/src/arrow/ipc/metadata.cc
+++ b/cpp/src/arrow/ipc/metadata.cc
@@ -798,7 +798,7 @@ int64_t DictionaryMemo::GetId(const std::shared_ptr<Array>& 
dictionary) {
     // Dictionary already observed, return the id
     return it->second;
   } else {
-    int64_t new_id = static_cast<int64_t>(dictionary_to_id_.size()) + 1;
+    int64_t new_id = static_cast<int64_t>(dictionary_to_id_.size());
     dictionary_to_id_[address] = new_id;
     id_to_dictionary_[new_id] = dictionary;
     return new_id;

http://git-wip-us.apache.org/repos/asf/arrow/blob/2c3e8b09/integration/integration_test.py
----------------------------------------------------------------------
diff --git a/integration/integration_test.py b/integration/integration_test.py
index cc59593..9532d4e 100644
--- a/integration/integration_test.py
+++ b/integration/integration_test.py
@@ -119,6 +119,9 @@ class Column(object):
         self.name = name
         self.count = count
 
+    def __len__(self):
+        return self.count
+
     def _get_children(self):
         return []
 
@@ -195,15 +198,21 @@ class IntegerType(PrimitiveType):
             ('bitWidth', self.bit_width)
         ])
 
-    def generate_column(self, size):
+    def generate_column(self, size, name=None):
         iinfo = np.iinfo(self.numpy_type)
+        lower_bound = max(iinfo.min, self.min_value)
+        upper_bound = min(iinfo.max, self.max_value)
+        return self.generate_range(size, lower_bound, upper_bound, name=name)
+
+    def generate_range(self, size, lower, upper, name=None):
         values = [int(x) for x in
-                  np.random.randint(max(iinfo.min, self.min_value),
-                                    min(iinfo.max, self.max_value),
-                                    size=size)]
+                  np.random.randint(lower, upper, size=size)]
 
         is_valid = self._make_is_valid(size)
-        return PrimitiveColumn(self.name, size, is_valid, values)
+
+        if name is None:
+            name = self.name
+        return PrimitiveColumn(name, size, is_valid, values)
 
 
 class DateType(IntegerType):
@@ -294,12 +303,14 @@ class FloatingPointType(PrimitiveType):
             ('precision', self.precision)
         ])
 
-    def generate_column(self, size):
+    def generate_column(self, size, name=None):
         values = np.random.randn(size) * 1000
         values = np.round(values, 3)
 
         is_valid = self._make_is_valid(size)
-        return PrimitiveColumn(self.name, size, is_valid, values)
+        if name is None:
+            name = self.name
+        return PrimitiveColumn(name, size, is_valid, values)
 
 
 class BooleanType(PrimitiveType):
@@ -313,10 +324,12 @@ class BooleanType(PrimitiveType):
     def numpy_type(self):
         return 'bool'
 
-    def generate_column(self, size):
+    def generate_column(self, size, name=None):
         values = list(map(bool, np.random.randint(0, 2, size=size)))
         is_valid = self._make_is_valid(size)
-        return PrimitiveColumn(self.name, size, is_valid, values)
+        if name is None:
+            name = self.name
+        return PrimitiveColumn(name, size, is_valid, values)
 
 
 class BinaryType(PrimitiveType):
@@ -342,7 +355,7 @@ class BinaryType(PrimitiveType):
               OrderedDict([('type', 'DATA'),
                            ('typeBitWidth', 8)])])])
 
-    def generate_column(self, size):
+    def generate_column(self, size, name=None):
         K = 7
         is_valid = self._make_is_valid(size)
         values = []
@@ -356,7 +369,9 @@ class BinaryType(PrimitiveType):
             else:
                 values.append("")
 
-        return self.column_class(self.name, size, is_valid, values)
+        if name is None:
+            name = self.name
+        return self.column_class(name, size, is_valid, values)
 
 
 class StringType(BinaryType):
@@ -368,7 +383,7 @@ class StringType(BinaryType):
     def _get_type(self):
         return OrderedDict([('name', 'utf8')])
 
-    def generate_column(self, size):
+    def generate_column(self, size, name=None):
         K = 7
         is_valid = self._make_is_valid(size)
         values = []
@@ -379,10 +394,12 @@ class StringType(BinaryType):
             else:
                 values.append("")
 
-        return self.column_class(self.name, size, is_valid, values)
+        if name is None:
+            name = self.name
+        return self.column_class(name, size, is_valid, values)
 
 
-class JSONSchema(object):
+class JsonSchema(object):
 
     def __init__(self, fields):
         self.fields = fields
@@ -447,7 +464,7 @@ class ListType(DataType):
               OrderedDict([('type', 'OFFSET'),
                            ('typeBitWidth', 32)])])])
 
-    def generate_column(self, size):
+    def generate_column(self, size, name=None):
         MAX_LIST_SIZE = 4
 
         is_valid = self._make_is_valid(size)
@@ -463,7 +480,9 @@ class ListType(DataType):
         # The offset now is the total number of elements in the child array
         values = self.value_type.generate_column(offset)
 
-        return ListColumn(self.name, size, is_valid, offsets, values)
+        if name is None:
+            name = self.name
+        return ListColumn(name, size, is_valid, offsets, values)
 
 
 class ListColumn(Column):
@@ -504,13 +523,66 @@ class StructType(DataType):
              [OrderedDict([('type', 'VALIDITY'),
                            ('typeBitWidth', 1)])])])
 
-    def generate_column(self, size):
+    def generate_column(self, size, name=None):
         is_valid = self._make_is_valid(size)
 
         field_values = [type_.generate_column(size)
                         for type_ in self.field_types]
+        if name is None:
+            name = self.name
+        return StructColumn(name, size, is_valid, field_values)
+
+
+class Dictionary(object):
+
+    def __init__(self, id_, field, values, ordered=False):
+        self.id_ = id_
+        self.field = field
+        self.values = values
+        self.ordered = ordered
+
+    def __len__(self):
+        return len(self.values)
+
+    def get_json(self):
+        dummy_batch = JsonRecordBatch(len(self.values), [self.values])
+        return OrderedDict([
+            ('id', self.id_),
+            ('data', dummy_batch.get_json())
+        ])
+
 
-        return StructColumn(self.name, size, is_valid, field_values)
+class DictionaryType(DataType):
+
+    def __init__(self, name, index_type, dictionary, nullable=True):
+        DataType.__init__(self, name, nullable=nullable)
+        assert isinstance(index_type, IntegerType)
+        assert isinstance(dictionary, Dictionary)
+
+        self.index_type = index_type
+        self.dictionary = dictionary
+
+    def get_json(self):
+        dict_field = self.dictionary.field
+        return OrderedDict([
+            ('name', self.name),
+            ('type', dict_field._get_type()),
+            ('nullable', self.nullable),
+            ('children', dict_field._get_children()),
+            ('dictionary', OrderedDict([
+                ('id', self.dictionary.id_),
+                ('indexType', self.index_type._get_type()),
+                ('isOrdered', self.dictionary.ordered)
+            ])),
+            ('typeLayout', self.index_type._get_type_layout())
+        ])
+
+    def _get_type_layout(self):
+        return self.index_type._get_type_layout()
+
+    def generate_column(self, size, name=None):
+        return self.index_type.generate_range(size, 0, len(self.dictionary),
+                                              name=name)
 
 
 class StructColumn(Column):
@@ -529,7 +601,7 @@ class StructColumn(Column):
         return [field.get_json() for field in self.field_values]
 
 
-class JSONRecordBatch(object):
+class JsonRecordBatch(object):
 
     def __init__(self, count, columns):
         self.count = count
@@ -542,18 +614,27 @@ class JSONRecordBatch(object):
         ])
 
 
-class JSONFile(object):
+class JsonFile(object):
 
-    def __init__(self, name, schema, batches):
+    def __init__(self, name, schema, batches, dictionaries=None):
         self.name = name
         self.schema = schema
+        self.dictionaries = dictionaries or []
         self.batches = batches
 
     def get_json(self):
-        return OrderedDict([
-            ('schema', self.schema.get_json()),
-            ('batches', [batch.get_json() for batch in self.batches])
-        ])
+        entries = [
+            ('schema', self.schema.get_json())
+        ]
+
+        if len(self.dictionaries) > 0:
+            entries.append(('dictionaries',
+                            [dictionary.get_json()
+                             for dictionary in self.dictionaries]))
+
+        entries.append(('batches', [batch.get_json()
+                                    for batch in self.batches]))
+        return OrderedDict(entries)
 
     def write(self, path):
         with open(path, 'wb') as f:
@@ -580,8 +661,8 @@ def get_field(name, type_, nullable=True):
         raise TypeError(dtype)
 
 
-def _generate_file(name, fields, batch_sizes):
-    schema = JSONSchema(fields)
+def _generate_file(name, fields, batch_sizes, dictionaries=None):
+    schema = JsonSchema(fields)
     batches = []
     for size in batch_sizes:
         columns = []
@@ -589,9 +670,9 @@ def _generate_file(name, fields, batch_sizes):
             col = field.generate_column(size)
             columns.append(col)
 
-        batches.append(JSONRecordBatch(size, columns))
+        batches.append(JsonRecordBatch(size, columns))
 
-    return JSONFile(name, schema, batches)
+    return JsonFile(name, schema, batches, dictionaries)
 
 
 def generate_primitive_case(batch_sizes):
@@ -645,6 +726,25 @@ def generate_nested_case():
     return _generate_file("nested", fields, batch_sizes)
 
 
+def generate_dictionary_case():
+    dict_type1 = StringType('dictionary1')
+    dict_type2 = get_field('dictionary2', 'int64')
+
+    dict1 = Dictionary(0, dict_type1,
+                       dict_type1.generate_column(10, name='DICT0'))
+    dict2 = Dictionary(1, dict_type2,
+                       dict_type2.generate_column(50, name='DICT1'))
+
+    fields = [
+        DictionaryType('dict1_0', get_field('', 'int8'), dict1),
+        DictionaryType('dict1_1', get_field('', 'int32'), dict1),
+        DictionaryType('dict2_0', get_field('', 'int16'), dict2)
+    ]
+    batch_sizes = [7, 10]
+    return _generate_file("dictionary", fields, batch_sizes,
+                          dictionaries=[dict1, dict2])
+
+
 def get_generated_json_files():
     temp_dir = tempfile.mkdtemp()
 
@@ -655,12 +755,14 @@ def get_generated_json_files():
         generate_primitive_case([7, 10]),
         generate_primitive_case([0, 0, 0]),
         generate_datetime_case(),
-        generate_nested_case()
+        generate_nested_case(),
+        # generate_dictionary_case()
     ]
 
     generated_paths = []
     for file_obj in file_objs:
-        out_path = os.path.join(temp_dir, 'generated_' + file_obj.name + 
'.json')
+        out_path = os.path.join(temp_dir, 'generated_' +
+                                file_obj.name + '.json')
         file_obj.write(out_path)
         generated_paths.append(out_path)
 
@@ -689,15 +791,16 @@ class IntegrationRunner(object):
                                                        consumer.name))
 
         for json_path in self.json_files:
-            
print('=====================================================================================')
+            print('==========================================================')
             print('Testing file {0}'.format(json_path))
-            
print('=====================================================================================')
+            print('==========================================================')
 
             name = os.path.splitext(os.path.basename(json_path))[0]
 
             # Make the random access file
             print('-- Creating binary inputs')
-            producer_file_path = os.path.join(self.temp_dir, guid() + '_' + 
name + '.json_to_arrow')
+            producer_file_path = os.path.join(self.temp_dir, guid() + '_' +
+                                              name + '.json_to_arrow')
             producer.json_to_file(json_path, producer_file_path)
 
             # Validate the file
@@ -705,8 +808,10 @@ class IntegrationRunner(object):
             consumer.validate(json_path, producer_file_path)
 
             print('-- Validating stream')
-            producer_stream_path = os.path.join(self.temp_dir, guid() + '_' + 
name + '.arrow_to_stream')
-            consumer_file_path = os.path.join(self.temp_dir, guid() + '_' + 
name + '.stream_to_arrow')
+            producer_stream_path = os.path.join(self.temp_dir, guid() + '_' +
+                                                name + '.arrow_to_stream')
+            consumer_file_path = os.path.join(self.temp_dir, guid() + '_' +
+                                              name + '.stream_to_arrow')
             producer.file_to_stream(producer_file_path,
                                     producer_stream_path)
             consumer.stream_to_file(producer_stream_path,

Reply via email to