This is an automated email from the ASF dual-hosted git repository.

dkulp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/avro.git


The following commit(s) were added to refs/heads/master by this push:
     new 73506f2  Revert "Reverting the changes for AVRO-1777 which failed in 
Travis (#413)"
73506f2 is described below

commit 73506f295a7835334d31c68da35e4beabdeb4443
Author: Daniel Kulp <[email protected]>
AuthorDate: Fri Apr 5 10:01:24 2019 -0400

    Revert "Reverting the changes for AVRO-1777 which failed in Travis (#413)"
    
    This reverts commit 52a53e9520d7657bd237cae3e92e3ed015c71811.
---
 lang/py/src/avro/io.py  | 45 +++++++++++++++++++++++++++++++++++++--------
 lang/py/test/test_io.py | 11 ++++++++++-
 2 files changed, 47 insertions(+), 9 deletions(-)

diff --git a/lang/py/src/avro/io.py b/lang/py/src/avro/io.py
index 2901660..046b94c 100644
--- a/lang/py/src/avro/io.py
+++ b/lang/py/src/avro/io.py
@@ -100,6 +100,10 @@ class SchemaResolutionException(schema.AvroException):
     if readers_schema: fail_msg += "\nReader's Schema: %s" % pretty_readers
     schema.AvroException.__init__(self, fail_msg)
 
+class RecordInitializationException(schema.AvroException):
+    def __init__(self, fail_msg):
+        schema.AvroException.__init__(self, fail_msg)
+
 #
 # Validate
 #
@@ -119,15 +123,18 @@ def validate(expected_schema, datum):
       return isinstance(datum, Decimal)
     return isinstance(datum, str)
   elif schema_type == 'int':
-    return ((isinstance(datum, int) or isinstance(datum, long)) 
-            and INT_MIN_VALUE <= datum <= INT_MAX_VALUE)
+    return (((isinstance(datum, int) and not isinstance(datum, bool)) or
+            isinstance(datum, long)) and
+            INT_MIN_VALUE <= datum <= INT_MAX_VALUE)
   elif schema_type == 'long':
-    return ((isinstance(datum, int) or isinstance(datum, long)) 
-            and LONG_MIN_VALUE <= datum <= LONG_MAX_VALUE)
+    return (((isinstance(datum, int) and not isinstance(datum, bool)) or
+            isinstance(datum, long)) and
+            LONG_MIN_VALUE <= datum <= LONG_MAX_VALUE)
   elif schema_type in ['float', 'double']:
-    return (isinstance(datum, int) or isinstance(datum, long)
-            or isinstance(datum, float))
-  # Check for int, float, long and decimal
+    # Check for int, float, long and decimal
+    return (isinstance(datum, long) or
+            (isinstance(datum, int) and not isinstance(datum, bool)) or
+            isinstance(datum, float))
   elif schema_type == 'fixed':
     if (hasattr(expected_schema, 'logical_type') and
                     expected_schema.logical_type == 'decimal'):
@@ -145,6 +152,8 @@ def validate(expected_schema, datum):
         [validate(expected_schema.values, v) for v in datum.values()])
   elif schema_type in ['union', 'error_union']:
     return True in [validate(s, datum) for s in expected_schema.schemas]
+  elif schema_type == 'record' and isinstance(datum, GenericRecord):
+      return expected_schema == datum.schema
   elif schema_type in ['record', 'error', 'request']:
     return (isinstance(datum, dict) and
       False not in
@@ -813,7 +822,7 @@ class DatumReader(object):
     """
     # schema resolution
     readers_fields_dict = readers_schema.fields_dict
-    read_record = {}
+    read_record = GenericRecord(readers_schema)
     for field in writers_schema.fields:
       readers_field = readers_fields_dict.get(field.name)
       if readers_field is not None:
@@ -1030,3 +1039,23 @@ class DatumWriter(object):
     """
     for field in writers_schema.fields:
       self.write_data(field.type, datum.get(field.name), encoder)
+
+class GenericRecord(dict):
+
+    def __init__(self, record_schema, lst = []):
+        if (record_schema is None or
+                not isinstance(record_schema, schema.Schema)):
+            raise RecordInitializationException(
+                    "Cannot initialize a record with schema: {sc}".format(sc = 
record_schema))
+        dict.__init__(self, lst)
+        self.schema = record_schema
+
+    def __eq__(self, other):
+        if other is None or not isinstance(other, dict):
+            return False
+        if not dict.__eq__(self, other):
+            return False
+        if isinstance(other, GenericRecord):
+            return self.schema == other.schema
+        else:
+            return True
diff --git a/lang/py/test/test_io.py b/lang/py/test/test_io.py
index df8b180..82c301c 100644
--- a/lang/py/test/test_io.py
+++ b/lang/py/test/test_io.py
@@ -48,6 +48,8 @@ SCHEMAS_TO_VALIDATE = (
   ('{"type": "array", "items": "long"}', [1, 3, 2]),
   ('{"type": "map", "values": "long"}', {'a': 1, 'b': 3, 'c': 2}),
   ('["string", "null", "long"]', None),
+  ('["double", "boolean"]', True),
+  ('["boolean", "double"]', True),
   ("""\
    {"type": "record",
     "name": "Test",
@@ -199,6 +201,13 @@ class TestIO(unittest.TestCase):
   def test_round_trip(self):
     print_test_name('TEST ROUND TRIP')
     correct = 0
+    def are_equal(datum, round_trip_datum):
+        if datum != round_trip_datum:
+            return False
+        if type(datum) == bool:
+            return type(round_trip_datum) == bool
+        else:
+            return True
     for example_schema, datum in SCHEMAS_TO_VALIDATE:
       print 'Schema: %s' % example_schema
       print 'Datum: %s' % datum
@@ -211,7 +220,7 @@ class TestIO(unittest.TestCase):
       if isinstance(round_trip_datum, Decimal):
         round_trip_datum = round_trip_datum.to_eng_string()
         datum = str(datum)
-      if datum == round_trip_datum: correct += 1
+      if are_equal(datum, round_trip_datum): correct += 1
     self.assertEquals(correct, len(SCHEMAS_TO_VALIDATE))
 
   #

Reply via email to