From b757beb5f326ce4a7da021d0f4c52e03e37e1945 Mon Sep 17 00:00:00 2001 From: Hugo Landau Date: Tue, 18 Apr 2023 19:30:54 +0100 Subject: QUIC TSERVER: Add support for multiple streams Reviewed-by: Matt Caswell Reviewed-by: Tomas Mraz (Merged from https://github.com/openssl/openssl/pull/20765) --- include/internal/quic_tserver.h | 14 ++++++- ssl/quic/quic_tserver.c | 93 +++++++++++++++++++++++++++++++---------- test/quic_tserver_test.c | 9 ++-- test/quicapitest.c | 9 +++- test/quicfaultstest.c | 8 ++-- 5 files changed, 98 insertions(+), 35 deletions(-) diff --git a/include/internal/quic_tserver.h b/include/internal/quic_tserver.h index 0d0d201497..fd657049ab 100644 --- a/include/internal/quic_tserver.h +++ b/include/internal/quic_tserver.h @@ -86,6 +86,7 @@ int ossl_quic_tserver_is_terminated(const QUIC_TSERVER *srv); * ossl_quic_tserver_has_read_ended() to identify this condition. */ int ossl_quic_tserver_read(QUIC_TSERVER *srv, + uint64_t stream_id, unsigned char *buf, size_t buf_len, size_t *bytes_read); @@ -93,7 +94,7 @@ int ossl_quic_tserver_read(QUIC_TSERVER *srv, /* * Returns 1 if the read part of the stream has ended normally. */ -int ossl_quic_tserver_has_read_ended(QUIC_TSERVER *srv); +int ossl_quic_tserver_has_read_ended(QUIC_TSERVER *srv, uint64_t stream_id); /* * Attempts to write to stream 0. Writes the number of bytes consumed to @@ -107,6 +108,7 @@ int ossl_quic_tserver_has_read_ended(QUIC_TSERVER *srv); * Returns 0 if connection is not currently active. */ int ossl_quic_tserver_write(QUIC_TSERVER *srv, + uint64_t stream_id, const unsigned char *buf, size_t buf_len, size_t *bytes_written); @@ -114,7 +116,15 @@ int ossl_quic_tserver_write(QUIC_TSERVER *srv, /* * Signals normal end of the stream. */ -int ossl_quic_tserver_conclude(QUIC_TSERVER *srv); +int ossl_quic_tserver_conclude(QUIC_TSERVER *srv, uint64_t stream_id); + +/* + * Create a server-initiated stream. The stream ID of the newly + * created stream is written to *stream_id. + */ +int ossl_quic_tserver_stream_new(QUIC_TSERVER *srv, + int is_uni, + uint64_t *stream_id); BIO *ossl_quic_tserver_get0_rbio(QUIC_TSERVER *srv); diff --git a/ssl/quic/quic_tserver.c b/ssl/quic/quic_tserver.c index 498ea62238..12f4516608 100644 --- a/ssl/quic/quic_tserver.c +++ b/ssl/quic/quic_tserver.c @@ -33,9 +33,6 @@ struct quic_tserver_st { /* SSL for the underlying TLS connection */ SSL *tls; - /* Our single bidirectional application data stream. */ - QUIC_STREAM *stream0; - /* The current peer L4 address. AF_UNSPEC if we do not have a peer yet. */ BIO_ADDR cur_peer_addr; @@ -104,10 +101,6 @@ QUIC_TSERVER *ossl_quic_tserver_new(const QUIC_TSERVER_ARGS *args, || !ossl_quic_channel_set_net_wbio(srv->ch, srv->args.net_wbio)) goto err; - srv->stream0 = ossl_quic_channel_get_stream_by_id(srv->ch, 0); - if (srv->stream0 == NULL) - goto err; - return srv; err: @@ -193,19 +186,40 @@ int ossl_quic_tserver_is_handshake_confirmed(const QUIC_TSERVER *srv) } int ossl_quic_tserver_read(QUIC_TSERVER *srv, + uint64_t stream_id, unsigned char *buf, size_t buf_len, size_t *bytes_read) { int is_fin = 0; + QUIC_STREAM *qs; if (!ossl_quic_channel_is_active(srv->ch)) return 0; - if (srv->stream0->recv_fin_retired) + qs = ossl_quic_stream_map_get_by_id(ossl_quic_channel_get_qsm(srv->ch), + stream_id); + if (qs == NULL) { + int is_client_init + = ((stream_id & QUIC_STREAM_INITIATOR_MASK) + == QUIC_STREAM_INITIATOR_CLIENT); + + /* + * A client-initiated stream might spontaneously come into existence, so + * allow trying to read on a client-initiated stream before it exists. + * Otherwise, fail. + */ + if (!is_client_init) + return 0; + + *bytes_read = 0; + return 1; + } + + if (qs->recv_fin_retired || qs->rstream == NULL) return 0; - if (!ossl_quic_rstream_read(srv->stream0->rstream, buf, buf_len, + if (!ossl_quic_rstream_read(qs->rstream, buf, buf_len, bytes_read, &is_fin)) return 0; @@ -220,35 +234,47 @@ int ossl_quic_tserver_read(QUIC_TSERVER *srv, ossl_statm_get_rtt_info(ossl_quic_channel_get_statm(srv->ch), &rtt_info); - if (!ossl_quic_rxfc_on_retire(&srv->stream0->rxfc, *bytes_read, + if (!ossl_quic_rxfc_on_retire(&qs->rxfc, *bytes_read, rtt_info.smoothed_rtt)) return 0; } if (is_fin) - srv->stream0->recv_fin_retired = 1; + qs->recv_fin_retired = 1; if (*bytes_read > 0) - ossl_quic_stream_map_update_state(ossl_quic_channel_get_qsm(srv->ch), - srv->stream0); + ossl_quic_stream_map_update_state(ossl_quic_channel_get_qsm(srv->ch), qs); return 1; } -int ossl_quic_tserver_has_read_ended(QUIC_TSERVER *srv) +int ossl_quic_tserver_has_read_ended(QUIC_TSERVER *srv, uint64_t stream_id) { - return srv->stream0->recv_fin_retired; + QUIC_STREAM *qs; + + qs = ossl_quic_stream_map_get_by_id(ossl_quic_channel_get_qsm(srv->ch), + stream_id); + + return qs != NULL && qs->recv_fin_retired; } int ossl_quic_tserver_write(QUIC_TSERVER *srv, + uint64_t stream_id, const unsigned char *buf, size_t buf_len, size_t *bytes_written) { + QUIC_STREAM *qs; + if (!ossl_quic_channel_is_active(srv->ch)) return 0; - if (!ossl_quic_sstream_append(srv->stream0->sstream, + qs = ossl_quic_stream_map_get_by_id(ossl_quic_channel_get_qsm(srv->ch), + stream_id); + if (qs == NULL || qs->sstream == NULL) + return 0; + + if (!ossl_quic_sstream_append(qs->sstream, buf, buf_len, bytes_written)) return 0; @@ -257,29 +283,50 @@ int ossl_quic_tserver_write(QUIC_TSERVER *srv, * We have appended at least one byte to the stream. Potentially mark * the stream as active, depending on FC. */ - ossl_quic_stream_map_update_state(ossl_quic_channel_get_qsm(srv->ch), - srv->stream0); + ossl_quic_stream_map_update_state(ossl_quic_channel_get_qsm(srv->ch), qs); /* Try and send. */ ossl_quic_tserver_tick(srv); return 1; } -int ossl_quic_tserver_conclude(QUIC_TSERVER *srv) +int ossl_quic_tserver_conclude(QUIC_TSERVER *srv, uint64_t stream_id) { + QUIC_STREAM *qs; + if (!ossl_quic_channel_is_active(srv->ch)) return 0; - if (!ossl_quic_sstream_get_final_size(srv->stream0->sstream, NULL)) { - ossl_quic_sstream_fin(srv->stream0->sstream); - ossl_quic_stream_map_update_state(ossl_quic_channel_get_qsm(srv->ch), - srv->stream0); + qs = ossl_quic_stream_map_get_by_id(ossl_quic_channel_get_qsm(srv->ch), + stream_id); + if (qs == NULL || qs->sstream == NULL) + return 0; + + if (!ossl_quic_sstream_get_final_size(qs->sstream, NULL)) { + ossl_quic_sstream_fin(qs->sstream); + ossl_quic_stream_map_update_state(ossl_quic_channel_get_qsm(srv->ch), qs); } ossl_quic_tserver_tick(srv); return 1; } +int ossl_quic_tserver_stream_new(QUIC_TSERVER *srv, + int is_uni, + uint64_t *stream_id) +{ + QUIC_STREAM *qs; + + if (!ossl_quic_channel_is_active(srv->ch)) + return 0; + + if ((qs = ossl_quic_channel_new_stream_local(srv->ch, is_uni)) == NULL) + return 0; + + *stream_id = qs->id; + return 1; +} + BIO *ossl_quic_tserver_get0_rbio(QUIC_TSERVER *srv) { return srv->args.net_rbio; diff --git a/test/quic_tserver_test.c b/test/quic_tserver_test.c index a385381716..e9ae4703b2 100644 --- a/test/quic_tserver_test.c +++ b/test/quic_tserver_test.c @@ -215,16 +215,17 @@ static int do_test(int use_thread_assist, int use_fake_time, int use_inject) } if (c_connected && c_write_done && !s_read_done) { - if (!ossl_quic_tserver_read(tserver, + if (!ossl_quic_tserver_read(tserver, 0, (unsigned char *)msg2 + s_total_read, sizeof(msg2) - s_total_read, &l)) { - if (!TEST_true(ossl_quic_tserver_has_read_ended(tserver))) + if (!TEST_true(ossl_quic_tserver_has_read_ended(tserver, 0))) goto err; if (!TEST_mem_eq(msg1, sizeof(msg1) - 1, msg2, s_total_read)) goto err; s_begin_write = 1; + s_read_done = 1; } else { s_total_read += l; if (!TEST_size_t_le(s_total_read, sizeof(msg1) - 1)) @@ -233,7 +234,7 @@ static int do_test(int use_thread_assist, int use_fake_time, int use_inject) } if (s_begin_write && s_total_written < sizeof(msg1) - 1) { - if (!TEST_true(ossl_quic_tserver_write(tserver, + if (!TEST_true(ossl_quic_tserver_write(tserver, 0, (unsigned char *)msg2 + s_total_written, sizeof(msg1) - 1 - s_total_written, &l))) goto err; @@ -241,7 +242,7 @@ static int do_test(int use_thread_assist, int use_fake_time, int use_inject) s_total_written += l; if (s_total_written == sizeof(msg1) - 1) { - ossl_quic_tserver_conclude(tserver); + ossl_quic_tserver_conclude(tserver, 0); c_begin_read = 1; } } diff --git a/test/quicapitest.c b/test/quicapitest.c index 092e303ba6..3ce695e5e6 100644 --- a/test/quicapitest.c +++ b/test/quicapitest.c @@ -42,6 +42,7 @@ static int test_quic_write_read(int idx) size_t msglen = strlen(msg); size_t numbytes = 0; int ssock = 0, csock = 0; + uint64_t sid = UINT64_MAX; if (idx == 1 && !qtest_supports_blocking()) return TEST_skip("Blocking tests not supported in this build"); @@ -61,6 +62,10 @@ static int test_quic_write_read(int idx) goto end; } + if (!TEST_true(ossl_quic_tserver_stream_new(qtserv, /*is_uni=*/0, &sid)) + || !TEST_uint64_t_eq(sid, 1)) /* server-initiated, so first SID is 1 */ + goto end; + for (j = 0; j < 2; j++) { /* Check that sending and receiving app data is ok */ if (!TEST_true(SSL_write_ex(clientquic, msg, msglen, &numbytes))) @@ -72,7 +77,7 @@ static int test_quic_write_read(int idx) ossl_quic_tserver_tick(qtserv); - if (!TEST_true(ossl_quic_tserver_read(qtserv, buf, sizeof(buf), + if (!TEST_true(ossl_quic_tserver_read(qtserv, sid, buf, sizeof(buf), &numbytes))) goto end; } while (numbytes == 0); @@ -81,7 +86,7 @@ static int test_quic_write_read(int idx) goto end; } - if (!TEST_true(ossl_quic_tserver_write(qtserv, (unsigned char *)msg, + if (!TEST_true(ossl_quic_tserver_write(qtserv, sid, (unsigned char *)msg, msglen, &numbytes))) goto end; ossl_quic_tserver_tick(qtserv); diff --git a/test/quicfaultstest.c b/test/quicfaultstest.c index beb3e4dc41..fbbbad4dd6 100644 --- a/test/quicfaultstest.c +++ b/test/quicfaultstest.c @@ -45,7 +45,7 @@ static int test_basic(void) goto err; ossl_quic_tserver_tick(qtserv); - if (!TEST_true(ossl_quic_tserver_read(qtserv, buf, sizeof(buf), &bytesread))) + if (!TEST_true(ossl_quic_tserver_read(qtserv, 0, buf, sizeof(buf), &bytesread))) goto err; /* @@ -119,7 +119,7 @@ static int test_unknown_frame(void) NULL))) goto err; - if (!TEST_true(ossl_quic_tserver_write(qtserv, (unsigned char *)msg, msglen, + if (!TEST_true(ossl_quic_tserver_write(qtserv, 0, (unsigned char *)msg, msglen, &byteswritten))) goto err; @@ -294,7 +294,7 @@ static int test_corrupted_data(int idx) * Send first 5 bytes of message. This will get corrupted and is treated as * "lost" */ - if (!TEST_true(ossl_quic_tserver_write(qtserv, (unsigned char *)msg, 5, + if (!TEST_true(ossl_quic_tserver_write(qtserv, 0, (unsigned char *)msg, 5, &byteswritten))) goto err; @@ -317,7 +317,7 @@ static int test_corrupted_data(int idx) OSSL_sleep(100); /* Send rest of message */ - if (!TEST_true(ossl_quic_tserver_write(qtserv, (unsigned char *)msg + 5, + if (!TEST_true(ossl_quic_tserver_write(qtserv, 0, (unsigned char *)msg + 5, msglen - 5, &byteswritten))) goto err; -- cgit v1.2.1