Author: Brian Kearns <bdkea...@gmail.com> Branch: Changeset: r69273:9ba1d3bb478e Date: 2014-02-22 15:52 -0500 http://bitbucket.org/pypy/pypy/changeset/9ba1d3bb478e/
Log: implement comparison funcs for record types diff --git a/pypy/module/micronumpy/compile.py b/pypy/module/micronumpy/compile.py --- a/pypy/module/micronumpy/compile.py +++ b/pypy/module/micronumpy/compile.py @@ -5,6 +5,7 @@ import re +from pypy.interpreter import special from pypy.interpreter.baseobjspace import InternalSpaceCache, W_Root from pypy.interpreter.error import OperationError from pypy.module.micronumpy import interp_boxes @@ -74,6 +75,7 @@ def __init__(self): """NOT_RPYTHON""" self.fromcache = InternalSpaceCache(self).getorbuild + self.w_NotImplemented = special.NotImplemented(self) def _freeze_(self): return True @@ -194,6 +196,9 @@ def is_w(self, w_obj, w_what): return w_obj is w_what + def eq_w(self, w_obj, w_what): + return w_obj == w_what + def issubtype(self, w_type1, w_type2): return BoolObject(True) diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py --- a/pypy/module/micronumpy/interp_ufuncs.py +++ b/pypy/module/micronumpy/interp_ufuncs.py @@ -371,17 +371,23 @@ w_ldtype = w_lhs.get_dtype() w_rdtype = w_rhs.get_dtype() if w_ldtype.is_str_type() and w_rdtype.is_str_type() and \ - self.comparison_func: + self.comparison_func: pass elif (w_ldtype.is_str_type() or w_rdtype.is_str_type()) and \ - self.comparison_func and w_out is None: + self.comparison_func and w_out is None: return space.wrap(False) - elif (w_ldtype.is_flexible_type() or \ - w_rdtype.is_flexible_type()): - raise OperationError(space.w_TypeError, space.wrap( - 'unsupported operand dtypes %s and %s for "%s"' % \ - (w_rdtype.get_name(), w_ldtype.get_name(), - self.name))) + elif w_ldtype.is_flexible_type() or w_rdtype.is_flexible_type(): + if self.comparison_func: + if self.name == 'equal' or self.name == 'not_equal': + res = w_ldtype.eq(space, w_rdtype) + if not res: + return space.wrap(self.name == 'not_equal') + else: + return space.w_NotImplemented + else: + raise oefmt(space.w_TypeError, + 'unsupported operand dtypes %s and %s for "%s"', + w_rdtype.name, w_ldtype.name, self.name) if self.are_common_types(w_ldtype, w_rdtype): if not w_lhs.is_scalar() and w_rhs.is_scalar(): diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -3573,6 +3573,28 @@ exc = raises(ValueError, "a.view(('float32', 2))") assert exc.value[0] == 'new type not compatible with array.' + def test_record_ufuncs(self): + import numpy as np + a = np.zeros(3, dtype=[('a', 'i8'), ('b', 'i8')]) + b = np.zeros(3, dtype=[('a', 'i8'), ('b', 'i8')]) + c = np.zeros(3, dtype=[('a', 'f8'), ('b', 'f8')]) + d = np.ones(3, dtype=[('a', 'i8'), ('b', 'i8')]) + e = np.ones(3, dtype=[('a', 'i8'), ('b', 'i8'), ('c', 'i8')]) + exc = raises(TypeError, abs, a) + assert exc.value[0] == 'Not implemented for this type' + assert (a == a).all() + assert not (a != a).any() + assert (a == b).all() + assert not (a != b).any() + assert a != c + assert not a == c + assert (a != d).all() + assert not (a == d).any() + assert a != e + assert not a == e + assert np.greater(a, a) is NotImplemented + assert np.less_equal(a, a) is NotImplemented + class AppTestPyPy(BaseNumpyAppTest): def setup_class(cls): if option.runappdirect and '__pypy__' not in sys.builtin_module_names: diff --git a/pypy/module/micronumpy/types.py b/pypy/module/micronumpy/types.py --- a/pypy/module/micronumpy/types.py +++ b/pypy/module/micronumpy/types.py @@ -1944,6 +1944,20 @@ pieces.append(")") return "".join(pieces) + def eq(self, v1, v2): + assert isinstance(v1, interp_boxes.W_VoidBox) + assert isinstance(v2, interp_boxes.W_VoidBox) + s1 = v1.dtype.get_size() + s2 = v2.dtype.get_size() + assert s1 == s2 + for i in range(s1): + if v1.arr.storage[v1.ofs + i] != v2.arr.storage[v2.ofs + i]: + return False + return True + + def ne(self, v1, v2): + return not self.eq(v1, v2) + for tp in [Int32, Int64]: if tp.T == lltype.Signed: IntP = tp _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit