diff -ur m2crypto/M2Crypto/SSL/Connection.py m2crypto-0.16/M2Crypto/SSL/Connection.py --- m2crypto/M2Crypto/SSL/Connection.py 2006-05-11 00:43:47.000000000 +0200 +++ m2crypto-0.16/M2Crypto/SSL/Connection.py 2006-10-23 18:55:39.000000000 +0200 @@ -36,9 +36,11 @@ self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self._fileno = self.socket.fileno() - - self.blocking = self.socket.gettimeout() - + + self._timeout = self.socket.gettimeout() + if self._timeout is None: + self._timeout = -1.0 + def __del__(self): if getattr(self, 'sslbio', None): self.m2_bio_free(self.sslbio) @@ -122,7 +124,7 @@ m2.ssl_set_accept_state(self.ssl) def accept_ssl(self): - return m2.ssl_accept(self.ssl) + return m2.ssl_accept(self.ssl, self._timeout) def accept(self): """Accept an SSL connection. The return value is a pair (ssl, addr) where @@ -144,7 +146,7 @@ m2.ssl_set_connect_state(self.ssl) def connect_ssl(self): - return m2.ssl_connect(self.ssl) + return m2.ssl_connect(self.ssl, self._timeout) def connect(self, addr): self.socket.connect(addr) @@ -171,7 +173,7 @@ return m2.ssl_pending(self.ssl) def _write_bio(self, data): - return m2.ssl_write(self.ssl, data) + return m2.ssl_write(self.ssl, data, self._timeout) def _write_nbio(self, data): return m2.ssl_write_nbio(self.ssl, data) @@ -179,7 +181,7 @@ def _read_bio(self, size=1024): if size <= 0: raise ValueError, 'size <= 0' - return m2.ssl_read(self.ssl, size) + return m2.ssl_read(self.ssl, size, self._timeout) def _read_nbio(self, size=1024): if size <= 0: @@ -187,13 +189,13 @@ return m2.ssl_read_nbio(self.ssl, size) def write(self, data): - if self.blocking: + if self._timeout != 0.0: return self._write_bio(data) return self._write_nbio(data) sendall = send = write def read(self, size=1024): - if self.blocking: + if self._timeout != 0.0: return self._read_bio(size) return self._read_nbio(size) recv = read @@ -201,7 +203,17 @@ def setblocking(self, mode): """Set this connection's underlying socket to _mode_.""" self.socket.setblocking(mode) - self.blocking = mode + if mode: + self._timeout = -1.0 + else: + self._timeout = 0.0 + + def settimeout(self, timeout): + """Set this connection's underlying socket's timeout to _timeout_.""" + self.socket.settimeout(timeout) + self._timeout = timeout + if self._timeout is None: + self._timeout = -1.0 def fileno(self): return self.socket.fileno() @@ -280,15 +292,8 @@ """Set the cipher suites for this connection.""" return m2.ssl_set_cipher_list(self.ssl, cipher_list) - def makefile(self, mode='rb', bufsize='ignored'): - r = 'r' in mode or '+' in mode - w = 'w' in mode or 'a' in mode or '+' in mode - b = 'b' in mode - m2mode = ['', 'r'][r] + ['', 'w'][w] + ['', 'b'][b] - # XXX Need to dup(). - bio = BIO.BIO(self.sslbio, _close_cb=self.close) - m2.bio_do_handshake(bio._ptr()) - return BIO.IOBuffer(bio, m2mode) + def makefile(self, mode='rb', bufsize=-1): + return socket._fileobject(self, mode, bufsize) def getsockname(self): return self.socket.getsockname() diff -ur m2crypto/M2Crypto/SSL/__init__.py m2crypto-0.16/M2Crypto/SSL/__init__.py --- m2crypto/M2Crypto/SSL/__init__.py 2006-03-20 20:26:28.000000000 +0100 +++ m2crypto-0.16/M2Crypto/SSL/__init__.py 2006-10-23 19:53:10.000000000 +0200 @@ -2,11 +2,14 @@ Copyright (c) 1999-2004 Ng Pheng Siong. All rights reserved.""" +import socket + # M2Crypto from M2Crypto import m2 class SSLError(Exception): pass -m2.ssl_init(SSLError) +class SSLTimeoutError(SSLError, socket.timeout): pass +m2.ssl_init(SSLError, SSLTimeoutError) # M2Crypto.SSL from Cipher import Cipher, Cipher_Stack diff -ur m2crypto/SWIG/_ssl.i m2crypto-0.16/SWIG/_ssl.i --- m2crypto/SWIG/_ssl.i 2006-04-01 00:11:55.000000000 +0200 +++ m2crypto-0.16/SWIG/_ssl.i 2006-10-23 20:00:36.000000000 +0200 @@ -8,10 +8,13 @@ %{ #include <pythread.h> +#include <limits.h> #include <openssl/bio.h> #include <openssl/dh.h> #include <openssl/ssl.h> #include <openssl/x509.h> +#include <poll.h> +#include <sys/time.h> %} %apply Pointer NONNULL { SSL_CTX * }; @@ -142,6 +145,11 @@ %rename(ssl_session_get_timeout) SSL_SESSION_get_timeout; extern long SSL_SESSION_get_timeout(CONST SSL_SESSION *); +extern PyObject *ssl_accept(SSL *ssl, double timeout = -1); +extern PyObject *ssl_connect(SSL *ssl, double timeout = -1); +extern PyObject *ssl_read(SSL *ssl, int num, double timeout = -1); +extern int ssl_write(SSL *ssl, PyObject *blob, double timeout = -1); + %constant int ssl_error_none = SSL_ERROR_NONE; %constant int ssl_error_ssl = SSL_ERROR_SSL; %constant int ssl_error_want_read = SSL_ERROR_WANT_READ; @@ -192,14 +200,19 @@ %constant int SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER = SSL_MODE_ENABLE_PARTIAL_WRITE; %constant int SSL_MODE_AUTO_RETRY = SSL_MODE_AUTO_RETRY; +%ignore ssl_handle_error; +%ignore ssl_sleep_with_timeout; %inline %{ static PyObject *_ssl_err; +static PyObject *_ssl_timeout_err; -void ssl_init(PyObject *ssl_err) { +void ssl_init(PyObject *ssl_err, PyObject *ssl_timeout_err) { SSL_library_init(); SSL_load_error_strings(); Py_INCREF(ssl_err); + Py_INCREF(ssl_timeout_err); _ssl_err = ssl_err; + _ssl_timeout_err = ssl_timeout_err; } void ssl_ctx_passphrase_callback(SSL_CTX *ctx, PyObject *pyfunc) { @@ -353,36 +366,139 @@ return ret; } -PyObject *ssl_accept(SSL *ssl) { +static void ssl_handle_error(int ssl_err, int ret) { + int err; + + switch (ssl_err) { + case SSL_ERROR_SSL: + PyErr_SetString(_ssl_err, + ERR_reason_error_string(ERR_get_error())); + break; + case SSL_ERROR_SYSCALL: + err = ERR_get_error(); + if (err) + PyErr_SetString(_ssl_err, ERR_reason_error_string(err)); + else if (ret == 0) + PyErr_SetString(_ssl_err, "unexpected eof"); + else if (ret == -1) + PyErr_SetFromErrno(_ssl_err); + else + assert(0); + break; + default: + PyErr_SetString(_ssl_err, "unexpected SSL error"); + } +} + +static int ssl_sleep_with_timeout(SSL *ssl, const struct timeval *start, + double timeout, int ssl_err) { + PyGILState_STATE gilstate; + struct pollfd fd; + struct timeval tv; + int ms; + + assert(timeout > 0); + again: + gettimeofday(&tv, NULL); + /* tv >= start */ + if ((timeout + start->tv_sec - tv.tv_sec) > INT_MAX / 1000) + ms = -1; + else { + int fract; + + ms = ((start->tv_sec + (int)timeout) - tv.tv_sec) * 1000; + fract = (start->tv_usec + (timeout - (int)timeout) * 1000000 + - tv.tv_usec + 999) / 1000; + if (ms > 0 && fract > INT_MAX - ms) + ms = -1; + else { + ms += fract; + if (ms <= 0) + goto timeout; + } + } + switch (ssl_err) { + case SSL_ERROR_WANT_READ: + fd.fd = SSL_get_rfd(ssl); + fd.events = POLLIN; + break; + + case SSL_ERROR_WANT_WRITE: + fd.fd = SSL_get_wfd(ssl); + fd.events = POLLOUT; + break; + + case SSL_ERROR_WANT_X509_LOOKUP: + return 0; /* FIXME: is this correct? */ + + default: + assert(0); + } + if (fd.fd == -1) { + gilstate = PyGILState_Ensure(); + PyErr_SetString(_ssl_err, "timeout on a non-FD SSL"); + PyGILState_Release(gilstate); + return -1; + } + switch (poll(&fd, 1, ms)) { + case 1: + return 0; + case 0: + goto timeout; + case -1: + if (errno == EINTR) + goto again; + gilstate = PyGILState_Ensure(); + PyErr_SetFromErrno(_ssl_err); + PyGILState_Release(gilstate); + return -1; + } + return 0; + + timeout: + gilstate = PyGILState_Ensure(); + PyErr_SetString(_ssl_timeout_err, "timed out"); + PyGILState_Release(gilstate); + return -1; +} + +PyObject *ssl_accept(SSL *ssl, double timeout) { PyObject *obj = NULL; - int r, err; + int r, ssl_err; + struct timeval tv; PyGILState_STATE gilstate; + if (timeout > 0) + gettimeofday(&tv, NULL); + again: r = SSL_accept(ssl); + ssl_err = SSL_get_error(ssl, r); gilstate = PyGILState_Ensure(); - switch (SSL_get_error(ssl, r)) { + switch (ssl_err) { case SSL_ERROR_NONE: case SSL_ERROR_ZERO_RETURN: obj = PyInt_FromLong((long)1); break; case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_READ: - obj = PyInt_FromLong((long)0); - break; - case SSL_ERROR_SSL: - PyErr_SetString(_ssl_err, ERR_reason_error_string(ERR_get_error())); + if (timeout <= 0) { + obj = PyInt_FromLong((long)0); + break; + } + + PyGILState_Release(gilstate); + if (ssl_sleep_with_timeout(ssl, &tv, timeout, ssl_err) == 0) + goto again; + gilstate = PyGILState_Ensure(); + obj = NULL; break; + + case SSL_ERROR_SSL: case SSL_ERROR_SYSCALL: - err = ERR_get_error(); - if (err) - PyErr_SetString(_ssl_err, ERR_reason_error_string(err)); - else if (r == 0) - PyErr_SetString(_ssl_err, "unexpected eof"); - else if (r == -1) - PyErr_SetFromErrno(_ssl_err); + ssl_handle_error(ssl_err, r); obj = NULL; break; } @@ -392,36 +508,43 @@ return obj; } -PyObject *ssl_connect(SSL *ssl) { +PyObject *ssl_connect(SSL *ssl, double timeout) { PyObject *obj = NULL; - int r, err; + int r, ssl_err; + struct timeval tv; PyGILState_STATE gilstate; + if (timeout > 0) + gettimeofday(&tv, NULL); + again: r = SSL_connect(ssl); + ssl_err = SSL_get_error(ssl, r); gilstate = PyGILState_Ensure(); - switch (SSL_get_error(ssl, r)) { + switch (ssl_err) { case SSL_ERROR_NONE: case SSL_ERROR_ZERO_RETURN: obj = PyInt_FromLong((long)1); break; case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_READ: - obj = PyInt_FromLong((long)0); - break; - case SSL_ERROR_SSL: - PyErr_SetString(_ssl_err, ERR_reason_error_string(ERR_get_error())); + if (timeout <= 0) { + obj = PyInt_FromLong((long)0); + break; + } + + PyGILState_Release(gilstate); + if (ssl_sleep_with_timeout(ssl, &tv, timeout, ssl_err) == 0) + goto again; + gilstate = PyGILState_Ensure(); + obj = NULL; break; + + case SSL_ERROR_SSL: case SSL_ERROR_SYSCALL: - err = ERR_get_error(); - if (err) - PyErr_SetString(_ssl_err, ERR_reason_error_string(err)); - else if (r == 0) - PyErr_SetString(_ssl_err, "unexpected eof"); - else if (r == -1) - PyErr_SetFromErrno(_ssl_err); + ssl_handle_error(ssl_err, r); obj = NULL; break; } @@ -435,10 +558,11 @@ SSL_set_shutdown(ssl, mode); } -PyObject *ssl_read(SSL *ssl, int num) { +PyObject *ssl_read(SSL *ssl, int num, double timeout) { PyObject *obj = NULL; void *buf; - int r, err; + int r; + struct timeval tv; PyGILState_STATE gilstate; gilstate = PyGILState_Ensure(); @@ -451,36 +575,48 @@ PyGILState_Release(gilstate); + if (timeout > 0) + gettimeofday(&tv, NULL); + again: r = SSL_read(ssl, buf, num); gilstate = PyGILState_Ensure(); - switch (SSL_get_error(ssl, r)) { - case SSL_ERROR_NONE: - case SSL_ERROR_ZERO_RETURN: - buf = PyMem_Realloc(buf, r); - obj = PyString_FromStringAndSize(buf, r); - break; - case SSL_ERROR_WANT_WRITE: - case SSL_ERROR_WANT_READ: - case SSL_ERROR_WANT_X509_LOOKUP: - Py_INCREF(Py_None); - obj = Py_None; - break; - case SSL_ERROR_SSL: - PyErr_SetString(_ssl_err, ERR_reason_error_string(ERR_get_error())); - obj = NULL; - break; - case SSL_ERROR_SYSCALL: - err = ERR_get_error(); - if (err) - PyErr_SetString(_ssl_err, ERR_reason_error_string(err)); - else if (r == 0) - PyErr_SetString(_ssl_err, "unexpected eof"); - else if (r == -1) - PyErr_SetFromErrno(_ssl_err); - obj = NULL; - break; + if (r >= 0) { + buf = PyMem_Realloc(buf, r); + obj = PyString_FromStringAndSize(buf, r); + } else { + int ssl_err; + + ssl_err = SSL_get_error(ssl, r); + switch (ssl_err) { + case SSL_ERROR_NONE: + case SSL_ERROR_ZERO_RETURN: + assert(0); + + case SSL_ERROR_WANT_WRITE: + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_X509_LOOKUP: + if (timeout <= 0) { + Py_INCREF(Py_None); + obj = Py_None; + break; + } + + PyGILState_Release(gilstate); + if (ssl_sleep_with_timeout(ssl, &tv, timeout, ssl_err) == 0) + goto again; + gilstate = PyGILState_Ensure(); + + obj = NULL; + break; + + case SSL_ERROR_SSL: + case SSL_ERROR_SYSCALL: + ssl_handle_error(ssl_err, r); + obj = NULL; + break; + } } PyMem_Free(buf); @@ -543,9 +679,10 @@ return obj; } -int ssl_write(SSL *ssl, PyObject *blob) { +int ssl_write(SSL *ssl, PyObject *blob, double timeout) { const void *buf; - int len, r, err, ret; + int len, r, ssl_err, ret; + struct timeval tv; PyGILState_STATE gilstate; gilstate = PyGILState_Ensure(); @@ -556,12 +693,16 @@ } PyGILState_Release(gilstate); - + + if (timeout > 0) + gettimeofday(&tv, NULL); + again: r = SSL_write(ssl, buf, len); + ssl_err = SSL_get_error(ssl, r); gilstate = PyGILState_Ensure(); - switch (SSL_get_error(ssl, r)) { + switch (ssl_err) { case SSL_ERROR_NONE: case SSL_ERROR_ZERO_RETURN: ret = r; @@ -569,20 +710,22 @@ case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_X509_LOOKUP: + if (timeout <= 0) { + ret = -1; + break; + } + + PyGILState_Release(gilstate); + if (ssl_sleep_with_timeout(ssl, &tv, timeout, ssl_err) == 0) + goto again; + gilstate = PyGILState_Ensure(); + ret = -1; break; + case SSL_ERROR_SSL: - PyErr_SetString(_ssl_err, ERR_reason_error_string(ERR_get_error())); - ret = -1; - break; case SSL_ERROR_SYSCALL: - err = ERR_get_error(); - if (err) - PyErr_SetString(_ssl_err, ERR_reason_error_string(ERR_get_error())); - else if (r == 0) - PyErr_SetString(_ssl_err, "unexpected eof"); - else if (r == -1) - PyErr_SetFromErrno(_ssl_err); + ssl_handle_error(ssl_err, r); default: ret = -1; } diff -ur m2crypto/tests/test_ssl.py m2crypto-0.16/tests/test_ssl.py --- m2crypto/tests/test_ssl.py 2006-10-22 01:22:57.000000000 +0200 +++ m2crypto-0.16/tests/test_ssl.py 2006-10-23 20:14:46.000000000 +0200 @@ -631,6 +631,53 @@ self.stop_server(pid) self.failIf(string.find(data, 's_server -quiet -www') == -1) + def test_timeout(self): + pid = self.start_server(self.args) + try: + ctx = SSL.Context() + s = SSL.Connection(ctx) + # Just a really small number so we can timeout + s.settimeout(0.000000000000000000000000000001) + self.assertRaises(SSL.SSLTimeoutError, s.connect, self.srv_addr) + s.close() + finally: + self.stop_server(pid) + + def test_makefile_timeout(self): + # httpslib uses makefile to read the response + pid = self.start_server(self.args) + try: + from M2Crypto import httpslib + c = httpslib.HTTPS(srv_host, srv_port) + c.putrequest('GET', '/') + c.putheader('Accept', 'text/html') + c.putheader('Accept', 'text/plain') + c.endheaders() + c._conn.sock.settimeout(100) + err, msg, headers = c.getreply() + assert err == 200, err + f = c.getfile() + data = f.read() + c.close() + finally: + self.stop_server(pid) + self.failIf(string.find(data, 's_server -quiet -www') == -1) + + def test_makefile_timeout_fires(self): + pid = self.start_server(self.args) + try: + from M2Crypto import httpslib + c = httpslib.HTTPS(srv_host, srv_port) + c.putrequest('GET', '/') + c.putheader('Accept', 'text/html') + c.putheader('Accept', 'text/plain') + c.endheaders() + c._conn.sock.settimeout(0.0000000001) + self.assertRaises(socket.timeout, c.getreply) + c.close() + finally: + self.stop_server(pid) + def test_twisted_wrapper(self): # # LEAK ALERT!