From f0336417cae1c32f4ea59a9f9851a15f269340f7 Mon Sep 17 00:00:00 2001 From: tokcum <47994370+tokcum@users.noreply.github.com> Date: Wed, 30 Mar 2022 11:39:08 +0200 Subject: THRIFT-5283: add support for Unix Domain Sockets in lib/rs (#2545) Client: rs --- test/rs/src/bin/test_client.rs | 104 ++++++++++++++++++++++++++++++++--------- test/rs/src/bin/test_server.rs | 23 ++++++--- test/rs/src/lib.rs | 4 -- test/tests.json | 3 +- 4 files changed, 101 insertions(+), 33 deletions(-) (limited to 'test') diff --git a/test/rs/src/bin/test_client.rs b/test/rs/src/bin/test_client.rs index 8623915d4..8274aaeb2 100644 --- a/test/rs/src/bin/test_client.rs +++ b/test/rs/src/bin/test_client.rs @@ -21,7 +21,12 @@ use log::*; use std::collections::{BTreeMap, BTreeSet}; use std::fmt::Debug; -use std::net::TcpStream; +use std::net::{TcpStream, ToSocketAddrs}; + +#[cfg(unix)] +use std::os::unix::net::UnixStream; +#[cfg(unix)] +use std::path::Path; use thrift; use thrift::protocol::{ @@ -35,6 +40,11 @@ use thrift::transport::{ use thrift::OrderedFloat; use thrift_test::*; +type ThriftClientPair = ( + ThriftTestSyncClient, Box>, + Option, Box>>, +); + fn main() { env_logger::init(); @@ -51,7 +61,6 @@ fn main() { fn run() -> thrift::Result<()> { // unsupported options: - // --domain-socket // --pipe // --anon-pipes // --ssl @@ -62,56 +71,107 @@ fn run() -> thrift::Result<()> { (about: "Rust Thrift test client") (@arg host: --host +takes_value "Host on which the Thrift test server is located") (@arg port: --port +takes_value "Port on which the Thrift test server is listening") - (@arg transport: --transport +takes_value "Thrift transport implementation to use (\"buffered\", \"framed\")") + (@arg domain_socket: --("domain-socket") +takes_value "Unix Domain Socket on which the Thrift test server is listening") (@arg protocol: --protocol +takes_value "Thrift protocol implementation to use (\"binary\", \"compact\", \"multi\", \"multic\")") + (@arg transport: --transport +takes_value "Thrift transport implementation to use (\"buffered\", \"framed\")") (@arg testloops: -n --testloops +takes_value "Number of times to run tests") ) .get_matches(); let host = matches.value_of("host").unwrap_or("127.0.0.1"); let port = value_t!(matches, "port", u16).unwrap_or(9090); - let testloops = value_t!(matches, "testloops", u8).unwrap_or(1); - let transport = matches.value_of("transport").unwrap_or("buffered"); + let domain_socket = matches.value_of("domain_socket"); let protocol = matches.value_of("protocol").unwrap_or("binary"); + let transport = matches.value_of("transport").unwrap_or("buffered"); + let testloops = value_t!(matches, "testloops", u8).unwrap_or(1); + + let (mut thrift_test_client, mut second_service_client) = match domain_socket { + None => { + let listen_address = format!("{}:{}", host, port); + info!( + "Client binds to {} with {}+{} stack", + listen_address, protocol, transport + ); + bind(listen_address.as_str(), protocol, transport)? + } + Some(domain_socket) => { + info!( + "Client binds to {} (UDS) with {}+{} stack", + domain_socket, protocol, transport + ); + bind_uds(domain_socket, protocol, transport)? + } + }; + + for _ in 0..testloops { + make_thrift_calls(&mut thrift_test_client, &mut second_service_client)? + } + Ok(()) +} + +fn bind( + listen_address: A, + protocol: &str, + transport: &str, +) -> Result { // create a TCPStream that will be shared by all Thrift clients // service calls from multiple Thrift clients will be interleaved over the same connection // this isn't a problem for us because we're single-threaded and all calls block to completion - let shared_stream = TcpStream::connect(format!("{}:{}", host, port))?; + let shared_stream = TcpStream::connect(listen_address)?; - let mut second_service_client = if protocol.starts_with("multi") { + let second_service_client = if protocol.starts_with("multi") { let shared_stream_clone = shared_stream.try_clone()?; - let (i_prot, o_prot) = build(shared_stream_clone, transport, protocol, "SecondService")?; + let channel = TTcpChannel::with_stream(shared_stream_clone); + let (i_prot, o_prot) = build(channel, transport, protocol, "SecondService")?; Some(SecondServiceSyncClient::new(i_prot, o_prot)) } else { None }; - let mut thrift_test_client = { - let (i_prot, o_prot) = build(shared_stream, transport, protocol, "ThriftTest")?; + let thrift_test_client = { + let channel = TTcpChannel::with_stream(shared_stream); + let (i_prot, o_prot) = build(channel, transport, protocol, "ThriftTest")?; ThriftTestSyncClient::new(i_prot, o_prot) }; - info!( - "connecting to {}:{} with {}+{} stack", - host, port, protocol, transport - ); + Ok((thrift_test_client, second_service_client)) +} - for _ in 0..testloops { - make_thrift_calls(&mut thrift_test_client, &mut second_service_client)? - } +#[cfg(unix)] +fn bind_uds>( + domain_socket: P, + protocol: &str, + transport: &str, +) -> Result { + // create a UnixStream that will be shared by all Thrift clients + // service calls from multiple Thrift clients will be interleaved over the same connection + // this isn't a problem for us because we're single-threaded and all calls block to completion + let shared_stream = UnixStream::connect(domain_socket)?; - Ok(()) + let second_service_client = if protocol.starts_with("multi") { + let shared_stream_clone = shared_stream.try_clone()?; + let (i_prot, o_prot) = build(shared_stream_clone, transport, protocol, "SecondService")?; + Some(SecondServiceSyncClient::new(i_prot, o_prot)) + } else { + None + }; + + let thrift_test_client = { + let (i_prot, o_prot) = build(shared_stream, transport, protocol, "ThriftTest")?; + ThriftTestSyncClient::new(i_prot, o_prot) + }; + + Ok((thrift_test_client, second_service_client)) } -fn build( - stream: TcpStream, +fn build( + channel: C, transport: &str, protocol: &str, service_name: &str, ) -> thrift::Result<(Box, Box)> { - let c = TTcpChannel::with_stream(stream); - let (i_chan, o_chan) = c.split()?; + let (i_chan, o_chan) = channel.split()?; let (i_tran, o_tran): (Box, Box) = match transport { "buffered" => ( diff --git a/test/rs/src/bin/test_server.rs b/test/rs/src/bin/test_server.rs index 6a05e79e5..7e6d08f1c 100644 --- a/test/rs/src/bin/test_server.rs +++ b/test/rs/src/bin/test_server.rs @@ -52,7 +52,6 @@ fn main() { fn run() -> thrift::Result<()> { // unsupported options: - // --domain-socket // --pipe // --ssl let matches = clap_app!(rust_test_client => @@ -60,21 +59,26 @@ fn run() -> thrift::Result<()> { (author: "Apache Thrift Developers ") (about: "Rust Thrift test server") (@arg port: --port +takes_value "port on which the test server listens") + (@arg domain_socket: --("domain-socket") +takes_value "Unix Domain Socket on which the test server listens") (@arg transport: --transport +takes_value "transport implementation to use (\"buffered\", \"framed\")") (@arg protocol: --protocol +takes_value "protocol implementation to use (\"binary\", \"compact\")") - (@arg server_type: --server_type +takes_value "type of server instantiated (\"simple\", \"thread-pool\")") + (@arg server_type: --("server-type") +takes_value "type of server instantiated (\"simple\", \"thread-pool\")") (@arg workers: -n --workers +takes_value "number of thread-pool workers (\"4\")") ) - .get_matches(); + .get_matches(); let port = value_t!(matches, "port", u16).unwrap_or(9090); + let domain_socket = matches.value_of("domain_socket"); let transport = matches.value_of("transport").unwrap_or("buffered"); let protocol = matches.value_of("protocol").unwrap_or("binary"); let server_type = matches.value_of("server_type").unwrap_or("thread-pool"); let workers = value_t!(matches, "workers", usize).unwrap_or(4); let listen_address = format!("127.0.0.1:{}", port); - info!("binding to {}", listen_address); + match domain_socket { + None => info!("Server is binding to {}", listen_address), + Some(domain_socket) => info!("Server is binding to {} (UDS)", domain_socket), + } let (i_transport_factory, o_transport_factory): ( Box, @@ -135,7 +139,10 @@ fn run() -> thrift::Result<()> { workers, ); - server.listen(&listen_address) + match domain_socket { + None => server.listen(&listen_address), + Some(domain_socket) => server.listen_uds(domain_socket), + } } else { let mut server = TServer::new( i_transport_factory, @@ -146,9 +153,13 @@ fn run() -> thrift::Result<()> { workers, ); - server.listen(&listen_address) + match domain_socket { + None => server.listen(&listen_address), + Some(domain_socket) => server.listen_uds(domain_socket), + } } } + unknown => Err(format!("unsupported server type {}", unknown).into()), } } diff --git a/test/rs/src/lib.rs b/test/rs/src/lib.rs index 3c7cfc09e..9cfd7a66f 100644 --- a/test/rs/src/lib.rs +++ b/test/rs/src/lib.rs @@ -15,9 +15,5 @@ // specific language governing permissions and limitations // under the License. - - - - mod thrift_test; pub use crate::thrift_test::*; diff --git a/test/tests.json b/test/tests.json index a8dbef7d4..3563dc9ab 100644 --- a/test/tests.json +++ b/test/tests.json @@ -679,7 +679,8 @@ ] }, "sockets": [ - "ip" + "ip", + "domain" ], "transports": [ "buffered", -- cgit v1.2.1