Author: Maciej Fijalkowski <[email protected]>
Branch: 
Changeset: r68565:0390ddee24d3
Date: 2013-12-29 10:43 +0200
http://bitbucket.org/pypy/pypy/changeset/0390ddee24d3/

Log:    implement multichar split for RPython

diff --git a/rpython/rlib/rsocket.py b/rpython/rlib/rsocket.py
--- a/rpython/rlib/rsocket.py
+++ b/rpython/rlib/rsocket.py
@@ -39,8 +39,8 @@
 else:
     def rsocket_startup():
         pass
- 
- 
+
+
 def ntohs(x):
     return rffi.cast(lltype.Signed, _c.ntohs(x))
 
@@ -500,7 +500,7 @@
         self.type = type
         self.proto = proto
         self.timeout = defaults.timeout
-        
+
     def __del__(self):
         fd = self.fd
         if fd != _c.INVALID_SOCKET:
@@ -575,8 +575,8 @@
             if n == 0:
                 return 1
             return 0
-        
-        
+
+
     def error_handler(self):
         return last_error()
 
@@ -696,7 +696,7 @@
             if res < 0:
                 res = errno
             return (res, False)
-        
+
     def connect(self, address):
         """Connect the socket to a remote address."""
         err, timeout = self._connect(address)
@@ -704,7 +704,7 @@
             raise SocketTimeout
         if err:
             raise CSocketError(err)
-        
+
     def connect_ex(self, address):
         """This is like connect(address), but returns an error code (the errno
         value) instead of raising an exception when an error occurs."""
@@ -720,7 +720,7 @@
                 raise self.error_handler()
             return make_socket(fd, self.family, self.type, self.proto,
                                SocketClass=SocketClass)
-        
+
     def getpeername(self):
         """Return the address of the remote endpoint."""
         address, addr_p, addrlen_p = self._addrbuf()
@@ -790,7 +790,7 @@
         """Return the timeout of the socket. A timeout < 0 means that
         timeouts are disabled in the socket."""
         return self.timeout
-    
+
     def listen(self, backlog):
         """Enable a server to accept connections.  The backlog argument
         must be at least 1; it specifies the number of unaccepted connections
@@ -857,7 +857,7 @@
     def recvfrom_into(self, rwbuffer, nbytes, flags=0):
         buf, addr = self.recvfrom(nbytes, flags)
         rwbuffer.setslice(0, buf)
-        return len(buf), addr        
+        return len(buf), addr
 
     def send_raw(self, dataptr, length, flags=0):
         """Send data from a CCHARP buffer."""
@@ -951,7 +951,7 @@
         else:
             self.timeout = timeout
         self._setblocking(self.timeout < 0.0)
-            
+
     def shutdown(self, how):
         """Shut down the reading side of the socket (flag == SHUT_RD), the
         writing side of the socket (flag == SHUT_WR), or both ends
diff --git a/rpython/rtyper/lltypesystem/rstr.py 
b/rpython/rtyper/lltypesystem/rstr.py
--- a/rpython/rtyper/lltypesystem/rstr.py
+++ b/rpython/rtyper/lltypesystem/rstr.py
@@ -624,8 +624,7 @@
             i += 1
         return count
 
-    @classmethod
-    def ll_find(cls, s1, s2, start, end):
+    def ll_find(s1, s2, start, end):
         if start < 0:
             start = 0
         if end > len(s1.chars):
@@ -635,9 +634,9 @@
 
         m = len(s2.chars)
         if m == 1:
-            return cls.ll_find_char(s1, s2.chars[0], start, end)
+            return LLHelpers.ll_find_char(s1, s2.chars[0], start, end)
 
-        return cls.ll_search(s1, s2, start, end, FAST_FIND)
+        return LLHelpers.ll_search(s1, s2, start, end, FAST_FIND)
 
     @classmethod
     def ll_rfind(cls, s1, s2, start, end):
@@ -881,6 +880,37 @@
         item.copy_contents(s, item, i, 0, j - i)
         return res
 
+    def ll_split(LIST, s, c, max):
+        count = 1
+        if max == -1:
+            max = len(s.chars)
+        pos = 0
+        last = len(s.chars)
+        markerlen = len(c.chars)
+        pos = s.find(c, 0, last)
+        while pos >= 0 and count <= max:
+            pos = s.find(c, pos + markerlen, last)
+            count += 1
+        res = LIST.ll_newlist(count)
+        items = res.ll_items()
+        pos = 0
+        count = 0
+        pos = s.find(c, 0, last)
+        prev_pos = 0
+        if pos < 0:
+            items[0] = s
+            return items
+        while pos >= 0 and count < max:
+            item = items[count] = s.malloc(pos - prev_pos)
+            item.copy_contents(s, item, prev_pos, 0, pos -
+                               prev_pos)
+            count += 1
+            prev_pos = pos + markerlen
+            pos = s.find(c, pos + markerlen, last)
+        item = items[count] = s.malloc(last - prev_pos)
+        item.copy_contents(s, item, prev_pos, 0, last - prev_pos)
+        return items
+
     def ll_rsplit_chr(LIST, s, c, max):
         chars = s.chars
         strlen = len(chars)
@@ -1094,7 +1124,8 @@
                               'copy_contents' : 
staticAdtMethod(copy_string_contents),
                               'copy_contents_from_str' : 
staticAdtMethod(copy_string_contents),
                               'gethash': LLHelpers.ll_strhash,
-                              'length': LLHelpers.ll_length}))
+                              'length': LLHelpers.ll_length,
+                              'find': LLHelpers.ll_find}))
 UNICODE.become(GcStruct('rpy_unicode', ('hash', Signed),
                         ('chars', Array(UniChar, hints={'immutable': True})),
                         adtmeths={'malloc' : staticAdtMethod(mallocunicode),
diff --git a/rpython/rtyper/rstr.py b/rpython/rtyper/rstr.py
--- a/rpython/rtyper/rstr.py
+++ b/rpython/rtyper/rstr.py
@@ -336,10 +336,16 @@
 
     def rtype_method_split(self, hop):
         rstr = hop.args_r[0].repr
+        v_str = hop.inputarg(rstr.repr, 0)
+        if isinstance(hop.args_s[1], annmodel.SomeString):
+            v_chr = hop.inputarg(rstr.repr, 1)
+            fn = self.ll.ll_split
+        else:
+            v_chr = hop.inputarg(rstr.char_repr, 1)
+            fn = self.ll.ll_split_chr
         if hop.nb_args == 3:
-            v_str, v_chr, v_max = hop.inputargs(rstr.repr, rstr.char_repr, 
Signed)
+            v_max = hop.inputarg(Signed, 2)
         else:
-            v_str, v_chr = hop.inputargs(rstr.repr, rstr.char_repr)
             v_max = hop.inputconst(Signed, -1)
         try:
             list_type = hop.r_result.lowleveltype.TO
@@ -347,7 +353,7 @@
             list_type = hop.r_result.lowleveltype
         cLIST = hop.inputconst(Void, list_type)
         hop.exception_cannot_occur()
-        return hop.gendirectcall(self.ll.ll_split_chr, cLIST, v_str, v_chr, 
v_max)
+        return hop.gendirectcall(fn, cLIST, v_str, v_chr, v_max)
 
     def rtype_method_rsplit(self, hop):
         rstr = hop.args_r[0].repr
diff --git a/rpython/rtyper/test/test_rstr.py b/rpython/rtyper/test/test_rstr.py
--- a/rpython/rtyper/test/test_rstr.py
+++ b/rpython/rtyper/test/test_rstr.py
@@ -731,6 +731,19 @@
             res = self.interpret(fn, [i])
             assert res == fn(i)
 
+    def test_split_multichar(self):
+        l = ["abc::z", "abc", "abc::def:::x"]
+        exp = [["abc", "z"], ["abc"], ["abc", "def", ":x"]]
+        exp2 = [["abc", "z"], ["abc"], ["abc", "def:::x"]]
+
+        def f(i):
+            s = l[i]
+            return s.split("::") == exp[i] and s.split("::", 1) == exp2[i]
+
+        for i in range(3):
+            res = self.interpret(f, [i])
+            assert res == True
+
     def test_rsplit(self):
         fn = self._make_split_test('rsplit')
         for i in range(5):
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to