diff --git a/examples/client.py b/examples/client.py index f6f1406..b83d4c6 100755 --- a/examples/client.py +++ b/examples/client.py @@ -130,6 +130,8 @@ def main(): # DTLS connection over UDP if args.u: + if args.v > 2: + args.v = 1 bind_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0) context = wolfssl.SSLContext(get_DTLSmethod(args.v)) # SSL/TLS connection over TCP @@ -151,6 +153,7 @@ def main(): if args.l: context.set_ciphers(args.l) + secure_socket = None try: secure_socket = context.wrap_socket(bind_socket) @@ -171,7 +174,8 @@ def main(): print() finally: - secure_socket.close() + if secure_socket: + secure_socket.close() if __name__ == '__main__': diff --git a/examples/server.py b/examples/server.py index 28812e4..8041ad8 100755 --- a/examples/server.py +++ b/examples/server.py @@ -170,7 +170,8 @@ def main(): finally: if secure_socket: secure_socket.shutdown(socket.SHUT_RDWR) - secure_socket.close() + if not args.u: + secure_socket.close() if not args.i: break diff --git a/tests/test_context.py b/tests/test_context.py index 9609d14..0aa804b 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -74,3 +74,17 @@ def test_load_verify_locations_with_cafile(ssl_context): def test_load_verify_locations_with_cadata(ssl_context): ssl_context.load_verify_locations(cadata=_CADATA) + + +def test_check_hostname_requires_cert_required(ssl_provider, ssl_context): + with pytest.raises(ValueError): + ssl_context.check_hostname = True + + ssl_context.verify_mode = ssl_provider.CERT_REQUIRED + ssl_context.check_hostname = True + assert ssl_context.check_hostname is True + + +def test_wrap_socket_server_side_mismatch(ssl_context, tcp_socket): + with pytest.raises(ValueError): + ssl_context.wrap_socket(tcp_socket, server_side=True) diff --git a/wolfssl/__init__.py b/wolfssl/__init__.py index 7678fb5..311bb85 100644 --- a/wolfssl/__init__.py +++ b/wolfssl/__init__.py @@ -81,7 +81,9 @@ class WolfSSL(object): @classmethod def enable_debug(self): - _lib.wolfSSL_Debugging_ON() + if _lib.wolfSSL_Debugging_ON() != _SSL_SUCCESS: + raise RuntimeError( + "wolfSSL debugging not available") @classmethod def disable_debug(self): @@ -143,9 +145,7 @@ def get_der(self): if derPtr == _ffi.NULL: return None - derBytes = _ffi.buffer(derPtr, outSz[0]) - - return derBytes + return _ffi.buffer(derPtr, outSz[0])[:] class SSLContext(object): """ @@ -154,7 +154,9 @@ class SSLContext(object): """ def __init__(self, protocol, server_side=None): - _lib.wolfSSL_Init() + if _lib.wolfSSL_Init() != _SSL_SUCCESS: + raise RuntimeError( + "wolfSSL library initialization failed") method = _WolfSSLMethod(protocol, server_side) self.protocol = protocol @@ -356,9 +358,10 @@ def load_verify_locations(self, cafile=None, capath=None, cadata=None): raise SSLError("Unable to load verify locations. E(%d)" % ret) if cadata is not None: + cadata_bytes = t2b(cadata) ret = _lib.wolfSSL_CTX_load_verify_buffer( - self.native_object, t2b(cadata), - len(cadata), _SSL_FILETYPE_PEM) + self.native_object, cadata_bytes, + len(cadata_bytes), _SSL_FILETYPE_PEM) if ret != _SSL_SUCCESS: raise SSLError("Unable to load verify locations. E(%d)" % ret) @@ -476,8 +479,11 @@ def __init__(self, sock=None, keyfile=None, certfile=None, ret = _lib.wolfSSL_check_domain_name(self.native_object, sni) if ret != _SSL_SUCCESS: - raise SSLError("Unable to set domain name check for " - "hostname verification") + self._release_native_object() + raise SSLError( + "Unable to set domain name " + "check for hostname " + "verification") if connected: try: @@ -497,6 +503,7 @@ def _release_native_object(self): self.native_object = _ffi.NULL def pending(self): + self._check_closed("pending") return _lib.wolfSSL_pending(self.native_object) @property @@ -605,13 +612,6 @@ def sendall(self, data, flags=0): while sent < length: ret = self.write(data[sent:]) - if (ret <= 0): - #expect to receive 0 when peer is reset or closed - err = _lib.wolfSSL_get_error(self.native_object, 0) - if err == _SSL_ERROR_WANT_WRITE: - raise SSLWantWriteError() - else: - raise SSLError("wolfSSL_write error (%d)" % err) sent += ret @@ -676,11 +676,13 @@ def recv_into(self, buffer, nbytes=None, flags=0): self._check_closed("read") if self._context.protocol < PROTOCOL_DTLSv1: self._check_connected() + else: + self.do_handshake() if buffer is None: raise ValueError("buffer cannot be None") - if nbytes is None: + if nbytes is None or nbytes == 0: nbytes = len(buffer) else: nbytes = min(len(buffer), nbytes) @@ -721,7 +723,9 @@ def recvmsg_into(self, *args, **kwargs): def shutdown(self, how): if self.native_object != _ffi.NULL: - _lib.wolfSSL_shutdown(self.native_object) + ret = _lib.wolfSSL_shutdown(self.native_object) + if ret == 0: + _lib.wolfSSL_shutdown(self.native_object) self._release_native_object() if self._context.protocol < PROTOCOL_DTLSv1: self._sock.shutdown(how) @@ -732,7 +736,8 @@ def unwrap(self): Returns the wrapped OS socket. """ if self.native_object != _ffi.NULL: - _lib.wolfSSL_set_fd(self.native_object, -1) + _lib.wolfSSL_shutdown(self.native_object) + self._release_native_object() sock = socket(family=self._sock.family, sock_type=self._sock.type, @@ -748,11 +753,16 @@ def add_peer(self, addr): peerAddr = _lib.wolfSSL_dtls_create_peer(addr[1],t2b(addr[0])) if peerAddr == _ffi.NULL: raise SSLError("Failed to create peer") - ret = _lib.wolfSSL_dtls_set_peer(self.native_object, peerAddr, - _SOCKADDR_SZ) - if ret != _SSL_SUCCESS: - raise SSLError("Unable to set dtls peer. E(%d)" % ret) - _lib.wolfSSL_dtls_free_peer(peerAddr) + try: + ret = _lib.wolfSSL_dtls_set_peer( + self.native_object, peerAddr, + _SOCKADDR_SZ) + if ret != _SSL_SUCCESS: + raise SSLError( + "Unable to set dtls peer." + " E(%d)" % ret) + finally: + _lib.wolfSSL_dtls_free_peer(peerAddr) def do_handshake(self, block=False): # pylint: disable=unused-argument """ @@ -814,18 +824,16 @@ def _real_connect(self, addr, connect_ex): raise ValueError("attempt to connect already-connected SSLSocket!") err = 0 - ret = _SSL_SUCCESS - + if self._context.protocol >= PROTOCOL_DTLSv1: - self.add_peer(addr) + self.add_peer(addr) else: if connect_ex: err = self._sock.connect_ex(addr) else: - err = 0 self._sock.connect(addr) - if err == 0 and ret == _SSL_SUCCESS: + if err == 0: self._connected = True if self.do_handshake_on_connect: self.do_handshake() @@ -894,12 +902,18 @@ def version(self): """ Returns the version of the protocol used in the connection. """ - return _ffi.string(_lib.wolfSSL_get_version(self.native_object)).decode("ascii") + self._check_closed("version") + return _ffi.string( + _lib.wolfSSL_get_version( + self.native_object)).decode("ascii") # The following functions expose functionality of the underlying # Socket object. These are also exposed through Python's ssl module # API and are provided here for compatibility. def close(self): + if self.native_object != _ffi.NULL: + _lib.wolfSSL_shutdown(self.native_object) + self._release_native_object() self._sock.close() def fileno(self): @@ -1029,12 +1043,17 @@ def callback(self): def _get_passwd(self, passwd, sz, rw, userdata): try: result = self._passwd_wrapper(sz, rw, userdata) - if not isinstance(result, bytes): - raise ValueError("Problem, expected String, not bytes") - if len(result) > sz: - raise ValueError("Problem with password returned being long") - for i in range(len(result)): - passwd[i] = result[i:i + 1] - return len(result) - except Exception as e: - raise ValueError("Problem getting password from callback") + except Exception: + raise ValueError( + "Problem getting password from callback") + if not isinstance(result, bytes): + raise ValueError( + "Password callback must return bytes," + " not str") + if len(result) > sz: + raise ValueError( + "Problem with password returned" + " being long") + for i in range(len(result)): + passwd[i] = result[i:i + 1] + return len(result) diff --git a/wolfssl/_build_ffi.py b/wolfssl/_build_ffi.py index dbc624b..f2df0bd 100644 --- a/wolfssl/_build_ffi.py +++ b/wolfssl/_build_ffi.py @@ -405,7 +405,7 @@ def generate_libwolfssl(): /* * Debugging */ - void wolfSSL_Debugging_ON(); + int wolfSSL_Debugging_ON(void); void wolfSSL_Debugging_OFF(); /* @@ -474,7 +474,7 @@ def generate_libwolfssl(): /* * SSL/TLS Session functions */ - void wolfSSL_Init(); + int wolfSSL_Init(void); WOLFSSL* wolfSSL_new(WOLFSSL_CTX*); void wolfSSL_free(WOLFSSL*);