Author: Tyler Wade <way...@gmail.com> Branch: fix-bytearray-complexity Changeset: r71872:3521f66aed64 Date: 2014-05-26 03:48 -0500 http://bitbucket.org/pypy/pypy/changeset/3521f66aed64/
Log: Most bytearray methods fixed diff --git a/pypy/objspace/std/bytearrayobject.py b/pypy/objspace/std/bytearrayobject.py --- a/pypy/objspace/std/bytearrayobject.py +++ b/pypy/objspace/std/bytearrayobject.py @@ -1,7 +1,7 @@ """The builtin bytearray implementation""" from rpython.rlib.objectmodel import ( - import_from_mixin, newlist_hint, resizelist_hint) + import_from_mixin, newlist_hint, resizelist_hint, specialize) from rpython.rlib.buffer import Buffer from rpython.rlib.rstring import StringBuilder @@ -11,7 +11,7 @@ from pypy.interpreter.signature import Signature from pypy.objspace.std.sliceobject import W_SliceObject from pypy.objspace.std.stdtypedef import StdTypeDef -from pypy.objspace.std.stringmethods import StringMethods +from pypy.objspace.std.stringmethods import StringMethods, _get_buffer from pypy.objspace.std.util import get_positive_index NON_HEX_MSG = "non-hexadecimal number found in fromhex() arg at position %d" @@ -40,7 +40,11 @@ return ''.join(self.data) def _new(self, value): - return W_BytearrayObject(_make_data(value)) + return W_BytearrayObject(value) + + def _new_from_buffer(self, buffer): + length = buffer.getlength() + return W_BytearrayObject([buffer.getitem(i) for i in range(length)]) def _new_from_list(self, value): return W_BytearrayObject(value) @@ -58,7 +62,12 @@ raise oefmt(space.w_IndexError, "bytearray index out of range") return space.wrap(ord(character)) - _val = charbuf_w + def _val(self, space): + return self.data + + @staticmethod + def _use_rstr_ops(space, w_other): + return False @staticmethod def _op_val(space, w_other): @@ -68,7 +77,9 @@ assert len(char) == 1 return str(char)[0] - _builder = StringBuilder + @staticmethod + def _builder(size=100): + return BytearrayBuilder(size) def _newlist_unwrapped(self, space, res): return space.newlist([W_BytearrayObject(_make_data(i)) for i in res]) @@ -260,58 +271,116 @@ return space.wrap(''.join(self.data)) def descr_eq(self, space, w_other): + if isinstance(w_other, W_BytearrayObject): + return space.newbool(self.data == w_other.data) + try: - res = self._val(space) == self._op_val(space, w_other) + buffer = _get_buffer(space, w_other) except OperationError as e: if e.match(space, space.w_TypeError): return space.w_NotImplemented raise - return space.newbool(res) + + value = self._val(space) + buffer_len = buffer.getlength() + + if len(value) != buffer_len: + return space.newbool(False) + + min_length = min(len(value), buffer_len) + return space.newbool(_memcmp(value, buffer, min_length) == 0) def descr_ne(self, space, w_other): + if isinstance(w_other, W_BytearrayObject): + return space.newbool(self.data != w_other.data) + try: - res = self._val(space) != self._op_val(space, w_other) + buffer = _get_buffer(space, w_other) except OperationError as e: if e.match(space, space.w_TypeError): return space.w_NotImplemented raise - return space.newbool(res) + + value = self._val(space) + buffer_len = buffer.getlength() + + if len(value) != buffer_len: + return space.newbool(True) + + min_length = min(len(value), buffer_len) + return space.newbool(_memcmp(value, buffer, min_length) != 0) def descr_lt(self, space, w_other): + if isinstance(w_other, W_BytearrayObject): + return space.newbool(self.data < w_other.data) + try: - res = self._val(space) < self._op_val(space, w_other) + buffer = _get_buffer(space, w_other) except OperationError as e: if e.match(space, space.w_TypeError): return space.w_NotImplemented raise - return space.newbool(res) + + value = self._val(space) + buffer_len = buffer.getlength() + + cmp = _memcmp(value, buffer, min(len(value), buffer_len)) + return space.newbool( + cmp < 0 or (cmp == 0 and space.newbool(len(value) < buffer_len))) def descr_le(self, space, w_other): + if isinstance(w_other, W_BytearrayObject): + return space.newbool(self.data <= w_other.data) + try: - res = self._val(space) <= self._op_val(space, w_other) + buffer = _get_buffer(space, w_other) except OperationError as e: if e.match(space, space.w_TypeError): return space.w_NotImplemented raise - return space.newbool(res) + + value = self._val(space) + buffer_len = buffer.getlength() + + cmp = _memcmp(value, buffer, min(len(value), buffer_len)) + return space.newbool( + cmp < 0 or (cmp == 0 and space.newbool(len(value) <= buffer_len))) def descr_gt(self, space, w_other): + if isinstance(w_other, W_BytearrayObject): + return space.newbool(self.data > w_other.data) + try: - res = self._val(space) > self._op_val(space, w_other) + buffer = _get_buffer(space, w_other) except OperationError as e: if e.match(space, space.w_TypeError): return space.w_NotImplemented raise - return space.newbool(res) + + value = self._val(space) + buffer_len = buffer.getlength() + + cmp = _memcmp(value, buffer, min(len(value), buffer_len)) + return space.newbool( + cmp > 0 or (cmp == 0 and space.newbool(len(value) > buffer_len))) def descr_ge(self, space, w_other): + if isinstance(w_other, W_BytearrayObject): + return space.newbool(self.data >= w_other.data) + try: - res = self._val(space) >= self._op_val(space, w_other) + buffer = _get_buffer(space, w_other) except OperationError as e: if e.match(space, space.w_TypeError): return space.w_NotImplemented raise - return space.newbool(res) + + value = self._val(space) + buffer_len = buffer.getlength() + + cmp = _memcmp(value, buffer, min(len(value), buffer_len)) + return space.newbool( + cmp > 0 or (cmp == 0 and space.newbool(len(value) >= buffer_len))) def descr_iter(self, space): return space.newseqiter(self) @@ -319,8 +388,11 @@ def descr_inplace_add(self, space, w_other): if isinstance(w_other, W_BytearrayObject): self.data += w_other.data - else: - self.data += self._op_val(space, w_other) + return self + + buffer = _get_buffer(space, w_other) + for i in range(buffer.getlength()): + self.data.append(buffer.getitem(i)) return self def descr_inplace_mul(self, space, w_times): @@ -403,11 +475,42 @@ if space.isinstance_w(w_sub, space.w_int): char = space.int_w(w_sub) return _descr_contains_bytearray(self.data, space, char) + return self._StringMethods_descr_contains(space, w_sub) + def descr_add(self, space, w_other): + if isinstance(w_other, W_BytearrayObject): + return self._new(self.data + w_other.data) + + try: + buffer = _get_buffer(space, w_other) + except OperationError as e: + if e.match(space, space.w_TypeError): + return space.w_NotImplemented + raise + + buffer_len = buffer.getlength() + data = list(self.data + ['\0'] * buffer_len) + for i in range(buffer_len): + data[len(self.data) + i] = buffer.getitem(i) + return self._new(data) + + def descr_reverse(self, space): self.data.reverse() +class BytearrayBuilder(object): + def __init__(self, size): + self.data = newlist_hint(size) + + def append(self, s): + for i in range(len(s)): + self.data.append(s[i]) + + def build(self): + return self.data + + # ____________________________________________________________ # helpers for slow paths, moved out because they contain loops @@ -1152,3 +1255,13 @@ def setitem(self, index, char): self.data[index] = char + + +@specialize.argtype(0) +def _memcmp(selfvalue, buffer, length): + for i in range(length): + if selfvalue[i] < buffer.getitem(i): + return -1 + if selfvalue[i] > buffer.getitem(i): + return 1 + return 0 diff --git a/pypy/objspace/std/bytesobject.py b/pypy/objspace/std/bytesobject.py --- a/pypy/objspace/std/bytesobject.py +++ b/pypy/objspace/std/bytesobject.py @@ -480,6 +480,11 @@ _val = str_w @staticmethod + def _use_rstr_ops(space, w_other): + from pypy.objspace.std.unicodeobject import W_UnicodeObject + return isinstance(w_other, (W_BytesObject, W_UnicodeObject)) + + @staticmethod def _op_val(space, w_other): try: return space.str_w(w_other) diff --git a/pypy/objspace/std/stringmethods.py b/pypy/objspace/std/stringmethods.py --- a/pypy/objspace/std/stringmethods.py +++ b/pypy/objspace/std/stringmethods.py @@ -1,7 +1,7 @@ """Functionality shared between bytes/bytearray/unicode""" from rpython.rlib import jit -from rpython.rlib.objectmodel import specialize +from rpython.rlib.objectmodel import specialize, newlist_hint from rpython.rlib.rarithmetic import ovfcheck from rpython.rlib.rstring import endswith, replace, rsplit, split, startswith @@ -36,17 +36,27 @@ def descr_contains(self, space, w_sub): value = self._val(space) - other = self._op_val(space, w_sub) - return space.newbool(value.find(other) >= 0) + if self._use_rstr_ops(space, w_sub): + other = self._op_val(space, w_sub) + return space.newbool(value.find(other) >= 0) + + buffer = _get_buffer(space, w_sub) + res = _search_slowpath(value, buffer, 0, len(value), FAST_FIND) + return space.newbool(res >= 0) def descr_add(self, space, w_other): - try: - other = self._op_val(space, w_other) - except OperationError as e: - if e.match(space, space.w_TypeError): - return space.w_NotImplemented - raise - return self._new(self._val(space) + other) + if self._use_rstr_ops(space, w_other): + try: + other = self._op_val(space, w_other) + except OperationError as e: + if e.match(space, space.w_TypeError): + return space.w_NotImplemented + raise + return self._new(self._val(space) + other) + + # Bytearray overrides this method, CPython doesn't support contacting + # buffers and strs, and unicodes are always handled above + return space.w_NotImplemented def descr_mul(self, space, w_times): try: @@ -128,14 +138,21 @@ def descr_count(self, space, w_sub, w_start=None, w_end=None): value, start, end = self._convert_idx_params(space, w_start, w_end) - return space.newint(value.count(self._op_val(space, w_sub), start, - end)) + + if self._use_rstr_ops(space, w_sub): + return space.newint(value.count(self._op_val(space, w_sub), start, + end)) + + buffer = _get_buffer(space, w_sub) + res = _search_slowpath(value, buffer, start, end, FAST_COUNT) + return space.wrap(max(res, 0)) def descr_decode(self, space, w_encoding=None, w_errors=None): from pypy.objspace.std.unicodeobject import ( _get_encoding_and_errors, decode_object, unicode_from_string) encoding, errors = _get_encoding_and_errors(space, w_encoding, w_errors) + # TODO: On CPython calling bytearray.decode with no arguments works. if encoding is None and errors is None: return unicode_from_string(space, self) return decode_object(space, self, encoding, errors) @@ -192,30 +209,52 @@ def descr_find(self, space, w_sub, w_start=None, w_end=None): (value, start, end) = self._convert_idx_params(space, w_start, w_end) - res = value.find(self._op_val(space, w_sub), start, end) + + if self._use_rstr_ops(space, w_sub): + res = value.find(self._op_val(space, w_sub), start, end) + return space.wrap(res) + + buffer = _get_buffer(space, w_sub) + res = _search_slowpath(value, buffer, start, end, FAST_FIND) return space.wrap(res) def descr_rfind(self, space, w_sub, w_start=None, w_end=None): (value, start, end) = self._convert_idx_params(space, w_start, w_end) - res = value.rfind(self._op_val(space, w_sub), start, end) + + if self._use_rstr_ops(space, w_sub): + res = value.rfind(self._op_val(space, w_sub), start, end) + return space.wrap(res) + + buffer = _get_buffer(space, w_sub) + res = _search_slowpath(value, buffer, start, end, FAST_RFIND) return space.wrap(res) def descr_index(self, space, w_sub, w_start=None, w_end=None): (value, start, end) = self._convert_idx_params(space, w_start, w_end) - res = value.find(self._op_val(space, w_sub), start, end) + + if self._use_rstr_ops(space, w_sub): + res = value.find(self._op_val(space, w_sub), start, end) + else: + buffer = _get_buffer(space, w_sub) + res = _search_slowpath(value, buffer, start, end, FAST_FIND) + if res < 0: raise oefmt(space.w_ValueError, "substring not found in string.index") - return space.wrap(res) def descr_rindex(self, space, w_sub, w_start=None, w_end=None): (value, start, end) = self._convert_idx_params(space, w_start, w_end) - res = value.rfind(self._op_val(space, w_sub), start, end) + + if self._use_rstr_ops(space, w_sub): + res = value.rfind(self._op_val(space, w_sub), start, end) + else: + buffer = _get_buffer(space, w_sub) + res = _search_slowpath(value, buffer, start, end, FAST_RFIND) + if res < 0: raise oefmt(space.w_ValueError, "substring not found in string.rindex") - return space.wrap(res) @specialize.arg(2) @@ -328,6 +367,7 @@ value = self._val(space) prealloc_size = len(value) * (size - 1) + unwrapped = newlist_hint(size) for i in range(size): w_s = list_w[i] check_item = self._join_check_item(space, w_s) @@ -337,13 +377,16 @@ i, w_s) elif check_item == 2: return self._join_autoconvert(space, list_w) - prealloc_size += len(self._op_val(space, w_s)) + # XXX Maybe the extra copy here is okay? It was basically going to + # happen anyway, what with being placed into the builder + unwrapped.append(self._op_val(space, w_s)) + prealloc_size += len(unwrapped[0]) sb = self._builder(prealloc_size) for i in range(size): if value and i != 0: sb.append(value) - sb.append(self._op_val(space, list_w[i])) + sb.append(unwrapped[i]) return self._new(sb.build()) def _join_autoconvert(self, space, list_w): @@ -386,10 +429,22 @@ def descr_partition(self, space, w_sub): value = self._val(space) - sub = self._op_val(space, w_sub) - if not sub: + + if self._use_rstr_ops(space, w_sub): + sub = self._op_val(space, w_sub) + sublen = len(sub) + else: + sub = _get_buffer(space, w_sub) + sublen = sub.getlength() + + if sublen == 0: raise oefmt(space.w_ValueError, "empty separator") - pos = value.find(sub) + + if self._use_rstr_ops(space, w_sub): + pos = value.find(sub) + else: + pos = _search_slowpath(value, sub, 0, len(value), FAST_FIND) + if pos == -1: from pypy.objspace.std.bytearrayobject import W_BytearrayObject if isinstance(self, W_BytearrayObject): @@ -398,17 +453,29 @@ else: from pypy.objspace.std.bytearrayobject import W_BytearrayObject if isinstance(self, W_BytearrayObject): - w_sub = self._new(sub) + w_sub = self._new_from_buffer(sub) return space.newtuple( [self._sliced(space, value, 0, pos, self), w_sub, self._sliced(space, value, pos+len(sub), len(value), self)]) def descr_rpartition(self, space, w_sub): value = self._val(space) - sub = self._op_val(space, w_sub) - if not sub: + + if self._use_rstr_ops(space, w_sub): + sub = self._op_val(space, w_sub) + sublen = len(sub) + else: + sub = _get_buffer(space, w_sub) + sublen = sub.getlength() + + if sublen == 0: raise oefmt(space.w_ValueError, "empty separator") - pos = value.rfind(sub) + + if self._use_rstr_ops(space, w_sub): + pos = value.rfind(sub) + else: + pos = _search_slowpath(value, sub, 0, len(value), FAST_RFIND) + if pos == -1: from pypy.objspace.std.bytearrayobject import W_BytearrayObject if isinstance(self, W_BytearrayObject): @@ -417,7 +484,7 @@ else: from pypy.objspace.std.bytearrayobject import W_BytearrayObject if isinstance(self, W_BytearrayObject): - w_sub = self._new(sub) + w_sub = self._new_from_buffer(sub) return space.newtuple( [self._sliced(space, value, 0, pos, self), w_sub, self._sliced(space, value, pos+len(sub), len(value), self)]) @@ -616,10 +683,11 @@ for char in string: buf.append(table[ord(char)]) else: + # XXX Why not preallocate here too? buf = self._builder() deletion_table = [False] * 256 - for c in deletechars: - deletion_table[ord(c)] = True + for i in range(len(deletechars)): + deletion_table[ord(deletechars[i])] = True for char in string: if not deletion_table[ord(char)]: buf.append(table[ord(char)]) @@ -662,3 +730,118 @@ @specialize.argtype(0) def _descr_getslice_slowpath(selfvalue, start, step, sl): return [selfvalue[start + i*step] for i in range(sl)] + +def _get_buffer(space, w_obj): + return space.buffer_w(w_obj, space.BUF_SIMPLE) + + + +# Stolen form rpython.rtyper.lltypesytem.rstr +# TODO: Ask about what to do with this... + +FAST_COUNT = 0 +FAST_FIND = 1 +FAST_RFIND = 2 + +from rpython.rlib.rarithmetic import LONG_BIT as BLOOM_WIDTH + +def bloom_add(mask, c): + return mask | (1 << (ord(c) & (BLOOM_WIDTH - 1))) + + +def bloom(mask, c): + return mask & (1 << (ord(c) & (BLOOM_WIDTH - 1))) + +@specialize.argtype(0, 1) +def _search_slowpath(value, buffer, start, end, mode): + if start < 0: + start = 0 + if end > len(value): + end = len(value) + if start > end: + return -1 + + count = 0 + n = end - start + m = buffer.getlength() + + if m == 0: + if mode == FAST_COUNT: + return end - start + 1 + elif mode == FAST_RFIND: + return end + else: + return start + + w = n - m + + if w < 0: + return -1 + + mlast = m - 1 + skip = mlast - 1 + mask = 0 + + if mode != FAST_RFIND: + for i in range(mlast): + mask = bloom_add(mask, buffer.getitem(i)) + if buffer.getitem(i) == buffer.getitem(mlast): + skip = mlast - i - 1 + mask = bloom_add(mask, buffer.getitem(mlast)) + + i = start - 1 + while i + 1 <= start + w: + i += 1 + if value[i + m - 1] == buffer.getitem(m - 1): + for j in range(mlast): + if value[i + j] != buffer.getitem(j): + break + else: + if mode != FAST_COUNT: + return i + count += 1 + i += mlast + continue + + if i + m < len(value): + c = value[i + m] + else: + c = '\0' + if not bloom(mask, c): + i += m + else: + i += skip + else: + if i + m < len(value): + c = value[i + m] + else: + c = '\0' + if not bloom(mask, c): + i += m + else: + mask = bloom_add(mask, buffer.getitem(0)) + for i in range(mlast, 0, -1): + mask = bloom_add(mask, buffer.getitem(i)) + if buffer.getitem(i) == buffer.getitem(0): + skip = i - 1 + + i = start + w + 1 + while i - 1 >= start: + i -= 1 + if value[i] == buffer.getitem(0): + for j in xrange(mlast, 0, -1): + if value[i + j] != buffer.getitem(j): + break + else: + return i + if i - 1 >= 0 and not bloom(mask, value[i - 1]): + i -= m + else: + i -= skip + else: + if i - 1 >= 0 and not bloom(mask, value[i - 1]): + i -= m + + if mode != FAST_COUNT: + return -1 + return count diff --git a/pypy/objspace/std/test/test_bytearrayobject.py b/pypy/objspace/std/test/test_bytearrayobject.py --- a/pypy/objspace/std/test/test_bytearrayobject.py +++ b/pypy/objspace/std/test/test_bytearrayobject.py @@ -178,8 +178,10 @@ assert bytearray('hello').rindex('l') == 3 assert bytearray('hello').index(bytearray('e')) == 1 assert bytearray('hello').find('l') == 2 + assert bytearray('hello').find('l', -2) == 3 assert bytearray('hello').rfind('l') == 3 + # these checks used to not raise in pypy but they should raises(TypeError, bytearray('hello').index, ord('e')) raises(TypeError, bytearray('hello').rindex, ord('e')) diff --git a/pypy/objspace/std/unicodeobject.py b/pypy/objspace/std/unicodeobject.py --- a/pypy/objspace/std/unicodeobject.py +++ b/pypy/objspace/std/unicodeobject.py @@ -103,6 +103,12 @@ _val = unicode_w @staticmethod + def _use_rstr_ops(space, w_other): + # Always return true because we always need to copy the other + # operand(s) before we can do comparisons + return True + + @staticmethod def _op_val(space, w_other): if isinstance(w_other, W_UnicodeObject): return w_other._value _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit