Author: Richard Plangger <[email protected]>
Branch: py3.5-ssl
Changeset: r88388:cdbfa886bf3b
Date: 2016-11-15 15:59 +0100
http://bitbucket.org/pypy/pypy/changeset/cdbfa886bf3b/

Log:    copy Lib/ssl.py from cpython. Pass lots of more tests in test_ssl.py

diff --git a/lib-python/3/ssl.py b/lib-python/3/ssl.py
--- a/lib-python/3/ssl.py
+++ b/lib-python/3/ssl.py
@@ -145,6 +145,7 @@
 from socket import SOL_SOCKET, SO_TYPE
 import base64        # for DER-to-PEM translation
 import errno
+import warnings
 
 
 socket_error = OSError  # keep that public name in module namespace
@@ -405,12 +406,16 @@
 
     def _load_windows_store_certs(self, storename, purpose):
         certs = bytearray()
-        for cert, encoding, trust in enum_certificates(storename):
-            # CA certs are never PKCS#7 encoded
-            if encoding == "x509_asn":
-                if trust is True or purpose.oid in trust:
-                    certs.extend(cert)
-        self.load_verify_locations(cadata=certs)
+        try:
+            for cert, encoding, trust in enum_certificates(storename):
+                # CA certs are never PKCS#7 encoded
+                if encoding == "x509_asn":
+                    if trust is True or purpose.oid in trust:
+                        certs.extend(cert)
+        except PermissionError:
+            warnings.warn("unable to enumerate Windows certificate store")
+        if certs:
+            self.load_verify_locations(cadata=certs)
         return certs
 
     def load_default_certs(self, purpose=Purpose.SERVER_AUTH):
@@ -560,7 +565,7 @@
         server hostame is set."""
         return self._sslobj.server_hostname
 
-    def read(self, len=0, buffer=None):
+    def read(self, len=1024, buffer=None):
         """Read up to 'len' bytes from the SSL object and return them.
 
         If 'buffer' is provided, read into this buffer and return the number of
@@ -569,7 +574,7 @@
         if buffer is not None:
             v = self._sslobj.read(len, buffer)
         else:
-            v = self._sslobj.read(len or 1024)
+            v = self._sslobj.read(len)
         return v
 
     def write(self, data):
@@ -745,7 +750,8 @@
                         # non-blocking
                         raise ValueError("do_handshake_on_connect should not 
be specified for non-blocking sockets")
                     self.do_handshake()
-            except:
+
+            except (OSError, ValueError):
                 self.close()
                 raise
 
@@ -774,7 +780,7 @@
             # EAGAIN.
             self.getpeername()
 
-    def read(self, len=0, buffer=None):
+    def read(self, len=1024, buffer=None):
         """Read up to LEN bytes and return them.
         Return zero-length string on EOF."""
 
diff --git a/lib-python/3/test/test_ssl.py b/lib-python/3/test/test_ssl.py
--- a/lib-python/3/test/test_ssl.py
+++ b/lib-python/3/test/test_ssl.py
@@ -2889,8 +2889,10 @@
                 # will be full and the call will block
                 buf = bytearray(8192)
                 def fill_buffer():
+                    i = 0
                     while True:
                         s.send(buf)
+                        i += 1
                 self.assertRaises((ssl.SSLWantWriteError,
                                    ssl.SSLWantReadError), fill_buffer)
 
diff --git a/lib_pypy/openssl/_cffi_src/openssl/ssl.py 
b/lib_pypy/openssl/_cffi_src/openssl/ssl.py
--- a/lib_pypy/openssl/_cffi_src/openssl/ssl.py
+++ b/lib_pypy/openssl/_cffi_src/openssl/ssl.py
@@ -25,6 +25,9 @@
 static const long Cryptography_HAS_GET_SERVER_TMP_KEY;
 static const long Cryptography_HAS_SSL_CTX_SET_CLIENT_CERT_ENGINE;
 static const long Cryptography_HAS_SSL_CTX_CLEAR_OPTIONS;
+static const long Cryptography_HAS_NPN_NEGOTIATED;
+
+static const long Cryptography_OPENSSL_NPN_NEGOTIATED;
 
 /* Internally invented symbol to tell us if SNI is supported */
 static const long Cryptography_HAS_TLSEXT_HOSTNAME;
@@ -435,6 +438,7 @@
 long SSL_CTX_sess_misses(SSL_CTX *);
 long SSL_CTX_sess_timeouts(SSL_CTX *);
 long SSL_CTX_sess_cache_full(SSL_CTX *);
+
 """
 
 CUSTOMIZATIONS = """
@@ -689,6 +693,14 @@
 
 static const long Cryptography_HAS_SSL_CTX_CLEAR_OPTIONS = 1;
 
+#ifdef OPENSSL_NPN_NEGOTIATED
+static const long Cryptography_OPENSSL_NPN_NEGOTIATED = OPENSSL_NPN_NEGOTIATED;
+static const long Cryptography_HAS_NPN_NEGOTIATED = 1;
+#else
+static const long Cryptography_OPENSSL_NPN_NEGOTIATED = 0;
+static const long Cryptography_HAS_NPN_NEGOTIATED = 0;
+#endif
+
 /* in OpenSSL 1.1.0 the SSL_ST values were renamed to TLS_ST and several were
    removed */
 #if CRYPTOGRAPHY_OPENSSL_LESS_THAN_110 || defined(LIBRESSL_VERSION_NUMBER)
diff --git a/lib_pypy/openssl/_stdssl/__init__.py 
b/lib_pypy/openssl/_stdssl/__init__.py
--- a/lib_pypy/openssl/_stdssl/__init__.py
+++ b/lib_pypy/openssl/_stdssl/__init__.py
@@ -7,8 +7,9 @@
 from _openssl import lib
 from openssl._stdssl.certificate import (_test_decode_cert,
     _decode_certificate, _certificate_to_der)
-from openssl._stdssl.utility import _str_with_len, _bytes_with_len, 
_str_to_ffi_buffer
-from openssl._stdssl.error import (ssl_error, ssl_lib_error, ssl_socket_error,
+from openssl._stdssl.utility import (_str_with_len, _bytes_with_len,
+    _str_to_ffi_buffer, _str_from_buf)
+from openssl._stdssl.error import (ssl_error, pyssl_error,
         SSLError, SSLZeroReturnError, SSLWantReadError,
         SSLWantWriteError, SSLSyscallError,
         SSLEOFError)
@@ -35,7 +36,7 @@
 HAS_ECDH = bool(lib.Cryptography_HAS_ECDH)
 HAS_SNI = bool(lib.Cryptography_HAS_TLSEXT_HOSTNAME)
 HAS_ALPN = bool(lib.Cryptography_HAS_ALPN)
-HAS_NPN = False
+HAS_NPN = lib.Cryptography_HAS_NPN_NEGOTIATED
 HAS_TLS_UNIQUE = True
 
 CLIENT = 0
@@ -61,6 +62,8 @@
 SSL_CLIENT = 0
 SSL_SERVER = 1
 
+SSL_CB_MAXLEN=128
+
 if lib.Cryptography_HAS_SSL2:
     PROTOCOL_SSLv2  = 0
 PROTOCOL_SSLv3  = 1
@@ -73,7 +76,7 @@
 
 _PROTOCOL_NAMES = (name for name in dir(lib) if name.startswith('PROTOCOL_'))
 
-from enum import Enum as _Enum, IntEnum as _IntEnum
+from enum import IntEnum as _IntEnum
 _IntEnum._convert('_SSLMethod', __name__,
         lambda name: name.startswith('PROTOCOL_'))
 
@@ -88,6 +91,14 @@
 # TODO threads?
 lib.OpenSSL_add_all_algorithms()
 
+def _socket_timeout(s):
+    if s is None:
+        return 0.0
+    t = s.gettimeout()
+    if t is None:
+        return -1.0
+    return t
+
 class PasswordInfo(object):
     callable = None
     password = None
@@ -139,7 +150,7 @@
     if sock is None or timeout == 0:
         return SOCKET_IS_NONBLOCKING
     elif timeout < 0:
-        t = sock.gettimeout() or 0
+        t = _socket_timeout(sock)
         if t > 0:
             return SOCKET_HAS_TIMED_OUT
         else:
@@ -219,7 +230,7 @@
         # If the socket is in non-blocking mode or timeout mode, set the BIO
         # to non-blocking mode (blocking is the default)
         #
-        timeout = sock.gettimeout() or 0
+        timeout = _socket_timeout(sock)
         if sock and timeout >= 0:
             lib.BIO_set_nbio(lib.SSL_get_rbio(ssl), 1)
             lib.BIO_set_nbio(lib.SSL_get_wbio(ssl), 1)
@@ -246,6 +257,8 @@
         self.owner = None
         self.server_hostname = None
         self.socket = None
+        self.alpn_protocols = ffi.NULL
+        self.npn_protocols = ffi.NULL
 
     @property
     def context(self):
@@ -260,15 +273,13 @@
         if sock is None:
             raise ssl_error("Underlying socket connection gone", 
SSL_ERROR_NO_SOCKET)
         ssl = self.ssl
-        timeout = 0
+        timeout = _socket_timeout(sock)
         if sock:
-            timeout = sock.gettimeout() or 0
             nonblocking = timeout >= 0
             lib.BIO_set_nbio(lib.SSL_get_rbio(ssl), nonblocking)
             lib.BIO_set_nbio(lib.SSL_get_wbio(ssl), nonblocking)
 
         has_timeout = timeout > 0
-        has_timeout = (timeout > 0);
         deadline = -1
         if has_timeout:
             # REVIEW, cpython uses a monotonic clock here
@@ -296,7 +307,7 @@
                 sockstate = SOCKET_OPERATION_OK
 
             if sockstate == SOCKET_HAS_TIMED_OUT:
-                raise SSLError("The handshake operation timed out")
+                raise socket.timeout("The handshake operation timed out")
             elif sockstate == SOCKET_HAS_BEEN_CLOSED:
                 raise SSLError("Underlying socket has been closed.")
             elif sockstate == SOCKET_TOO_LARGE_FOR_SELECT:
@@ -306,7 +317,7 @@
             if not (err == SSL_ERROR_WANT_READ or err == SSL_ERROR_WANT_WRITE):
                 break
         if ret < 1:
-            raise ssl_lib_error()
+            raise pyssl_error(self, ret)
 
         if self.peer_cert != ffi.NULL:
             lib.X509_free(self.peer_cert)
@@ -335,15 +346,19 @@
     def write(self, bytestring):
         deadline = 0
         b = _str_to_ffi_buffer(bytestring)
-        sock = self.get_socket_or_None()
+        sock = self.get_socket_or_connection_gone()
         ssl = self.ssl
+
+        if len(b) > sys.maxsize:
+            raise OverflowError("string longer than %d bytes" % sys.maxsize)
+
+        timeout = _socket_timeout(sock)
         if sock:
-            timeout = sock.gettimeout() or 0
             nonblocking = timeout >= 0
             lib.BIO_set_nbio(lib.SSL_get_rbio(ssl), nonblocking)
             lib.BIO_set_nbio(lib.SSL_get_wbio(ssl), nonblocking)
 
-        timeout = sock.gettimeout() or 0
+
         has_timeout = timeout > 0
         if has_timeout:
             # TODO monotonic clock?
@@ -351,7 +366,7 @@
 
         sockstate = _ssl_select(sock, 1, timeout)
         if sockstate == SOCKET_HAS_TIMED_OUT:
-            raise socket.TimeoutError("The write operation timed out")
+            raise socket.timeout("The write operation timed out")
         elif sockstate == SOCKET_HAS_BEEN_CLOSED:
             raise ssl_error("Underlying socket has been closed.")
         elif sockstate == SOCKET_TOO_LARGE_FOR_SELECT:
@@ -378,7 +393,7 @@
                 sockstate = SOCKET_OPERATION_OK
 
             if sockstate == SOCKET_HAS_TIMED_OUT:
-                raise socket.TimeoutError("The write operation timed out")
+                raise socket.timeout("The write operation timed out")
             elif sockstate == SOCKET_HAS_BEEN_CLOSED:
                 raise ssl_error("Underlying socket has been closed.")
             elif sockstate == SOCKET_IS_NONBLOCKING:
@@ -389,13 +404,15 @@
         if length > 0:
             return length
         else:
-            raise ssl_lib_error()
-            # return PySSL_SetError(self, len, __FILE__, __LINE__);
+            raise pyssl_error(self, length)
 
     def read(self, length, buffer_into=None):
         sock = self.get_socket_or_None()
         ssl = self.ssl
 
+        if length < 0 and buffer_into is None:
+            raise ValueError("size should not be negative")
+
         if sock is None:
             raise ssl_error("Underlying socket connection gone", 
SSL_ERROR_NO_SOCKET)
 
@@ -403,20 +420,20 @@
             dest = _buffer_new(length)
             mem = dest
         else:
-            import pdb; pdb.set_trace()
             mem = ffi.from_buffer(buffer_into)
             if length <= 0 or length > len(buffer_into):
-                if len(buffer_into) != length:
+                length = len(buffer_into)
+                if length > sys.maxsize:
                     raise OverflowError("maximum length can't fit in a C 
'int'")
 
         if sock:
-            timeout = sock.gettimeout() or 0
+            timeout = _socket_timeout(sock)
             nonblocking = timeout >= 0
             lib.BIO_set_nbio(lib.SSL_get_rbio(ssl), nonblocking)
             lib.BIO_set_nbio(lib.SSL_get_wbio(ssl), nonblocking)
 
         deadline = 0
-        timeout = sock.gettimeout() or 0
+        timeout = _socket_timeout(sock)
         has_timeout = timeout > 0
         if has_timeout:
             # TODO monotonic clock?
@@ -448,28 +465,29 @@
                 sockstate = SOCKET_OPERATION_OK
 
             if sockstate == SOCKET_HAS_TIMED_OUT:
-                raise socket.TimeoutError("The read operation timed out")
+                raise socket.timeout("The read operation timed out")
             elif sockstate == SOCKET_IS_NONBLOCKING:
                 break
             if not (err == SSL_ERROR_WANT_READ or err == SSL_ERROR_WANT_WRITE):
                 break
 
-        if count <= 0:
-            raise ssl_socket_error(self, err)
+        if count <= 0 and not shutdown:
+            raise pyssl_error(self, count)
 
         if not buffer_into:
             return _bytes_with_len(dest, count)
 
         return count
 
-    def selected_alpn_protocol(self):
-        out = ffi.new("const unsigned char **")
-        outlen = ffi.new("unsigned int*")
+    if HAS_ALPN:
+        def selected_alpn_protocol(self):
+            out = ffi.new("const unsigned char **")
+            outlen = ffi.new("unsigned int*")
 
-        lib.SSL_get0_alpn_selected(self.ssl, out, outlen);
-        if out == ffi.NULL:
-            return None
-        return _str_with_len(ffi.cast("char*",out[0]), outlen[0]);
+            lib.SSL_get0_alpn_selected(self.ssl, out, outlen);
+            if out[0] == ffi.NULL:
+                return None
+            return _str_with_len(out[0], outlen[0]);
 
     def shared_ciphers(self):
         sess = lib.SSL_get_session(self.ssl)
@@ -519,6 +537,11 @@
             return None
         return self.socket()
 
+    def get_socket_or_connection_gone(self):
+        if self.socket is None:
+            raise ssl_error("Underlying socket connection gone", 
SSL_ERROR_NO_SOCKET)
+        return self.socket()
+
     def shutdown(self):
         sock = self.get_socket_or_None()
         nonblocking = False
@@ -529,7 +552,7 @@
             if sock.fileno() < 0:
                 raise ssl_error("Underlying socket connection gone", 
SSL_ERROR_NO_SOCKET)
 
-            timeout = sock.gettimeout() or 0
+            timeout = _socket_timeout(sock)
             nonblocking = timeout >= 0
             if sock and timeout >= 0:
                 lib.BIO_set_nbio(lib.SSL_get_rbio(ssl), nonblocking)
@@ -588,9 +611,9 @@
 
             if sockstate == SOCKET_HAS_TIMED_OUT:
                 if ssl_err == SSL_ERROR_WANT_READ:
-                    raise socket.TimeoutError("The read operation timed out")
+                    raise socket.timeout("The read operation timed out")
                 else:
-                    raise socket.TimeoutError("The write operation timed out")
+                    raise socket.timeout("The write operation timed out")
             elif sockstate == SOCKET_TOO_LARGE_FOR_SELECT:
                 raise ssl_error("Underlying socket too large for select().")
             elif sockstate != SOCKET_OPERATION_OK:
@@ -598,12 +621,46 @@
                 break;
 
         if err < 0:
-            raise ssl_socket_error(self, err)
+            raise pyssl_error(self, err)
         if sock:
             return sock
         else:
             return None
 
+    def pending(self):
+        # TODO PySSL_BEGIN_ALLOW_THREADS
+        count = lib.SSL_pending(self.ssl)
+        # TODO PySSL_END_ALLOW_THREADS
+        if count < 0:
+            raise pyssl_error(self, count)
+        else:
+            return count
+
+    def tls_unique_cb(self):
+        buf = ffi.new("char[%d]" % SSL_CB_MAXLEN)
+
+        if lib.SSL_session_reused(self.ssl) ^ (not self.socket_type):
+            # if session is resumed XOR we are the client
+            length = lib.SSL_get_finished(self.ssl, buf, SSL_CB_MAXLEN)
+        else:
+            # if a new session XOR we are the server
+            length = lib.SSL_get_peer_finished(self.ssl, buf, SSL_CB_MAXLEN)
+
+        # It cannot be negative in current OpenSSL version as of July 2011
+        if length == 0:
+            return None
+
+        return _bytes_with_len(buf, length)
+
+    if HAS_NPN:
+        def selected_npn_protocol(self):
+            out = ffi.new("unsigned char**")
+            outlen = ffi.new("unsigned int*")
+            lib.SSL_get0_next_proto_negotiated(self.ssl, out, outlen)
+            if (out[0] == ffi.NULL):
+                return None
+            return _str_with_len(out[0], outlen[0])
+
 
 def _fs_decode(name):
     # TODO return PyUnicode_DecodeFSDefault(short_name);
@@ -638,7 +695,8 @@
     SSL_CTX_STATS.append((name, getattr(lib, attr)))
 
 class _SSLContext(object):
-    __slots__ = ('ctx', '_check_hostname', 'servername_callback')
+    __slots__ = ('ctx', '_check_hostname', 'servername_callback',
+                 'alpn_protocols', 'npn_protocols')
 
     def __new__(cls, protocol):
         self = object.__new__(cls)
@@ -684,12 +742,8 @@
                 lib.SSL_CTX_set_ecdh_auto(self.ctx, 1)
             else:
                 key = lib.EC_KEY_new_by_curve_name(lib.NID_X9_62_prime256v1)
-                if not key:
-                    raise ssl_lib_error()
-                try:
-                    lib.SSL_CTX_set_tmp_ecdh(self.ctx, key)
-                finally:
-                    lib.EC_KEY_free(key)
+                lib.SSL_CTX_set_tmp_ecdh(self.ctx, key)
+                lib.EC_KEY_free(key)
         if lib.Cryptography_HAS_X509_V_FLAG_TRUSTED_FIRST:
             store = lib.SSL_CTX_get_cert_store(self.ctx)
             lib.X509_STORE_set_flags(store, lib.X509_V_FLAG_TRUSTED_FIRST)
@@ -820,7 +874,7 @@
                     lib.ERR_clear_error()
                     raise OSError(_errno, "Error")
                 else:
-                    raise ssl_lib_error()
+                    raise ssl_error(None)
 
             ffi.errno = 0
             buf = _str_to_ffi_buffer(keyfile)
@@ -835,7 +889,7 @@
                     lib.ERR_clear_error()
                     raise OSError(_errno, None)
                 else:
-                    raise ssl_lib_error()
+                    raise ssl_error(None)
 
             ret = lib.SSL_CTX_check_private_key(self.ctx)
             if ret != 1:
@@ -890,7 +944,7 @@
                     lib.ERR_clear_error()
                     raise OSError(_errno, '')
                 else:
-                    raise ssl_lib_error()
+                    raise ssl_error(None)
 
     def _add_ca_certs(self, data, size, ca_file_type):
         biobuf = lib.BIO_new_mem_buf(data, size)
@@ -935,7 +989,7 @@
                 # EOF PEM file, not an error
                 lib.ERR_clear_error()
             else:
-                raise ssl_lib_error()
+                raise ssl_error(None)
         finally:
             lib.BIO_free(biobuf)
 
@@ -969,13 +1023,6 @@
 #        if ctx:
 #            self.ctx = lltype.nullptr(SSL_CTX.TO)
 #            libssl_SSL_CTX_free(ctx)
-#
-#    @staticmethod
-#    @unwrap_spec(protocol=int)
-#    def descr_new(space, w_subtype, protocol=PY_SSL_VERSION_SSL23):
-#        self = space.allocate_instance(SSLContext, w_subtype)
-#        self.__init__(space, protocol)
-#        return space.wrap(self)
 
     def session_stats(self):
         stats = {}
@@ -987,34 +1034,6 @@
         if not lib.SSL_CTX_set_default_verify_paths(self.ctx):
             raise ssl_error("")
 
-#    def descr_get_options(self, space):
-#        return space.newlong(libssl_SSL_CTX_get_options(self.ctx))
-#
-#    def descr_set_options(self, space, w_new_opts):
-#        new_opts = space.int_w(w_new_opts)
-#        opts = libssl_SSL_CTX_get_options(self.ctx)
-#        clear = opts & ~new_opts
-#        set = ~opts & new_opts
-#        if clear:
-#            if HAVE_SSL_CTX_CLEAR_OPTIONS:
-#                libssl_SSL_CTX_clear_options(self.ctx, clear)
-#            else:
-#                raise oefmt(space.w_ValueError,
-#                            "can't clear options before OpenSSL 0.9.8m")
-#        if set:
-#            libssl_SSL_CTX_set_options(self.ctx, set)
-#
-#    def descr_get_check_hostname(self, space):
-#        return space.newbool(self.check_hostname)
-#
-#    def descr_set_check_hostname(self, space, w_obj):
-#        check_hostname = space.is_true(w_obj)
-#        if check_hostname and libssl_SSL_CTX_get_verify_mode(self.ctx) == 
SSL_VERIFY_NONE:
-#            raise oefmt(space.w_ValueError,
-#                        "check_hostname needs a SSL context with either "
-#                        "CERT_OPTIONAL or CERT_REQUIRED")
-#        self.check_hostname = check_hostname
-#
     def load_dh_params(self, filepath):
         ffi.errno = 0
         if filepath is None:
@@ -1037,57 +1056,13 @@
                 lib.ERR_clear_error()
                 raise OSError(_errno, '')
             else:
-                raise ssl_lib_error()
+                raise ssl_error(None)
         try:
             if lib.SSL_CTX_set_tmp_dh(self.ctx, dh) == 0:
-                raise ssl_lib_error()
+                raise ssl_error(None)
         finally:
             lib.DH_free(dh)
 
-#    def cert_store_stats_w(self, space):
-#        store = libssl_SSL_CTX_get_cert_store(self.ctx)
-#        x509 = 0
-#        x509_ca = 0
-#        crl = 0
-#        for i in range(libssl_sk_X509_OBJECT_num(store[0].c_objs)):
-#            obj = libssl_sk_X509_OBJECT_value(store[0].c_objs, i)
-#            if intmask(obj.c_type) == X509_LU_X509:
-#                x509 += 1
-#                if libssl_X509_check_ca(
-#                        libssl_pypy_X509_OBJECT_data_x509(obj)):
-#                    x509_ca += 1
-#            elif intmask(obj.c_type) == X509_LU_CRL:
-#                crl += 1
-#            else:
-#                # Ignore X509_LU_FAIL, X509_LU_RETRY, X509_LU_PKEY.
-#                # As far as I can tell they are internal states and never
-#                # stored in a cert store
-#                pass
-#        w_result = space.newdict()
-#        space.setitem(w_result,
-#                      space.wrap('x509'), space.wrap(x509))
-#        space.setitem(w_result,
-#                      space.wrap('x509_ca'), space.wrap(x509_ca))
-#        space.setitem(w_result,
-#                      space.wrap('crl'), space.wrap(crl))
-#        return w_result
-#
-#    @unwrap_spec(protos='bufferstr')
-#    def set_npn_protocols_w(self, space, protos):
-#        if not HAS_NPN:
-#            raise oefmt(space.w_NotImplementedError,
-#                        "The NPN extension requires OpenSSL 1.0.1 or later.")
-#
-#        self.npn_protocols = SSLNpnProtocols(self.ctx, protos)
-#
-#    @unwrap_spec(protos='bufferstr')
-#    def set_alpn_protocols_w(self, space, protos):
-#        if not HAS_ALPN:
-#            raise oefmt(space.w_NotImplementedError,
-#                        "The ALPN extension requires OpenSSL 1.0.2 or later.")
-#
-#        self.alpn_protocols = SSLAlpnProtocols(self.ctx, protos)
-#
     def get_ca_certs(self, binary_form=None):
         binary_mode = bool(binary_form)
         _list = []
@@ -1117,7 +1092,7 @@
             raise ValueError("unknown elliptic curve name '%s'" % name)
         key = lib.EC_KEY_new_by_curve_name(nid)
         if not key:
-            raise ssl_lib_error()
+            raise ssl_error(None)
         try:
             lib.SSL_CTX_set_tmp_ecdh(self.ctx, key)
         finally:
@@ -1139,6 +1114,29 @@
         lib.Cryptography_SSL_CTX_set_tlsext_servername_callback(self.ctx, 
_servername_callback)
         lib.Cryptography_SSL_CTX_set_tlsext_servername_arg(self.ctx, 
ffi.new_handle(callback_struct))
 
+    def _set_alpn_protocols(self, protos):
+        if HAS_ALPN:
+            self.alpn_protocols = protocols = ffi.from_buffer(protos)
+            length = len(protocols)
+
+            if lib.SSL_CTX_set_alpn_protos(self.ctx,ffi.cast("unsigned char*", 
protocols), length):
+                return MemoryError()
+            handle = ffi.new_handle(self)
+            lib.SSL_CTX_set_alpn_select_cb(self.ctx, select_alpn_callback, 
handle)
+        else:
+            raise NotImplementedError("The ALPN extension requires OpenSSL 
1.0.2 or later.")
+
+    def _set_npn_protocols(self, protos):
+        if HAS_NPN:
+            self.npn_protocols = ffi.from_buffer(protos)
+            handle = ffi.new_handle(self)
+            lib.SSL_CTX_set_next_protos_advertised_cb(self.ctx, 
advertise_npn_callback, handle)
+            lib.SSL_CTX_set_next_proto_select_cb(self.ctx, 
select_npn_callback, handle)
+        else:
+            raise NotImplementedError("The NPN extension requires OpenSSL 
1.0.1 or later.")
+
+
+
 @ffi.callback("void(void)")
 def _servername_callback(ssl, ad, arg):
     struct = ffi.from_handle(arg)
@@ -1209,9 +1207,6 @@
     ctx = None
 SERVERNAME_CALLBACKS = weakref.WeakValueDictionary()
 
-def _str_from_buf(buf):
-    return ffi.string(buf).decode('utf-8')
-
 def _asn1obj2py(obj):
     nid = lib.OBJ_obj2nid(obj)
     if nid == lib.NID_undef:
@@ -1282,7 +1277,7 @@
             raise ssl_error("cannot write() after write_eof()")
         nbytes = lib.BIO_write(self.bio, buf, len(buf));
         if nbytes < 0:
-            raise ssl_lib_error()
+            raise ssl_error(None)
         return nbytes
 
     def write_eof(self):
@@ -1311,6 +1306,7 @@
     def pending(self):
         return lib.BIO_ctrl_pending(self.bio)
 
+
 RAND_status = lib.RAND_status
 RAND_add = lib.RAND_add
 
@@ -1348,7 +1344,7 @@
 #            target = PyBytes_FromString(tmp); } \
 #        if (!target) goto error; \
 #    }
-    # XXX
+    # REVIEW
     return ffi.string(buf).decode(sys.getfilesystemencoding())
 
 def get_default_verify_paths():
@@ -1367,3 +1363,50 @@
         return odir
 
     return (ofile_env, ofile, odir_env, odir);
+
[email protected]("int(SSL*,unsigned char **,unsigned char *,const unsigned char 
*,unsigned int,void *)")
+def select_alpn_callback(ssl, out, outlen, client_protocols, 
client_protocols_len, args):
+    ctx = ffi.from_handle(args)
+    return do_protocol_selection(1, out, outlen,
+                                 ffi.cast("unsigned 
char*",ctx.alpn_protocols), len(ctx.alpn_protocols),
+                                 client_protocols, client_protocols_len)
+
[email protected]("int(SSL*,unsigned char **,unsigned char *,const unsigned char 
*,unsigned int,void *)")
+def select_npn_callback(ssl, out, outlen, server_protocols, 
server_protocols_len, args):
+    ctx = ffi.from_handle(args)
+    return do_protocol_selection(0, out, outlen, server_protocols, 
server_protocols_len,
+                                 ffi.cast("unsigned char*",ctx.npn_protocols), 
len(ctx.npn_protocols))
+
+
[email protected]("int(SSL*,const unsigned char**, unsigned int*, void*)")
+def advertise_npn_callback(ssl, data, length, args):
+    ctx = ffi.from_handle(args)
+
+    if not ctx.npn_protocols:
+        data[0] = ffi.new("unsigned char*", b"")
+        length[0] = 0
+    else:
+        data[0] = ffi.cast("unsigned char*",ctx.npn_protocols)
+        length[0] = len(ctx.npn_protocols)
+
+    return lib.SSL_TLSEXT_ERR_OK
+
+
+if lib.Cryptography_HAS_NPN_NEGOTIATED:
+    def do_protocol_selection(alpn, out, outlen, server_protocols, 
server_protocols_len,
+                                                 client_protocols, 
client_protocols_len):
+        if client_protocols == ffi.NULL:
+            client_protocols = b""
+            client_protocols_len = 0
+        if server_protocols == ffi.NULL:
+            server_protocols = ""
+            server_protocols_len = 0
+
+        ret = lib.SSL_select_next_proto(out, outlen,
+                                        server_protocols, server_protocols_len,
+                                        client_protocols, 
client_protocols_len);
+        if alpn and ret != lib.Cryptography_OPENSSL_NPN_NEGOTIATED:
+            return lib.SSL_TLSEXT_ERR_NOACK
+
+        return lib.SSL_TLSEXT_ERR_OK
+
diff --git a/lib_pypy/openssl/_stdssl/certificate.py 
b/lib_pypy/openssl/_stdssl/certificate.py
--- a/lib_pypy/openssl/_stdssl/certificate.py
+++ b/lib_pypy/openssl/_stdssl/certificate.py
@@ -4,7 +4,7 @@
 from _openssl import ffi
 from _openssl import lib
 from openssl._stdssl.utility import _string_from_asn1, _str_with_len, 
_bytes_with_len
-from openssl._stdssl.error import ssl_error, ssl_socket_error, ssl_lib_error
+from openssl._stdssl.error import ssl_error, pyssl_error
 
 X509_NAME_MAXLEN = 256
 
@@ -12,13 +12,13 @@
     buf = ffi.new("char[]", X509_NAME_MAXLEN)
     length = lib.OBJ_obj2txt(buf, X509_NAME_MAXLEN, name, 0)
     if length < 0:
-        raise ssl_socket_error(None, 0)
+        raise ssl_error(None)
     name = _str_with_len(buf, length)
 
     buf_ptr = ffi.new("unsigned char**")
     length = lib.ASN1_STRING_to_UTF8(buf_ptr, value)
     if length < 0:
-        raise ssl_socket_error(None, 0)
+        raise ssl_error(None)
     try:
         value = _str_with_len(buf_ptr[0], length)
     finally:
@@ -167,7 +167,7 @@
     length = lib.BIO_gets(biobuf, STATIC_BIO_BUF, len(STATIC_BIO_BUF)-1)
     if length < 0:
         if biobuf: lib.BIO_free(biobuf)
-        raise ssl_lib_error()
+        raise ssl_error(None)
     return _str_with_len(STATIC_BIO_BUF, length)
 
 def _decode_certificate(certificate):
@@ -198,7 +198,7 @@
         buf = ffi.new("char[]", 2048)
         length = lib.BIO_gets(biobuf, buf, len(buf)-1)
         if length < 0:
-            raise ssl_lib_error()
+            raise ssl_error(None)
         retval["serialNumber"] = _str_with_len(buf, length)
 
         lib.BIO_reset(biobuf);
@@ -206,7 +206,7 @@
         lib.ASN1_TIME_print(biobuf, notBefore);
         length = lib.BIO_gets(biobuf, buf, len(buf)-1);
         if length < 0:
-            raise ssl_lib_error()
+            raise ssl_error(None)
         retval["notBefore"] = _str_with_len(buf, length)
 
         lib.BIO_reset(biobuf);
@@ -214,7 +214,7 @@
         lib.ASN1_TIME_print(biobuf, notAfter);
         length = lib.BIO_gets(biobuf, buf, len(buf)-1);
         if length < 0:
-            raise ssl_lib_error()
+            raise ssl_error(None)
         retval["notAfter"] = _str_with_len(buf, length)
 
         # Now look for subjectAltName
@@ -332,7 +332,7 @@
     buf_ptr[0] = ffi.NULL
     length = lib.i2d_X509(certificate, buf_ptr)
     if length < 0:
-        raise ssl_lib_error()
+        raise ssl_error(None)
     try:
         return _bytes_with_len(ffi.cast("char*",buf_ptr[0]), length)
     finally:
diff --git a/lib_pypy/openssl/_stdssl/error.py 
b/lib_pypy/openssl/_stdssl/error.py
--- a/lib_pypy/openssl/_stdssl/error.py
+++ b/lib_pypy/openssl/_stdssl/error.py
@@ -1,7 +1,7 @@
 from _openssl import ffi
 from _openssl import lib
 
-from openssl._stdssl.utility import _string_from_asn1, _str_to_ffi_buffer
+from openssl._stdssl.utility import _string_from_asn1, _str_to_ffi_buffer, 
_str_from_buf
 
 SSL_ERROR_NONE = 0
 SSL_ERROR_SSL = 1
@@ -42,34 +42,13 @@
 class SSLEOFError(SSLError):
     """ SSL/TLS connection terminated abruptly. """
 
-def ssl_lib_error():
-    errcode = lib.ERR_peek_last_error()
-    lib.ERR_clear_error()
-    return ssl_error(None, 0, None, errcode)
-
-def ssl_error(msg, errno=0, errtype=None, errcode=0):
-    reason_str = None
-    lib_str = None
-    if errcode:
-        err_lib = lib.ERR_GET_LIB(errcode)
-        err_reason = lib.ERR_GET_REASON(errcode)
-        reason_str = ERR_CODES_TO_NAMES.get((err_lib, err_reason), None)
-        lib_str = LIB_CODES_TO_NAMES.get(err_lib, None)
-        msg = ffi.string(lib.ERR_reason_error_string(errcode)).decode('utf-8')
-    if not msg:
-        msg = "unknown error"
-    if reason_str and lib_str:
-        msg = "[%s: %s] %s" % (lib_str, reason_str, msg)
-    elif lib_str:
-        msg = "[%s] %s" % (lib_str, msg)
-
-    if errno or errcode:
-        error = SSLError(errno, msg)
-    else:
-        error = SSLError(msg)
-    error.reason = reason_str if reason_str else None
-    error.library = lib_str if lib_str else None
-    return error
+def ssl_error(errstr, errcode=0):
+    if errstr is None:
+        errcode = lib.ERR_peek_last_error()
+    try:
+        return fill_sslerror(SSLError, errcode, errstr)
+    finally:
+        lib.ERR_clear_error()
 
 ERR_CODES_TO_NAMES = {}
 ERR_NAMES_TO_CODES = {}
@@ -87,84 +66,85 @@
 for mnemo, number in _lib_codes:
     LIB_CODES_TO_NAMES[number] = mnemo
 
-def _fill_and_raise_ssl_error(error, errcode):
-    pass
-    if errcode != 0:
-        library = lib.ERR_GET_LIB(errcode);
-        reason = lib.ERR_GET_REASON(errcode);
-        key = (library, reason)
-        reason_obj = ERR_CODES_TO_NAMES[key]
-        lib_obj = LIB_CODES_TO_NAMES[library]
-        raise error("[%S: %S]" % (lib_obj, reason_obj))
-
-def _last_error():
-    errcode = lib.ERR_peek_last_error()
-    _fill_and_raise_ssl_error(SSLError, errcode)
-    #buf = ffi.new("char[4096]")
-    #length = lib.ERR_error_string(errcode, buf)
-    #return ffi.string(buf).decode()
-
 
 # the PySSL_SetError equivalent
-def ssl_socket_error(ss, ret):
+def pyssl_error(obj, ret):
     errcode = lib.ERR_peek_last_error()
 
-    if ss is None:
-        return ssl_error(None, errcode=errcode)
-    elif ss.ssl:
-        err = lib.SSL_get_error(ss.ssl, ret)
-    else:
-        err = SSL_ERROR_SSL
     errstr = ""
     errval = 0
     errtype = SSLError
+    e = lib.ERR_peek_last_error()
 
-    if err == SSL_ERROR_ZERO_RETURN:
-        errtype = ZeroReturnError
-        errstr = "TLS/SSL connection has been closed"
-        errval = SSL_ERROR_ZERO_RETURN
-    elif err == SSL_ERROR_WANT_READ:
-        errtype = WantReadError
-        errstr = "The operation did not complete (read)"
-        errval = SSL_ERROR_WANT_READ
-    elif err == SSL_ERROR_WANT_WRITE:
-        errtype = WantWriteError
-        errstr = "The operation did not complete (write)"
-        errval = SSL_ERROR_WANT_WRITE
-    elif err == SSL_ERROR_WANT_X509_LOOKUP:
-        errstr = "The operation did not complete (X509 lookup)"
-        errval = SSL_ERROR_WANT_X509_LOOKUP
-    elif err == SSL_ERROR_WANT_CONNECT:
-        errstr = "The operation did not complete (connect)"
-        errval = SSL_ERROR_WANT_CONNECT
-    elif err == SSL_ERROR_SYSCALL:
-        xxx
-        e = lib.ERR_get_error()
-        if e == 0:
-            if ret == 0 or ss.w_socket() is None:
-                w_errtype = get_error(space).w_EOFError
-                errstr = "EOF occurred in violation of protocol"
-                errval = PY_SSL_ERROR_EOF
-            elif ret == -1:
-                # the underlying BIO reported an I/0 error
-                error = rsocket.last_error()
-                return interp_socket.converted_error(space, error)
+    if obj.ssl != ffi.NULL:
+        err = lib.SSL_get_error(obj.ssl, ret)
+
+        if err == SSL_ERROR_ZERO_RETURN:
+            errtype = SSLZeroReturnError
+            errstr = "TLS/SSL connection has been closed"
+            errval = SSL_ERROR_ZERO_RETURN
+        elif err == SSL_ERROR_WANT_READ:
+            errtype = SSLWantReadError
+            errstr = "The operation did not complete (read)"
+            errval = SSL_ERROR_WANT_READ
+        elif err == SSL_ERROR_WANT_WRITE:
+            errtype = SSLWantWriteError
+            errstr = "The operation did not complete (write)"
+            errval = SSL_ERROR_WANT_WRITE
+        elif err == SSL_ERROR_WANT_X509_LOOKUP:
+            errstr = "The operation did not complete (X509 lookup)"
+            errval = SSL_ERROR_WANT_X509_LOOKUP
+        elif err == SSL_ERROR_WANT_CONNECT:
+            errstr = "The operation did not complete (connect)"
+            errval = SSL_ERROR_WANT_CONNECT
+        elif err == SSL_ERROR_SYSCALL:
+            if e == 0:
+                if ret == 0 or obj.get_socket_or_None() is None:
+                    errtype = EOFError
+                    errstr = "EOF occurred in violation of protocol"
+                    errval = SSL_ERROR_EOF
+                elif ret == -1:
+                    # the underlying BIO reported an I/0 error
+                    errno = ffi.errno
+                    return IOError(errno)
+                else:
+                    errtype = SSLSyscallError
+                    errstr = "Some I/O error occurred"
+                    errval = SSL_ERROR_SYSCALL
             else:
-                w_errtype = get_error(space).w_SyscallError
-                errstr = "Some I/O error occurred"
-                errval = PY_SSL_ERROR_SYSCALL
+                errstr = _str_from_buf(lib.ERR_error_string(e, ffi.NULL))
+                errval = SSL_ERROR_SYSCALL
+        elif err == SSL_ERROR_SSL:
+            errval = SSL_ERROR_SSL
+            if errcode != 0:
+                errstr = _str_from_buf(lib.ERR_error_string(errcode, ffi.NULL))
+            else:
+                errstr = "A failure in the SSL library occurred"
         else:
-            errstr = rffi.charp2str(libssl_ERR_error_string(e, None))
-            errval = PY_SSL_ERROR_SYSCALL
-    elif err == SSL_ERROR_SSL:
-        errval = SSL_ERROR_SSL
-        if errcode != 0:
-            errstr = _str_to_ffi_buffer(lib.ERR_error_string(errcode, 
ffi.NULL))
-        else:
-            errstr = "A failure in the SSL library occurred"
-    else:
-        errstr = "Invalid error code"
-        errval = SSL_ERROR_INVALID_ERROR_CODE
+            errstr = "Invalid error code"
+            errval = SSL_ERROR_INVALID_ERROR_CODE
+    return fill_sslerror(errtype, errval, errstr, e)
 
-    return errtype(errstr, errval)
 
+def fill_sslerror(errtype, ssl_errno, errstr, errcode):
+    reason_str = None
+    lib_str = None
+    if errcode != 0:
+        err_lib = lib.ERR_GET_LIB(errcode)
+        err_reason = lib.ERR_GET_REASON(errcode)
+        reason_str = ERR_CODES_TO_NAMES.get((err_lib, err_reason), None)
+        lib_str = LIB_CODES_TO_NAMES.get(err_lib, None)
+        if errstr is None:
+            errstr = _str_from_buf(lib.ERR_reason_error_string(errcode))
+    if not errstr:
+        msg = "unknown error"
+    if reason_str and lib_str:
+        msg = "[%s: %s] %s" % (lib_str, reason_str, errstr)
+    elif lib_str:
+        msg = "[%s] %s" % (lib_str, errstr)
+
+    err_value = errtype(ssl_errno, msg)
+    err_value.reason = reason_str if reason_str else None
+    err_value.library = lib_str if lib_str else None
+    return err_value
+
diff --git a/lib_pypy/openssl/_stdssl/utility.py 
b/lib_pypy/openssl/_stdssl/utility.py
--- a/lib_pypy/openssl/_stdssl/utility.py
+++ b/lib_pypy/openssl/_stdssl/utility.py
@@ -33,3 +33,6 @@
         else:
             return ffi.from_buffer(view)
 
+def _str_from_buf(buf):
+    return ffi.string(buf).decode('utf-8')
+
diff --git a/lib_pypy/ssl.py b/lib_pypy/ssl.py
--- a/lib_pypy/ssl.py
+++ b/lib_pypy/ssl.py
@@ -1,9 +1,3 @@
-# This file exposes the Standard Library API for the ssl module
-
-#from openssl._stdssl import (_PROTOCOL_NAMES, _OPENSSL_API_VERSION)
-#from openssl._stdssl import _ssl
-#from openssl._stdssl import *
-
 # Wrapper module for _ssl, providing some additional facilities
 # implemented in Python.  Written by Bill Janssen.
 
@@ -151,6 +145,7 @@
 from socket import SOL_SOCKET, SO_TYPE
 import base64        # for DER-to-PEM translation
 import errno
+import warnings
 
 
 socket_error = OSError  # keep that public name in module namespace
@@ -411,12 +406,16 @@
 
     def _load_windows_store_certs(self, storename, purpose):
         certs = bytearray()
-        for cert, encoding, trust in enum_certificates(storename):
-            # CA certs are never PKCS#7 encoded
-            if encoding == "x509_asn":
-                if trust is True or purpose.oid in trust:
-                    certs.extend(cert)
-        self.load_verify_locations(cadata=certs)
+        try:
+            for cert, encoding, trust in enum_certificates(storename):
+                # CA certs are never PKCS#7 encoded
+                if encoding == "x509_asn":
+                    if trust is True or purpose.oid in trust:
+                        certs.extend(cert)
+        except PermissionError:
+            warnings.warn("unable to enumerate Windows certificate store")
+        if certs:
+            self.load_verify_locations(cadata=certs)
         return certs
 
     def load_default_certs(self, purpose=Purpose.SERVER_AUTH):
@@ -566,7 +565,7 @@
         server hostame is set."""
         return self._sslobj.server_hostname
 
-    def read(self, len=0, buffer=None):
+    def read(self, len=1024, buffer=None):
         """Read up to 'len' bytes from the SSL object and return them.
 
         If 'buffer' is provided, read into this buffer and return the number of
@@ -575,7 +574,7 @@
         if buffer is not None:
             v = self._sslobj.read(len, buffer)
         else:
-            v = self._sslobj.read(len or 1024)
+            v = self._sslobj.read(len)
         return v
 
     def write(self, data):
@@ -781,7 +780,7 @@
             # EAGAIN.
             self.getpeername()
 
-    def read(self, len=0, buffer=None):
+    def read(self, len=1024, buffer=None):
         """Read up to LEN bytes and return them.
         Return zero-length string on EOF."""
 
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to