Author: Armin Rigo <ar...@tunes.org> Branch: Changeset: r2890:91fcbd69bce2 Date: 2017-02-19 14:37 +0100 http://bitbucket.org/cffi/cffi/changeset/91fcbd69bce2/
Log: issue #255: comparing primitive cdatas diff --git a/c/_cffi_backend.c b/c/_cffi_backend.c --- a/c/_cffi_backend.c +++ b/c/_cffi_backend.c @@ -2031,47 +2031,97 @@ static PyObject *cdata_richcompare(PyObject *v, PyObject *w, int op) { - int res; + int v_is_ptr, w_is_ptr; PyObject *pyres; - char *v_cdata, *w_cdata; assert(CData_Check(v)); - if (!CData_Check(w)) { + + /* Comparisons involving a primitive cdata work differently than + * comparisons involving a struct/array/pointer. + * + * If v or w is a struct/array/pointer, then the other must be too + * (otherwise we return NotImplemented and leave the case to + * Python). If both are, then we compare the addresses. + * + * If v and/or w is a primitive cdata, then we convert the cdata(s) + * to regular Python objects and redo the comparison there. + */ + + v_is_ptr = !(((CDataObject *)v)->c_type->ct_flags & CT_PRIMITIVE_ANY); + w_is_ptr = CData_Check(w) && + !(((CDataObject *)w)->c_type->ct_flags & CT_PRIMITIVE_ANY); + + if (v_is_ptr && w_is_ptr) { + int res; + char *v_cdata = ((CDataObject *)v)->c_data; + char *w_cdata = ((CDataObject *)w)->c_data; + + switch (op) { + case Py_EQ: res = (v_cdata == w_cdata); break; + case Py_NE: res = (v_cdata != w_cdata); break; + case Py_LT: res = (v_cdata < w_cdata); break; + case Py_LE: res = (v_cdata <= w_cdata); break; + case Py_GT: res = (v_cdata > w_cdata); break; + case Py_GE: res = (v_cdata >= w_cdata); break; + default: res = -1; + } + pyres = res ? Py_True : Py_False; + } + else if (v_is_ptr || w_is_ptr) { pyres = Py_NotImplemented; - goto done; - } - - if ((op != Py_EQ && op != Py_NE) && - ((((CDataObject *)v)->c_type->ct_flags & CT_PRIMITIVE_ANY) || - (((CDataObject *)w)->c_type->ct_flags & CT_PRIMITIVE_ANY))) - goto Error; - - v_cdata = ((CDataObject *)v)->c_data; - w_cdata = ((CDataObject *)w)->c_data; - - switch (op) { - case Py_EQ: res = (v_cdata == w_cdata); break; - case Py_NE: res = (v_cdata != w_cdata); break; - case Py_LT: res = (v_cdata < w_cdata); break; - case Py_LE: res = (v_cdata <= w_cdata); break; - case Py_GT: res = (v_cdata > w_cdata); break; - case Py_GE: res = (v_cdata >= w_cdata); break; - default: res = -1; - } - pyres = res ? Py_True : Py_False; - done: + } + else { + PyObject *aa[2]; + int i; + + aa[0] = v; Py_INCREF(v); + aa[1] = w; Py_INCREF(w); + pyres = NULL; + + for (i = 0; i < 2; i++) { + v = aa[i]; + if (!CData_Check(v)) + continue; + w = convert_to_object(((CDataObject *)v)->c_data, + ((CDataObject *)v)->c_type); + if (w == NULL) + goto error; + if (CData_Check(w)) { + Py_DECREF(w); + PyErr_Format(PyExc_NotImplementedError, + "cannot use <cdata '%s'> in a comparison", + ((CDataObject *)v)->c_type->ct_name); + goto error; + } + aa[i] = w; + Py_DECREF(v); + } + pyres = PyObject_RichCompare(aa[0], aa[1], op); + error: + Py_DECREF(aa[1]); + Py_DECREF(aa[0]); + return pyres; + } + Py_INCREF(pyres); return pyres; - - Error: - PyErr_SetString(PyExc_TypeError, - "cannot do comparison on a primitive cdata"); - return NULL; -} - -static long cdata_hash(CDataObject *cd) -{ - return _Py_HashPointer(cd->c_data); +} + +static long cdata_hash(CDataObject *v) +{ + if (((CDataObject *)v)->c_type->ct_flags & CT_PRIMITIVE_ANY) { + PyObject *vv = convert_to_object(((CDataObject *)v)->c_data, + ((CDataObject *)v)->c_type); + if (vv == NULL) + return -1; + if (!CData_Check(vv)) { + long hash = PyObject_Hash(vv); + Py_DECREF(vv); + return hash; + } + Py_DECREF(vv); + } + return _Py_HashPointer(v->c_data); } static Py_ssize_t diff --git a/c/test_c.py b/c/test_c.py --- a/c/test_c.py +++ b/c/test_c.py @@ -27,6 +27,7 @@ .replace(r'\\U', r'\U')) u = U() str2bytes = str + strict_compare = False else: type_or_class = "class" long = int @@ -38,6 +39,7 @@ bitem2bchr = bytechr u = "" str2bytes = lambda s: bytes(s, "ascii") + strict_compare = True def size_of_int(): BInt = new_primitive_type("int") @@ -106,11 +108,11 @@ x = cast(p, -66 + (1<<199)*256) assert repr(x) == "<cdata 'signed char' -66>" assert int(x) == -66 - assert (x == cast(p, -66)) is False - assert (x != cast(p, -66)) is True + assert (x == cast(p, -66)) is True + assert (x != cast(p, -66)) is False q = new_primitive_type("short") - assert (x == cast(q, -66)) is False - assert (x != cast(q, -66)) is True + assert (x == cast(q, -66)) is True + assert (x != cast(q, -66)) is False def test_sizeof_type(): py.test.raises(TypeError, sizeof, 42.5) @@ -175,7 +177,7 @@ assert float(cast(p, 1.1)) != 1.1 # rounding error assert float(cast(p, 1E200)) == INF # limited range - assert cast(p, -1.1) != cast(p, -1.1) + assert cast(p, -1.1) == cast(p, -1.1) assert repr(float(cast(p, -0.0))) == '-0.0' assert float(cast(p, b'\x09')) == 9.0 assert float(cast(p, u+'\x09')) == 9.0 @@ -219,7 +221,7 @@ p = new_primitive_type("char") assert bool(cast(p, 'A')) is True assert bool(cast(p, '\x00')) is False # since 1.7 - assert cast(p, '\x00') != cast(p, -17*256) + assert cast(p, '\x00') == cast(p, -17*256) assert int(cast(p, 'A')) == 65 assert long(cast(p, 'A')) == 65 assert type(int(cast(p, 'A'))) is int @@ -382,23 +384,6 @@ # that it is already loaded too, so it should work assert x.load_function(BVoidP, 'sqrt') -def test_hash_differences(): - BChar = new_primitive_type("char") - BInt = new_primitive_type("int") - BFloat = new_primitive_type("float") - for i in range(1, 20): - x1 = cast(BChar, chr(i)) - x2 = cast(BInt, i) - if hash(x1) != hash(x2): - break - else: - raise AssertionError("hashes are equal") - for i in range(1, 20): - if hash(cast(BFloat, i)) != hash(float(i)): - break - else: - raise AssertionError("hashes are equal") - def test_no_len_on_nonarray(): p = new_primitive_type("int") py.test.raises(TypeError, len, cast(p, 42)) @@ -2261,12 +2246,17 @@ BVoidP = new_pointer_type(new_void_type()) p = newp(BIntP, 123) q = cast(BInt, 124) - py.test.raises(TypeError, "p < q") - py.test.raises(TypeError, "p <= q") assert (p == q) is False assert (p != q) is True - py.test.raises(TypeError, "p > q") - py.test.raises(TypeError, "p >= q") + assert (q == p) is False + assert (q != p) is True + if strict_compare: + py.test.raises(TypeError, "p < q") + py.test.raises(TypeError, "p <= q") + py.test.raises(TypeError, "q < p") + py.test.raises(TypeError, "q <= p") + py.test.raises(TypeError, "p > q") + py.test.raises(TypeError, "p >= q") r = cast(BVoidP, p) assert (p < r) is False assert (p <= r) is True @@ -3840,3 +3830,86 @@ assert len(w) == 2 # check that the warnings are associated with lines in this file assert w[1].lineno == w[0].lineno + 4 + +def test_primitive_comparison(): + def assert_eq(a, b): + assert (a == b) is True + assert (b == a) is True + assert (a != b) is False + assert (b != a) is False + assert (a < b) is False + assert (a <= b) is True + assert (a > b) is False + assert (a >= b) is True + assert (b < a) is False + assert (b <= a) is True + assert (b > a) is False + assert (b >= a) is True + assert hash(a) == hash(b) + def assert_lt(a, b): + assert (a == b) is False + assert (b == a) is False + assert (a != b) is True + assert (b != a) is True + assert (a < b) is True + assert (a <= b) is True + assert (a > b) is False + assert (a >= b) is False + assert (b < a) is False + assert (b <= a) is False + assert (b > a) is True + assert (b >= a) is True + assert hash(a) != hash(b) # (or at least, it is unlikely) + def assert_gt(a, b): + assert_lt(b, a) + def assert_ne(a, b): + assert (a == b) is False + assert (b == a) is False + assert (a != b) is True + assert (b != a) is True + if strict_compare: + py.test.raises(TypeError, "a < b") + py.test.raises(TypeError, "a <= b") + py.test.raises(TypeError, "a > b") + py.test.raises(TypeError, "a >= b") + py.test.raises(TypeError, "b < a") + py.test.raises(TypeError, "b <= a") + py.test.raises(TypeError, "b > a") + py.test.raises(TypeError, "b >= a") + elif a < b: + assert_lt(a, b) + else: + assert_lt(b, a) + assert_eq(5, 5) + assert_lt(3, 5) + assert_ne('5', 5) + # + t1 = new_primitive_type("char") + t2 = new_primitive_type("int") + t3 = new_primitive_type("unsigned char") + t4 = new_primitive_type("unsigned int") + t5 = new_primitive_type("float") + t6 = new_primitive_type("double") + assert_eq(cast(t1, 65), b'A') + assert_lt(cast(t1, 64), b'\x99') + assert_gt(cast(t1, 200), b'A') + assert_ne(cast(t1, 65), 65) + assert_eq(cast(t2, -25), -25) + assert_lt(cast(t2, -25), -24) + assert_gt(cast(t2, -25), -26) + assert_eq(cast(t3, 65), 65) + assert_ne(cast(t3, 65), b'A') + assert_ne(cast(t3, 65), cast(t1, 65)) + assert_gt(cast(t4, -1), -1) + assert_gt(cast(t4, -1), cast(t2, -1)) + assert_gt(cast(t4, -1), 99999) + assert_eq(cast(t4, -1), 256 ** size_of_int() - 1) + assert_eq(cast(t5, 3.0), 3) + assert_eq(cast(t5, 3.5), 3.5) + assert_lt(cast(t5, 3.3), 3.3) # imperfect rounding + assert_eq(cast(t6, 3.3), 3.3) + assert_eq(cast(t5, 3.5), cast(t6, 3.5)) + assert_lt(cast(t5, 3.1), cast(t6, 3.1)) # imperfect rounding + assert_eq(cast(t5, 7.0), cast(t3, 7)) + assert_lt(cast(t5, 3.1), 3.101) + assert_gt(cast(t5, 3.1), 3) diff --git a/cffi/backend_ctypes.py b/cffi/backend_ctypes.py --- a/cffi/backend_ctypes.py +++ b/cffi/backend_ctypes.py @@ -112,11 +112,20 @@ def _make_cmp(name): cmpfunc = getattr(operator, name) def cmp(self, other): - if isinstance(other, CTypesData): + v_is_ptr = not isinstance(self, CTypesGenericPrimitive) + w_is_ptr = (isinstance(other, CTypesData) and + not isinstance(other, CTypesGenericPrimitive)) + if v_is_ptr and w_is_ptr: return cmpfunc(self._convert_to_address(None), other._convert_to_address(None)) + elif v_is_ptr or w_is_ptr: + return NotImplemented else: - return NotImplemented + if isinstance(self, CTypesGenericPrimitive): + self = self._value + if isinstance(other, CTypesGenericPrimitive): + other = other._value + return cmpfunc(self, other) cmp.func_name = name return cmp @@ -128,7 +137,7 @@ __ge__ = _make_cmp('__ge__') def __hash__(self): - return hash(type(self)) ^ hash(self._convert_to_address(None)) + return hash(self._convert_to_address(None)) def _to_string(self, maxlen): raise TypeError("string(): %r" % (self,)) @@ -137,14 +146,8 @@ class CTypesGenericPrimitive(CTypesData): __slots__ = [] - def __eq__(self, other): - return self is other - - def __ne__(self, other): - return self is not other - def __hash__(self): - return object.__hash__(self) + return hash(self._value) def _get_own_repr(self): return repr(self._from_ctypes(self._value)) diff --git a/doc/source/ref.rst b/doc/source/ref.rst --- a/doc/source/ref.rst +++ b/doc/source/ref.rst @@ -602,21 +602,21 @@ | C type | writing into | reading from |other operations| +===============+========================+==================+================+ | integers | an integer or anything | a Python int or | int(), bool() | -| and enums | on which int() works | long, depending | `(******)` | -| `(*****)` | (but not a float!). | on the type | | +| and enums | on which int() works | long, depending | `(******)`, | +| `(*****)` | (but not a float!). | on the type | ``<`` | | | Must be within range. | (ver. 1.10: or a | | | | | bool) | | +---------------+------------------------+------------------+----------------+ -| ``char`` | a string of length 1 | a string of | int(), bool() | -| | or another <cdata char>| length 1 | | +| ``char`` | a string of length 1 | a string of | int(), bool(), | +| | or another <cdata char>| length 1 | ``<`` | +---------------+------------------------+------------------+----------------+ | ``wchar_t`` | a unicode of length 1 | a unicode of | | -| | (or maybe 2 if | length 1 | int(), bool() | -| | surrogates) or | (or maybe 2 if | | +| | (or maybe 2 if | length 1 | int(), bool(), | +| | surrogates) or | (or maybe 2 if | ``<`` | | | another <cdata wchar_t>| surrogates) | | +---------------+------------------------+------------------+----------------+ | ``float``, | a float or anything on | a Python float | float(), int(),| -| ``double`` | which float() works | | bool() | +| ``double`` | which float() works | | bool(), ``<`` | +---------------+------------------------+------------------+----------------+ |``long double``| another <cdata> with | a <cdata>, to | float(), int(),| | | a ``long double``, or | avoid loosing | bool() | diff --git a/doc/source/whatsnew.rst b/doc/source/whatsnew.rst --- a/doc/source/whatsnew.rst +++ b/doc/source/whatsnew.rst @@ -50,6 +50,13 @@ only in out-of-line mode. This is useful for taking the address of global variables. +* Issue #255: ``cdata`` objects of a primitive type (integers, floats, + char) are now compared and ordered by value. For example, ``<cdata + 'int' 42>`` compares equal to ``42`` and ``<cdata 'char' b'A'>`` + compares equal to ``b'A'``. Unlike C, ``<cdata 'int' -1>`` does not + compare equal to ``ffi.cast("unsigned int", -1)``: it compares + smaller, because ``-1 < 4294967295``. + v1.9 ==== diff --git a/testing/cffi0/backend_tests.py b/testing/cffi0/backend_tests.py --- a/testing/cffi0/backend_tests.py +++ b/testing/cffi0/backend_tests.py @@ -54,7 +54,8 @@ min = int(min) max = int(max) p = ffi.cast(c_decl, min) - assert p != min # no __eq__(int) + assert p == min + assert hash(p) == hash(min) assert bool(p) is bool(min) assert int(p) == min p = ffi.cast(c_decl, max) @@ -65,9 +66,9 @@ assert ffi.typeof(q) is ffi.typeof(p) and int(q) == max q = ffi.cast(c_decl, long(min - 1)) assert ffi.typeof(q) is ffi.typeof(p) and int(q) == max - assert q != p + assert q == p assert int(q) == int(p) - assert hash(q) != hash(p) # unlikely + assert hash(q) == hash(p) c_decl_ptr = '%s *' % c_decl py.test.raises(OverflowError, ffi.new, c_decl_ptr, min - 1) py.test.raises(OverflowError, ffi.new, c_decl_ptr, max + 1) @@ -882,9 +883,9 @@ assert ffi.string(ffi.cast("enum bar", -2)) == "B1" assert ffi.string(ffi.cast("enum bar", -1)) == "CC1" assert ffi.string(ffi.cast("enum bar", 1)) == "E1" - assert ffi.cast("enum bar", -2) != ffi.cast("enum bar", -2) - assert ffi.cast("enum foo", 0) != ffi.cast("enum bar", 0) - assert ffi.cast("enum bar", 0) != ffi.cast("int", 0) + assert ffi.cast("enum bar", -2) == ffi.cast("enum bar", -2) + assert ffi.cast("enum foo", 0) == ffi.cast("enum bar", 0) + assert ffi.cast("enum bar", 0) == ffi.cast("int", 0) assert repr(ffi.cast("enum bar", -1)) == "<cdata 'enum bar' -1: CC1>" assert repr(ffi.cast("enum foo", -1)) == ( # enums are unsigned, if "<cdata 'enum foo' 4294967295>") # they contain no neg value @@ -1113,15 +1114,15 @@ assert (q == None) is False assert (q != None) is True - def test_no_integer_comparison(self): + def test_integer_comparison(self): ffi = FFI(backend=self.Backend()) x = ffi.cast("int", 123) y = ffi.cast("int", 456) - py.test.raises(TypeError, "x < y") + assert x < y # z = ffi.cast("double", 78.9) - py.test.raises(TypeError, "x < z") - py.test.raises(TypeError, "z < y") + assert x > z + assert y > z def test_ffi_buffer_ptr(self): ffi = FFI(backend=self.Backend()) diff --git a/testing/cffi1/test_new_ffi_1.py b/testing/cffi1/test_new_ffi_1.py --- a/testing/cffi1/test_new_ffi_1.py +++ b/testing/cffi1/test_new_ffi_1.py @@ -137,7 +137,7 @@ min = int(min) max = int(max) p = ffi.cast(c_decl, min) - assert p != min # no __eq__(int) + assert p == min assert bool(p) is bool(min) assert int(p) == min p = ffi.cast(c_decl, max) @@ -148,9 +148,9 @@ assert ffi.typeof(q) is ffi.typeof(p) and int(q) == max q = ffi.cast(c_decl, long(min - 1)) assert ffi.typeof(q) is ffi.typeof(p) and int(q) == max - assert q != p + assert q == p assert int(q) == int(p) - assert hash(q) != hash(p) # unlikely + assert hash(q) == hash(p) c_decl_ptr = '%s *' % c_decl py.test.raises(OverflowError, ffi.new, c_decl_ptr, min - 1) py.test.raises(OverflowError, ffi.new, c_decl_ptr, max + 1) @@ -896,9 +896,9 @@ assert ffi.string(ffi.cast("enum bar", -2)) == "B1" assert ffi.string(ffi.cast("enum bar", -1)) == "CC1" assert ffi.string(ffi.cast("enum bar", 1)) == "E1" - assert ffi.cast("enum bar", -2) != ffi.cast("enum bar", -2) - assert ffi.cast("enum foq", 0) != ffi.cast("enum bar", 0) - assert ffi.cast("enum bar", 0) != ffi.cast("int", 0) + assert ffi.cast("enum bar", -2) == ffi.cast("enum bar", -2) + assert ffi.cast("enum foq", 0) == ffi.cast("enum bar", 0) + assert ffi.cast("enum bar", 0) == ffi.cast("int", 0) assert repr(ffi.cast("enum bar", -1)) == "<cdata 'enum bar' -1: CC1>" assert repr(ffi.cast("enum foq", -1)) == ( # enums are unsigned, if "<cdata 'enum foq' 4294967295>") or ( # they contain no neg value @@ -1105,14 +1105,14 @@ assert (q == None) is False assert (q != None) is True - def test_no_integer_comparison(self): + def test_integer_comparison(self): x = ffi.cast("int", 123) y = ffi.cast("int", 456) - py.test.raises(TypeError, "x < y") + assert x < y # z = ffi.cast("double", 78.9) - py.test.raises(TypeError, "x < z") - py.test.raises(TypeError, "z < y") + assert x > z + assert y > z def test_ffi_buffer_ptr(self): a = ffi.new("short *", 100) _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit