This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new fce8b6555ce Fix beam.Row.__eq__ for rows with trailing columns (#23876)
fce8b6555ce is described below

commit fce8b6555ced9fac2eef60ba5611a982476827f1
Author: aliftadvantage <lorenza...@google.com>
AuthorDate: Thu Feb 16 05:56:23 2023 -0800

    Fix beam.Row.__eq__ for rows with trailing columns (#23876)
    
    * Fix beam.Row.__eq__ for rows with trailing columns
    
    * Correct the condition to make sure we still check the insertion order of 
the elements.
    
    ---------
    
    Co-authored-by: Valentyn Tymofieiev <valen...@google.com>
---
 sdks/python/apache_beam/pvalue.py      |  7 +++++--
 sdks/python/apache_beam/pvalue_test.py | 18 ++++++++++++++++++
 2 files changed, 23 insertions(+), 2 deletions(-)

diff --git a/sdks/python/apache_beam/pvalue.py 
b/sdks/python/apache_beam/pvalue.py
index 35274d411f5..2e86c9eb51c 100644
--- a/sdks/python/apache_beam/pvalue.py
+++ b/sdks/python/apache_beam/pvalue.py
@@ -684,8 +684,11 @@ class Row(object):
     return hash(self.__dict__.items())
 
   def __eq__(self, other):
-    return type(self) == type(other) and all(
-        s == o for s, o in zip(self.__dict__.items(), other.__dict__.items()))
+    return (
+        type(self) == type(other) and
+        len(self.__dict__) == len(other.__dict__) and all(
+            s == o for s,
+            o in zip(self.__dict__.items(), other.__dict__.items())))
 
   def __reduce__(self):
     return _make_Row, tuple(self.__dict__.items())
diff --git a/sdks/python/apache_beam/pvalue_test.py 
b/sdks/python/apache_beam/pvalue_test.py
index 0ea1a8c9835..447d2327dc4 100644
--- a/sdks/python/apache_beam/pvalue_test.py
+++ b/sdks/python/apache_beam/pvalue_test.py
@@ -23,6 +23,7 @@ import unittest
 
 from apache_beam.pvalue import AsSingleton
 from apache_beam.pvalue import PValue
+from apache_beam.pvalue import Row
 from apache_beam.pvalue import TaggedOutput
 from apache_beam.testing.test_pipeline import TestPipeline
 
@@ -49,5 +50,22 @@ class TaggedValueTest(unittest.TestCase):
       TaggedOutput((1, 2, 3), 'value')
 
 
+class RowTest(unittest.TestCase):
+  def test_row_eq(self):
+    row = Row(a=1, b=2)
+    same = Row(a=1, b=2)
+    self.assertEqual(row, same)
+
+  def test_trailing_column_row_neq(self):
+    row = Row(a=1, b=2)
+    trail = Row(a=1, b=2, c=3)
+    self.assertNotEqual(row, trail)
+
+  def test_row_comparison_respects_element_order(self):
+    row = Row(a=1, b=2)
+    different = Row(b=2, a=1)
+    self.assertNotEqual(row, different)
+
+
 if __name__ == '__main__':
   unittest.main()

Reply via email to