This is an automated email from the ASF dual-hosted git repository. damccorm pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new fce8b6555ce Fix beam.Row.__eq__ for rows with trailing columns (#23876) fce8b6555ce is described below commit fce8b6555ced9fac2eef60ba5611a982476827f1 Author: aliftadvantage <lorenza...@google.com> AuthorDate: Thu Feb 16 05:56:23 2023 -0800 Fix beam.Row.__eq__ for rows with trailing columns (#23876) * Fix beam.Row.__eq__ for rows with trailing columns * Correct the condition to make sure we still check the insertion order of the elements. --------- Co-authored-by: Valentyn Tymofieiev <valen...@google.com> --- sdks/python/apache_beam/pvalue.py | 7 +++++-- sdks/python/apache_beam/pvalue_test.py | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py index 35274d411f5..2e86c9eb51c 100644 --- a/sdks/python/apache_beam/pvalue.py +++ b/sdks/python/apache_beam/pvalue.py @@ -684,8 +684,11 @@ class Row(object): return hash(self.__dict__.items()) def __eq__(self, other): - return type(self) == type(other) and all( - s == o for s, o in zip(self.__dict__.items(), other.__dict__.items())) + return ( + type(self) == type(other) and + len(self.__dict__) == len(other.__dict__) and all( + s == o for s, + o in zip(self.__dict__.items(), other.__dict__.items()))) def __reduce__(self): return _make_Row, tuple(self.__dict__.items()) diff --git a/sdks/python/apache_beam/pvalue_test.py b/sdks/python/apache_beam/pvalue_test.py index 0ea1a8c9835..447d2327dc4 100644 --- a/sdks/python/apache_beam/pvalue_test.py +++ b/sdks/python/apache_beam/pvalue_test.py @@ -23,6 +23,7 @@ import unittest from apache_beam.pvalue import AsSingleton from apache_beam.pvalue import PValue +from apache_beam.pvalue import Row from apache_beam.pvalue import TaggedOutput from apache_beam.testing.test_pipeline import TestPipeline @@ -49,5 +50,22 @@ class TaggedValueTest(unittest.TestCase): TaggedOutput((1, 2, 3), 'value') +class RowTest(unittest.TestCase): + def test_row_eq(self): + row = Row(a=1, b=2) + same = Row(a=1, b=2) + self.assertEqual(row, same) + + def test_trailing_column_row_neq(self): + row = Row(a=1, b=2) + trail = Row(a=1, b=2, c=3) + self.assertNotEqual(row, trail) + + def test_row_comparison_respects_element_order(self): + row = Row(a=1, b=2) + different = Row(b=2, a=1) + self.assertNotEqual(row, different) + + if __name__ == '__main__': unittest.main()