From 64eaffcc4524a6fa032a0510d4cd34a64c38e8c5 Mon Sep 17 00:00:00 2001 From: Jean-Paul Calderone Date: Mon, 13 Feb 2012 11:53:49 -0500 Subject: Add Connection.get_session and have the Session object actually wrap an SSL_SESSION* (though there is actually not yet any way to tell that that is the case) --- OpenSSL/ssl/connection.c | 32 +++++++++++++++++++++++++ OpenSSL/ssl/session.c | 61 +++++++++++++++++++++++++++++------------------- OpenSSL/ssl/session.h | 2 ++ OpenSSL/test/test_ssl.py | 45 +++++++++++++++++++++++++++++++++++ 4 files changed, 116 insertions(+), 24 deletions(-) diff --git a/OpenSSL/ssl/connection.c b/OpenSSL/ssl/connection.c index 037f2a0..60887b0 100755 --- a/OpenSSL/ssl/connection.c +++ b/OpenSSL/ssl/connection.c @@ -1254,6 +1254,37 @@ ssl_Connection_want_write(ssl_ConnectionObj *self, PyObject *args) return PyLong_FromLong((long)SSL_want_write(self->ssl)); } +static char ssl_Connection_get_session_doc[] = "\n\ +Returns the Session currently used.\n\ +\n\ +@return: An instance of :py:class:`OpenSSL.SSL.Session` or :py:obj:`None` if\n\ + no session exists.\n\ +"; +static PyObject * +ssl_Connection_get_session(ssl_ConnectionObj *self, PyObject *args) { + ssl_SessionObj *session; + SSL_SESSION *native_session; + + if (!PyArg_ParseTuple(args, ":get_session")) { + return NULL; + } + + native_session = SSL_get1_session(self->ssl); + + if (native_session == NULL) { + Py_INCREF(Py_None); + return Py_None; + } + + session = ssl_Session_from_SSL_SESSION(native_session); + if (!session) { + Py_INCREF(Py_None); + return Py_None; + } + + return (PyObject*)session; +} + /* * Member methods in the Connection object * ADD_METHOD(name) expands to a correct PyMethodDef declaration @@ -1309,6 +1340,7 @@ static PyMethodDef ssl_Connection_methods[] = ADD_METHOD(want_write), ADD_METHOD(set_accept_state), ADD_METHOD(set_connect_state), + ADD_METHOD(get_session), { NULL, NULL } }; #undef ADD_ALIAS diff --git a/OpenSSL/ssl/session.c b/OpenSSL/ssl/session.c index 5cd23b6..fb9c83f 100644 --- a/OpenSSL/ssl/session.c +++ b/OpenSSL/ssl/session.c @@ -18,22 +18,15 @@ Session() -> Session instance\n\ "; /* - * Initialize an already-constructed Session instance. + * Initialize an already-constructed Session instance with an OpenSSL session + * structure (or NULL). A reference to the OpenSSL session structure is stolen. */ -static ssl_SessionObj *ssl_Session_init(ssl_SessionObj *self) { - /* - self->sess = d2i_SSL_SESSION(NULL, &buffer, len); - - if (!self->sess) { - exception_from_error_queue(ssl_Error); - return NULL; - } - */ +static ssl_SessionObj* +ssl_Session_init(ssl_SessionObj *self, SSL_SESSION *native_session) { + self->session = native_session; return self; - } - /* * Create a Session object */ @@ -50,7 +43,37 @@ ssl_Session_new(PyTypeObject *subtype, PyObject *args, PyObject *kwargs) { return NULL; } - return (PyObject *)ssl_Session_init(self); + return (PyObject *)ssl_Session_init(self, NULL); +} + +/* + * Create a Session object from an existing SSL_SESSION*. A reference to the + * SSL_SESSION* is stolen. + */ +ssl_SessionObj* +ssl_Session_from_SSL_SESSION(SSL_SESSION *native_session) { + ssl_SessionObj *self; + + self = PyObject_New(ssl_SessionObj, &ssl_Session_Type); + if (self == NULL) { + return NULL; + } + + return ssl_Session_init(self, native_session); +} + +/* + * Discard the reference to the OpenSSL session structure, if there is one, so + * that it can be freed if it is no longer in use. Also release the memory for + * the Python object. + */ +static void +ssl_Session_dealloc(ssl_SessionObj *self) { + if (self->session != NULL) { + SSL_SESSION_free(self->session); + self->session = NULL; + } + self->ob_type->tp_free((PyObject*)self); } /* @@ -63,16 +86,6 @@ ssl_Session_new(PyTypeObject *subtype, PyObject *args, PyObject *kwargs) { */ #define ADD_METHOD(name) { #name, (PyCFunction)ssl_Session_##name, METH_VARARGS, ssl_Session_##name##_doc } static PyMethodDef ssl_Session_methods[] = { -#if 0 - ADD_METHOD(asn1), - ADD_METHOD(get_time), - ADD_METHOD(get_timeout), -#ifdef SSL_SESSION_hash - ADD_METHOD(hash), -#endif - ADD_METHOD(set_time), - ADD_METHOD(set_timeout), -#endif { NULL, NULL } }; #undef ADD_METHOD @@ -85,7 +98,7 @@ PyTypeObject ssl_Session_Type = { "OpenSSL.SSL.Session", sizeof(ssl_SessionObj), 0, - NULL, // (destructor)ssl_Session_dealloc, /* tp_dealloc */ + (destructor)ssl_Session_dealloc, /* tp_dealloc */ NULL, /* print */ NULL, /* tp_getattr */ NULL, /* setattr */ diff --git a/OpenSSL/ssl/session.h b/OpenSSL/ssl/session.h index 98cee80..4e8de11 100644 --- a/OpenSSL/ssl/session.h +++ b/OpenSSL/ssl/session.h @@ -16,10 +16,12 @@ typedef struct { PyObject_HEAD + SSL_SESSION *session; } ssl_SessionObj; extern PyTypeObject ssl_Session_Type; extern int init_ssl_session(PyObject *); +extern ssl_SessionObj *ssl_Session_from_SSL_SESSION(SSL_SESSION *native_session); #endif diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py index 529a454..b3cb1cb 100644 --- a/OpenSSL/test/test_ssl.py +++ b/OpenSSL/test/test_ssl.py @@ -1428,6 +1428,51 @@ class ConnectionTests(TestCase, _LoopbackMixin): self.assertIdentical(None, server.get_peer_cert_chain()) + def test_get_session_wrong_args(self): + """ + :py:obj:`Connection.get_session` raises :py:obj:`TypeError` if called + with any arguments. + """ + ctx = Context(TLSv1_METHOD) + server = Connection(ctx, None) + self.assertRaises(TypeError, server.get_session, 123) + self.assertRaises(TypeError, server.get_session, "hello") + self.assertRaises(TypeError, server.get_session, object()) + + + def test_get_session_unconnected(self): + """ + :py:obj:`Connection.get_session` returns :py:obj:`None` when used with + an object which has not been connected. + """ + ctx = Context(TLSv1_METHOD) + server = Connection(ctx, None) + session = server.get_session() + self.assertIdentical(None, session) + + + def test_server_get_session(self): + """ + On the server side of a connection, :py:obj:`Connection.get_session` + returns a :py:class:`Session` instance representing the SSL session for + that connection. + """ + server, client = self._loopback() + session = server.get_session() + self.assertTrue(session, Session) + + + def test_client_get_session(self): + """ + On the client side of a connection, :py:obj:`Connection.get_session` + returns a :py:class:`Session` instance representing the SSL session for + that connection. + """ + server, client = self._loopback() + session = client.get_session() + self.assertTrue(session, Session) + + class ConnectionGetCipherListTests(TestCase): """ -- cgit v1.2.1