diff options
Diffstat (limited to 'lib/ssl/src/inet_tls_dist.erl')
-rw-r--r-- | lib/ssl/src/inet_tls_dist.erl | 825 |
1 files changed, 507 insertions, 318 deletions
diff --git a/lib/ssl/src/inet_tls_dist.erl b/lib/ssl/src/inet_tls_dist.erl index 9118fb59f6..c93bb27596 100644 --- a/lib/ssl/src/inet_tls_dist.erl +++ b/lib/ssl/src/inet_tls_dist.erl @@ -20,19 +20,24 @@ %% -module(inet_tls_dist). +-feature(maybe_expr, enable). -export([childspecs/0]). --export([listen/2, accept/1, accept_connection/5, - setup/5, close/1, select/1, address/0, is_node_name/1]). +-export([select/1, address/0, is_node_name/1, + listen/2, accept/1, accept_connection/5, + setup/5, close/1]). %% Generalized dist API --export([gen_listen/3, gen_accept/2, gen_accept_connection/6, - gen_setup/6, gen_close/2, gen_select/2, gen_address/1]). - --export([nodelay/0]). +-export([fam_select/2, fam_address/1, fam_listen/3, fam_accept/2, + fam_accept_connection/6, fam_setup/6]). -export([verify_client/3, cert_nodes/1]). +%% kTLS helpers +-export([inet_ktls_setopt/3, inet_ktls_getopt/3, + set_ktls/1, set_ktls_ulp/2, set_ktls_cipher/5, + ktls_os/0, ktls_opt_ulp/1, ktls_opt_cipher/6]). + -export([dbg/0]). % Debug -include_lib("kernel/include/net_address.hrl"). @@ -41,37 +46,68 @@ -include_lib("public_key/include/public_key.hrl"). -include("ssl_api.hrl"). +-include("ssl_cipher.hrl"). +-include("ssl_internal.hrl"). +-include("ssl_record.hrl"). -include_lib("kernel/include/logger.hrl"). +-define(FAMILY, inet). +-define(DRIVER, inet_tcp). % Implies ?FAMILY = inet through inet_drv.c +-define(PROTOCOL, tls). + %% ------------------------------------------------------------------------- childspecs() -> {ok, [{ssl_dist_sup,{ssl_dist_sup, start_link, []}, permanent, infinity, supervisor, [ssl_dist_sup]}]}. +%% ------------------------------------------------------------------------- +%% Select this protocol based on node name select(Node) -> - gen_select(inet_tcp, Node). - -gen_select(Driver, Node) -> - inet_tcp_dist:gen_select(Driver, Node). - -%% ------------------------------------------------------------ -%% Get the address family that this distribution uses -%% ------------------------------------------------------------ + fam_select(?FAMILY, Node). +fam_select(Family, Node) -> + inet_tcp_dist:fam_select(Family, Node). +%% ------------------------------------------------------------------------- +%% Get the #net_address this distribution uses address() -> - gen_address(inet_tcp). -gen_address(Driver) -> - inet_tcp_dist:gen_address(Driver). - + fam_address(?FAMILY). +fam_address(Family) -> + NetAddress = inet_tcp_dist:fam_address(Family), + NetAddress#net_address{ protocol = ?PROTOCOL }. %% ------------------------------------------------------------------------- - +%% Is this one really needed?? is_node_name(Node) -> dist_util:is_node_name(Node). - %% ------------------------------------------------------------------------- -hs_data_common(#sslsocket{pid = [_, DistCtrl|_]} = SslSocket) -> +hs_data_inet_tcp(Driver, Socket) -> + Family = Driver:family(), + {ok, Peername} = + maybe + {error, einval} ?= inet:peername(Socket), + ?shutdown({Driver, closed}) + end, + (inet_tcp_dist:gen_hs_data(Driver, Socket)) + #hs_data{ + f_address = + fun(_, Node) -> + {node, _, Host} = dist_util:split_node(Node), + #net_address{ + address = Peername, + host = Host, + protocol = ?PROTOCOL, + family = Family + } + end}. + +hs_data_ssl(Family, #sslsocket{pid = [_, DistCtrl|_]} = SslSocket) -> + {ok, Address} = + maybe + {error, einval} ?= ssl:peername(SslSocket), + ?shutdown({sslsocket, closed}) + end, #hs_data{ + socket = DistCtrl, f_send = fun (_Ctrl, Packet) -> f_send(SslSocket, Packet) @@ -95,7 +131,7 @@ hs_data_common(#sslsocket{pid = [_, DistCtrl|_]} = SslSocket) -> end, f_address = fun (Ctrl, Node) when Ctrl == DistCtrl -> - f_address(SslSocket, Node) + f_address(Family, Address, Node) end, mf_tick = fun (Ctrl) when Ctrl == DistCtrl -> @@ -133,22 +169,21 @@ f_setopts_pre_nodeup(_SslSocket) -> ok. f_setopts_post_nodeup(SslSocket) -> - ssl:setopts(SslSocket, [nodelay()]). + ssl:setopts(SslSocket, [inet_tcp_dist:nodelay()]). f_getll(DistCtrl) -> {ok, DistCtrl}. -f_address(SslSocket, Node) -> - case ssl:peername(SslSocket) of - {ok, Address} -> - case dist_util:split_node(Node) of - {node,_,Host} -> - #net_address{ - address=Address, host=Host, - protocol=tls, family=inet}; - _ -> - {error, no_node} - end +f_address(Family, Address, Node) -> + case dist_util:split_node(Node) of + {node,_,Host} -> + #net_address{ + address = Address, + host = Host, + protocol = ?PROTOCOL, + family = Family}; + _ -> + {error, no_node} end. mf_tick(DistCtrl) -> @@ -194,52 +229,79 @@ split_stat([], R, W, P) -> %% ------------------------------------------------------------------------- listen(Name, Host) -> - gen_listen(inet_tcp, Name, Host). - -gen_listen(Driver, Name, Host) -> - case inet_tcp_dist:gen_listen(Driver, Name, Host) of - {ok, {Socket, Address, Creation}} -> - inet:setopts(Socket, [{packet, 4}, {nodelay, true}]), - {ok, {Socket, Address#net_address{protocol=tls}, Creation}}; - Other -> - Other + fam_listen(?FAMILY, Name, Host). + +fam_listen(Family, Name, Host) -> + ForcedOptions = + [Family, {active, false}, {packet, 4}, {nodelay, true}], + ListenFun = + fun (First, Last, ListenOptions) -> + listen_loop( + First, Last, + inet_tcp_dist:merge_options(ListenOptions, ForcedOptions)) + end, + maybe + %% + {ok, {ListenSocket, Address, Creation}} ?= + inet_tcp_dist:fam_listen(Family, Name, Host, ListenFun), + NetAddress = + #net_address{ + host = Host, + protocol = ?PROTOCOL, + family = Family, + address = Address}, + {ok, {ListenSocket, NetAddress, Creation}} end. +listen_loop(First, Last, ListenOptions) when First =< Last -> + case gen_tcp:listen(First, ListenOptions) of + {error, eaddrinuse} -> + listen_loop(First + 1, Last, ListenOptions); + Result -> + Result + end; +listen_loop(_, _, _) -> + {error, eaddrinuse}. + %% ------------------------------------------------------------------------- -accept(Listen) -> - gen_accept(inet_tcp, Listen). +accept(ListenSocket) -> + fam_accept(?FAMILY, ListenSocket). -gen_accept(Driver, Listen) -> - Kernel = self(), +fam_accept(Family, ListenSocket) -> + NetKernel = self(), monitor_pid( spawn_opt( fun () -> - process_flag(trap_exit, true), - LOpts = application:get_env(kernel, inet_dist_listen_options, []), - MaxPending = - case lists:keyfind(backlog, 1, LOpts) of - {backlog, Backlog} -> Backlog; - false -> 128 - end, - DLK = {Driver, Listen, Kernel}, - accept_loop(DLK, spawn_accept(DLK), MaxPending, #{}) + process_flag(trap_exit, true), + MaxPending = erlang:system_info(schedulers_online), + Continue = make_ref(), + FLNC = {Family, ListenSocket, NetKernel, Continue}, + Pending = #{}, + accept_loop( + FLNC, Continue, spawn_accept(FLNC), MaxPending, + Pending) end, - [link, {priority, max}])). + dist_util:net_ticker_spawn_options())). %% Concurrent accept loop will spawn a new HandshakePid when %% there is no HandshakePid already running, and Pending map is %% smaller than MaxPending -accept_loop(DLK, undefined, MaxPending, Pending) when map_size(Pending) < MaxPending -> - accept_loop(DLK, spawn_accept(DLK), MaxPending, Pending); -accept_loop({_, _, NetKernelPid} = DLK, HandshakePid, MaxPending, Pending) -> +accept_loop(FLNC, Continue, undefined, MaxPending, Pending) + when map_size(Pending) < MaxPending -> + accept_loop(FLNC, Continue, spawn_accept(FLNC), MaxPending, Pending); +accept_loop({_, _, NetKernelPid, _} = FLNC, Continue, HandshakePid, MaxPending, Pending) -> receive - {continue, HandshakePid} when is_pid(HandshakePid) -> - accept_loop(DLK, undefined, MaxPending, Pending#{HandshakePid => true}); + {Continue, HandshakePid} when is_pid(HandshakePid) -> + accept_loop( + FLNC, Continue, undefined, MaxPending, + Pending#{HandshakePid => true}); {'EXIT', Pid, Reason} when is_map_key(Pid, Pending) -> Reason =/= normal andalso ?LOG_ERROR("TLS distribution handshake failed: ~p~n", [Reason]), - accept_loop(DLK, HandshakePid, MaxPending, maps:remove(Pid, Pending)); + accept_loop( + FLNC, Continue, HandshakePid, MaxPending, + maps:remove(Pid, Pending)); {'EXIT', HandshakePid, Reason} when is_pid(HandshakePid) -> %% HandshakePid crashed before turning into Pending, which means %% error happened in accept. Need to restart the listener. @@ -248,20 +310,21 @@ accept_loop({_, _, NetKernelPid} = DLK, HandshakePid, MaxPending, Pending) -> %% Since we're trapping exits, need to manually propagate this signal exit(Reason); Unexpected -> - ?LOG_WARNING("TLS distribution: unexpected message: ~p~n" ,[Unexpected]), - accept_loop(DLK, HandshakePid, MaxPending, Pending) + ?LOG_WARNING( + "TLS distribution: unexpected message: ~p~n", [Unexpected]), + accept_loop(FLNC, Continue, HandshakePid, MaxPending, Pending) end. -spawn_accept({Driver, Listen, Kernel}) -> +spawn_accept({Family, ListenSocket, NetKernel, Continue}) -> AcceptLoop = self(), spawn_link( fun () -> - case Driver:accept(Listen) of + case gen_tcp:accept(ListenSocket) of {ok, Socket} -> - AcceptLoop ! {continue, self()}, - case check_ip(Driver, Socket) of + AcceptLoop ! {Continue, self()}, + case check_ip(Socket) of true -> - accept_one(Driver, Kernel, Socket); + accept_one(Family, Socket, NetKernel); {false,IP} -> ?LOG_ERROR( "** Connection attempt from " @@ -273,33 +336,37 @@ spawn_accept({Driver, Listen, Kernel}) -> end end). -accept_one(Driver, Kernel, Socket) -> +accept_one(Family, Socket, NetKernel) -> Opts = setup_verify_client(Socket, get_ssl_options(server)), - wait_for_code_server(), + KTLS = proplists:get_value(ktls, Opts, false), case ssl:handshake( Socket, trace([{active, false},{packet, 4}|Opts]), net_kernel:connecttime()) of - {ok, #sslsocket{pid = [_, DistCtrl| _]} = SslSocket} -> - trace( - Kernel ! - {accept, self(), DistCtrl, - Driver:family(), tls}), - receive - {Kernel, controller, Pid} -> - case ssl:controlling_process(SslSocket, Pid) of + {ok, SslSocket} -> + Receiver = hd(SslSocket#sslsocket.pid), + case KTLS of + true -> + {ok, KtlsInfo} = ssl_gen_statem:ktls_handover(Receiver), + case inet_set_ktls(KtlsInfo) of ok -> - trace(Pid ! {self(), controller}); - Error -> - trace(Pid ! {self(), exit}), + accept_one( + Family, maps:get(socket, KtlsInfo), NetKernel, + fun gen_tcp:controlling_process/2); + {error, KtlsReason} -> ?LOG_ERROR( - "Cannot control TLS distribution connection: ~p~n", - [Error]) + [{slogan, set_ktls_failed}, + {reason, KtlsReason}, + {pid, self()}]), + close(Socket), + trace({ktls_error, KtlsReason}) end; - {Kernel, unsupported_protocol} -> - trace(unsupported_protocol) + false -> + accept_one( + Family, SslSocket, NetKernel, + fun ssl:controlling_process/2) end; {error, {options, _}} = Error -> %% Bad options: that's probably our fault. @@ -307,12 +374,31 @@ accept_one(Driver, Kernel, Socket) -> ?LOG_ERROR( "Cannot accept TLS distribution connection: ~s~n", [ssl:format_error(Error)]), - gen_tcp:close(Socket), + close(Socket), trace(Error); Other -> - gen_tcp:close(Socket), + close(Socket), trace(Other) end. +%% +accept_one( + Family, DistSocket, NetKernel, ControllingProcessFun) -> + trace(NetKernel ! {accept, self(), DistSocket, Family, ?PROTOCOL}), + receive + {NetKernel, controller, Pid} -> + case ControllingProcessFun(DistSocket, Pid) of + ok -> + trace(Pid ! {self(), controller}); + {error, Reason} -> + trace(Pid ! {self(), exit}), + ?LOG_ERROR( + [{slogan, controlling_process_failed}, + {reason, Reason}, + {pid, self()}]) + end; + {NetKernel, unsupported_protocol} -> + trace(unsupported_protocol) + end. %% {verify_fun,{fun ?MODULE:verify_client/3,_}} is used @@ -398,72 +484,50 @@ verify_client(PeerCert, valid_peer, {AllowedHosts,PeerIP} = S) -> end. -wait_for_code_server() -> - %% This is an ugly hack. Upgrading a socket to TLS requires the - %% crypto module to be loaded. Loading the crypto module triggers - %% its on_load function, which calls code:priv_dir/1 to find the - %% directory where its NIF library is. However, distribution is - %% started earlier than the code server, so the code server is not - %% necessarily started yet, and code:priv_dir/1 might fail because - %% of that, if we receive an incoming connection on the - %% distribution port early enough. - %% - %% If the on_load function of a module fails, the module is - %% unloaded, and the function call that triggered loading it fails - %% with 'undef', which is rather confusing. - %% - %% Thus, the accept process will terminate, and be - %% restarted by ssl_dist_sup. However, it won't have any memory - %% of being asked by net_kernel to listen for incoming - %% connections. Hence, the node will believe that it's open for - %% distribution, but it actually isn't. - %% - %% So let's avoid that by waiting for the code server to start. - case whereis(code_server) of - undefined -> - timer:sleep(10), - wait_for_code_server(); - Pid when is_pid(Pid) -> - ok - end. - %% ------------------------------------------------------------------------- -accept_connection(AcceptPid, DistCtrl, MyNode, Allowed, SetupTime) -> - gen_accept_connection( - inet_tcp, AcceptPid, DistCtrl, MyNode, Allowed, SetupTime). +accept_connection(AcceptPid, DistSocket, MyNode, Allowed, SetupTime) -> + fam_accept_connection( + ?FAMILY, AcceptPid, DistSocket, MyNode, Allowed, SetupTime). -gen_accept_connection( - Driver, AcceptPid, DistCtrl, MyNode, Allowed, SetupTime) -> +fam_accept_connection( + Family, AcceptPid, DistSocket, MyNode, Allowed, SetupTime) -> Kernel = self(), monitor_pid( spawn_opt( fun() -> do_accept( - Driver, AcceptPid, DistCtrl, + Family, AcceptPid, DistSocket, MyNode, Allowed, SetupTime, Kernel) end, dist_util:net_ticker_spawn_options())). do_accept( - _Driver, AcceptPid, DistCtrl, MyNode, Allowed, SetupTime, Kernel) -> + Family, AcceptPid, DistSocket, MyNode, Allowed, SetupTime, Kernel) -> MRef = erlang:monitor(process, AcceptPid), receive {AcceptPid, controller} -> erlang:demonitor(MRef, [flush]), - {ok, SslSocket} = tls_sender:dist_tls_socket(DistCtrl), - Timer = dist_util:start_timer(SetupTime), - NewAllowed = allowed_nodes(SslSocket, Allowed), - HSData0 = hs_data_common(SslSocket), + Timer = dist_util:start_timer(SetupTime), + {HSData0, NewAllowed} = + case DistSocket of + SslSocket = #sslsocket{pid = [_Receiver, Sender| _]} -> + link(Sender), + {hs_data_ssl(Family, SslSocket), + allowed_nodes(SslSocket, Allowed)}; + PortSocket when is_port(DistSocket) -> + %%% XXX Breaking abstraction barrier + Driver = erlang:port_get_data(PortSocket), + {hs_data_inet_tcp(Driver, PortSocket), + Allowed} + end, HSData = HSData0#hs_data{ kernel_pid = Kernel, this_node = MyNode, - socket = DistCtrl, timer = Timer, this_flags = 0, allowed = NewAllowed}, - link(DistCtrl), dist_util:handshake_other_started(trace(HSData)); {AcceptPid, exit} -> %% this can happen when connection was initiated, but dropped @@ -535,138 +599,147 @@ allowed_nodes(PeerCert, Allowed, PeerIP, Node, Host) -> allowed_nodes(PeerCert, Allowed, PeerIP) end. + +%% ------------------------------------------------------------------------- + setup(Node, Type, MyNode, LongOrShortNames, SetupTime) -> - gen_setup(inet_tcp, Node, Type, MyNode, LongOrShortNames, SetupTime). + fam_setup(?FAMILY, Node, Type, MyNode, LongOrShortNames, SetupTime). -gen_setup(Driver, Node, Type, MyNode, LongOrShortNames, SetupTime) -> - Kernel = self(), +fam_setup(Family, Node, Type, MyNode, LongOrShortNames, SetupTime) -> + NetKernel = self(), monitor_pid( - spawn_opt(setup_fun(Driver, Kernel, Node, Type, MyNode, LongOrShortNames, SetupTime), - dist_util:net_ticker_spawn_options())). + spawn_opt( + setup_fun( + Family, Node, Type, MyNode, LongOrShortNames, SetupTime, + NetKernel), + dist_util:net_ticker_spawn_options())). -spec setup_fun(_,_,_,_,_,_,_) -> fun(() -> no_return()). -setup_fun(Driver, Kernel, Node, Type, MyNode, LongOrShortNames, SetupTime) -> +setup_fun( + Family, Node, Type, MyNode, LongOrShortNames, SetupTime, NetKernel) -> fun() -> do_setup( - Driver, Kernel, Node, Type, - MyNode, LongOrShortNames, SetupTime) + Family, Node, Type, MyNode, LongOrShortNames, SetupTime, + NetKernel) end. - -spec do_setup(_,_,_,_,_,_,_) -> no_return(). -do_setup(Driver, Kernel, Node, Type, MyNode, LongOrShortNames, SetupTime) -> - {Name, Address} = split_node(Driver, Node, LongOrShortNames), - ErlEpmd = net_kernel:epmd_module(), - {ARMod, ARFun} = get_address_resolver(ErlEpmd, Driver), +do_setup( + Family, Node, Type, MyNode, LongOrShortNames, SetupTime, NetKernel) -> Timer = trace(dist_util:start_timer(SetupTime)), - case ARMod:ARFun(Name,Address,Driver:family()) of - {ok, Ip, TcpPort, Version} -> - do_setup_connect(Driver, Kernel, Node, Address, Ip, TcpPort, Version, Type, MyNode, Timer); - {ok, Ip} -> - case ErlEpmd:port_please(Name, Ip) of - {port, TcpPort, Version} -> - do_setup_connect(Driver, Kernel, Node, Address, Ip, TcpPort, Version, Type, MyNode, Timer); - Other -> - ?shutdown2( - Node, - trace( - {port_please_failed, ErlEpmd, Name, Ip, Other})) - end; - Other -> - ?shutdown2( - Node, - trace({getaddr_failed, Driver, Address, Other})) - end. - --spec do_setup_connect(_,_,_,_,_,_,_,_,_,_) -> no_return(). - -do_setup_connect(Driver, Kernel, Node, Address, Ip, TcpPort, Version, Type, MyNode, Timer) -> - Opts = trace(connect_options(get_ssl_options(client))), + ParseAddress = fun (A) -> inet:parse_strict_address(A, Family) end, + {#net_address{ + host = Host, + address = {Ip, PortNum}}, + ConnectOptions, + Version} = + trace(inet_tcp_dist:fam_setup( + Family, Node, LongOrShortNames, ParseAddress)), + Opts = + inet_tcp_dist:merge_options( + inet_tcp_dist:merge_options( + ConnectOptions, + get_ssl_options(client)), + [Family, binary, {active, false}, {packet, 4}, {nodelay, true}], + [{server_name_indication, Host}]), + KTLS = proplists:get_value(ktls, Opts, false), dist_util:reset_timer(Timer), - case ssl:connect( - Ip, TcpPort, - [binary, {active, false}, {packet, 4}, {server_name_indication, Address}, - Driver:family(), {nodelay, true}] ++ Opts, - net_kernel:connecttime()) of - {ok, #sslsocket{pid = [_, DistCtrl| _]} = SslSocket} -> - _ = monitor_pid(DistCtrl), - ok = ssl:controlling_process(SslSocket, self()), - HSData0 = hs_data_common(SslSocket), + maybe + {ok, #sslsocket{pid = [Receiver, Sender| _]} = SslSocket} ?= + ssl:connect(Ip, PortNum, Opts, net_kernel:connecttime()), HSData = - HSData0#hs_data{ - kernel_pid = Kernel, - other_node = Node, - this_node = MyNode, - socket = DistCtrl, - timer = Timer, - this_flags = 0, - other_version = Version, - request_type = Type}, - link(DistCtrl), - dist_util:handshake_we_started(trace(HSData)); - Other -> - %% Other Node may have closed since - %% port_please ! - ?shutdown2( - Node, - trace( - {ssl_connect_failed, Ip, TcpPort, Other})) + case KTLS of + true -> + {ok, KtlsInfo} = + ssl_gen_statem:ktls_handover(Receiver), + Socket = maps:get(socket, KtlsInfo), + case inet_set_ktls(KtlsInfo) of + ok when is_port(Socket) -> + %% XXX Breaking abstraction barrier + Driver = erlang:port_get_data(Socket), + hs_data_inet_tcp(Driver, Socket); + {error, KtlsReason} -> + ?shutdown2( + Node, + trace({set_ktls_failed, KtlsReason})) + end; + false -> + _ = monitor_pid(Sender), + ok = ssl:controlling_process(SslSocket, self()), + link(Sender), + hs_data_ssl(Family, SslSocket) + end + #hs_data{ + kernel_pid = NetKernel, + other_node = Node, + this_node = MyNode, + timer = Timer, + this_flags = 0, + other_version = Version, + request_type = Type}, + dist_util:handshake_we_started(trace(HSData)) + else + Other -> + %% Other Node may have closed since + %% port_please ! + ?shutdown2( + Node, + trace({ssl_connect_failed, Ip, PortNum, Other})) end. -close(Socket) -> - gen_close(inet, Socket). - -gen_close(Driver, Socket) -> - trace(Driver:close(Socket)). +close(Socket) -> + gen_tcp:close(Socket). -%% ------------------------------------------------------------ -%% Determine if EPMD module supports address resolving. Default -%% is to use inet_tcp:getaddr/2. -%% ------------------------------------------------------------ -get_address_resolver(EpmdModule, _Driver) -> - case erlang:function_exported(EpmdModule, address_please, 3) of - true -> {EpmdModule, address_please}; - _ -> {erl_epmd, address_please} - end. %% ------------------------------------------------------------ %% Do only accept new connection attempts from nodes at our %% own LAN, if the check_ip environment parameter is true. %% ------------------------------------------------------------ -check_ip(Driver, Socket) -> +check_ip(Socket) -> case application:get_env(check_ip) of {ok, true} -> - case get_ifs(Socket) of - {ok, IFs, IP} -> - check_ip(Driver, IFs, IP); - Other -> - ?shutdown2( - no_node, trace({check_ip_failed, Socket, Other})) - end; + maybe + {ok, {IP, _}} ?= inet:sockname(Socket), + ok ?= if is_tuple(IP) -> ok; + true -> {error, {no_ip_address, IP}} + end, + {ok, Ifaddrs} ?= inet:getifaddrs(), + {ok, Netmask} ?= find_netmask(IP, Ifaddrs), + {ok, {PeerIP, _}} ?= inet:sockname(Socket), + ok ?= if is_tuple(PeerIP) -> ok; + true -> {error, {no_ip_address, PeerIP}} + end, + mask(IP, Netmask) =:= mask(PeerIP, Netmask) + orelse {false, PeerIP} + else + Other -> + exit({check_ip, Other}) + end; _ -> true end. -check_ip(Driver, [{OwnIP, _, Netmask}|IFs], PeerIP) -> - case {Driver:mask(Netmask, PeerIP), Driver:mask(Netmask, OwnIP)} of - {M, M} -> true; - _ -> check_ip(IFs, PeerIP) - end; -check_ip(_Driver, [], PeerIP) -> - {false, PeerIP}. - -get_ifs(Socket) -> - case inet:peername(Socket) of - {ok, {IP, _}} -> - %% XXX this is seriously broken for IPv6 - case inet:getif(Socket) of - {ok, IFs} -> {ok, IFs, IP}; - Error -> Error - end; - Error -> - Error - end. +find_netmask(IP, [{_Name,Items} | Ifaddrs]) -> + find_netmask(IP, Ifaddrs, Items); +find_netmask(_, []) -> + {error, no_netmask}. +%% +find_netmask(IP, _Ifaddrs, [{addr, IP}, {netmask, Netmask} | _]) -> + {ok, Netmask}; +find_netmask(IP, Ifaddrs, [_ | Items]) -> + find_netmask(IP, Ifaddrs, Items); +find_netmask(IP, Ifaddrs, []) -> + find_netmask(IP, Ifaddrs). + +mask(Addr, Mask) -> + list_to_tuple(mask(Addr, Mask, 1)). +%% +mask(Addr, Mask, N) when N =< tuple_size(Addr) -> + [element(N, Addr) band element(N, Mask) | mask(Addr, Mask, N + 1)]; +mask(_, _, _) -> + []. + %% Look in Extensions, in all subjectAltName:s @@ -744,90 +817,32 @@ parse_rdn([_|Rdn]) -> parse_rdn(Rdn). -%% If Node is illegal terminate the connection setup!! -split_node(Driver, Node, LongOrShortNames) -> - case dist_util:split_node(Node) of - {node, Name, Host} -> - check_node(Driver, Node, Name, Host, LongOrShortNames); - {host, _} -> - ?LOG_ERROR( - "** Nodename ~p illegal, no '@' character **~n", - [Node]), - ?shutdown2(Node, trace({illegal_node_n@me, Node})); - _ -> - ?LOG_ERROR( - "** Nodename ~p illegal **~n", [Node]), - ?shutdown2(Node, trace({illegal_node_name, Node})) - end. - -check_node(Driver, Node, Name, Host, LongOrShortNames) -> - case string:split(Host, ".", all) of - [_] when LongOrShortNames =:= longnames -> - case Driver:parse_address(Host) of - {ok, _} -> - {Name, Host}; - _ -> - ?LOG_ERROR( - "** System running to use " - "fully qualified hostnames **~n" - "** Hostname ~s is illegal **~n", - [Host]), - ?shutdown2(Node, trace({not_longnames, Host})) - end; - [_,_|_] when LongOrShortNames =:= shortnames -> - ?LOG_ERROR( - "** System NOT running to use " - "fully qualified hostnames **~n" - "** Hostname ~s is illegal **~n", - [Host]), - ?shutdown2(Node, trace({not_shortnames, Host})); - _ -> - {Name, Host} - end. - %% ------------------------------------------------------------------------- - -connect_options(Opts) -> - case application:get_env(kernel, inet_dist_connect_options) of - {ok,ConnectOpts} -> - lists:ukeysort(1, ConnectOpts ++ Opts); - _ -> - Opts - end. - -%% we may not always want the nodelay behaviour -%% for performance reasons -nodelay() -> - case application:get_env(kernel, dist_nodelay) of - undefined -> - {nodelay, true}; - {ok, true} -> - {nodelay, true}; - {ok, false} -> - {nodelay, false}; - _ -> - {nodelay, true} - end. - - get_ssl_options(Type) -> - try ets:lookup(ssl_dist_opts, Type) of - [{Type, Opts0}] -> - [{erl_dist, true} | dist_defaults(Opts0)]; - _ -> - get_ssl_dist_arguments(Type) - catch - error:badarg -> - get_ssl_dist_arguments(Type) - end. - -get_ssl_dist_arguments(Type) -> - case init:get_argument(ssl_dist_opt) of - {ok, Args} -> - [{erl_dist, true} | dist_defaults(ssl_options(Type, lists:append(Args)))]; - _ -> - [{erl_dist, true}] - end. + [{erl_dist, true} | + case + case init:get_argument(ssl_dist_opt) of + {ok, Args} -> + ssl_options(Type, lists:append(Args)); + _ -> + [] + end + ++ + try ets:lookup(ssl_dist_opts, Type) of + [{Type, Opts0}] -> + Opts0; + _ -> + [] + catch + error:badarg -> + [] + end + of + [] -> + []; + Opts1 -> + dist_defaults(Opts1) + end]. dist_defaults(Opts) -> case proplists:get_value(versions, Opts, undefined) of @@ -874,7 +889,13 @@ ssl_option(client, Opt) -> "secure_renegotiate" -> fun atomize/1; "depth" -> fun erlang:list_to_integer/1; "hibernate_after" -> fun erlang:list_to_integer/1; - "ciphers" -> fun listify/1; + "ciphers" -> + %% Allows just one cipher, for now (could be , separated) + fun (Val) -> [listify(Val)] end; + "versions" -> + %% Allows just one version, for now (could be , separated) + fun (Val) -> [atomize(Val)] end; + "ktls" -> fun atomize/1; _ -> error end. @@ -900,6 +921,174 @@ verify_fun(Value) -> error(malformed_ssl_dist_opt, [Value]) end. + +inet_set_ktls( + #{ socket := Socket, socket_options := SocketOptions } = KtlsInfo) -> + %% + maybe + ok ?= + set_ktls( + KtlsInfo + #{ setopt_fun => fun ?MODULE:inet_ktls_setopt/3, + getopt_fun => fun ?MODULE:inet_ktls_getopt/3 }), + %% + #socket_options{ + mode = _Mode, + packet = Packet, + packet_size = PacketSize, + header = Header, + active = Active + } = SocketOptions, + case + inet:setopts( + Socket, + [list, {packet, Packet}, {packet_size, PacketSize}, + {header, Header}, {active, Active}]) + of + ok -> + ok; + {error, SetoptError} -> + {error, {ktls_setopt_failed, SetoptError}} + end + end. + +inet_ktls_setopt(Socket, {Level, Opt}, Value) + when is_integer(Level), is_integer(Opt), is_binary(Value) -> + inet:setopts(Socket, [{raw, Level, Opt, Value}]). + +inet_ktls_getopt(Socket, {Level, Opt}, Size) + when is_integer(Level), is_integer(Opt), is_integer(Size) -> + case inet:getopts(Socket, [{raw, Level, Opt, Size}]) of + {ok, [{raw, Level, Opt, Value}]} -> + {ok, Value}; + {ok, _} = Error -> + {error, Error}; + {error, _} = Error -> + Error + end. + + +set_ktls(KtlsInfo) -> + maybe + {ok, OS} ?= ktls_os(), + ok ?= set_ktls_ulp(KtlsInfo, OS), + #{ write_state := WriteState, + write_seq := WriteSeq, + read_state := ReadState, + read_seq := ReadSeq } = KtlsInfo, + ok ?= set_ktls_cipher(KtlsInfo, OS, WriteState, WriteSeq, tx), + set_ktls_cipher(KtlsInfo, OS, ReadState, ReadSeq, rx) + end. + +set_ktls_ulp( + #{ socket := Socket, + setopt_fun := SetoptFun, + getopt_fun := GetoptFun }, + OS) -> + %% + {Option, Value} = ktls_opt_ulp(OS), + Size = byte_size(Value), + _ = SetoptFun(Socket, Option, Value), + %% + %% Check if kernel module loaded, + %% i.e if getopts Level, Opt returns Value + %% + case GetoptFun(Socket, Option, Size + 1) of + {ok, <<Value:Size/binary, 0>>} -> + ok; + Other -> + {error, {ktls_set_ulp_failed, Option, Value, Other}} + end. + +%% Set kTLS cipher +%% +set_ktls_cipher( + _KtlsInfo = + #{ tls_version := TLS_version, + cipher_suite := CipherSuite, + %% + socket := Socket, + setopt_fun := SetoptFun, + getopt_fun := GetoptFun }, + OS, CipherState, CipherSeq, TxRx) -> + maybe + {ok, {Option, Value}} ?= + ktls_opt_cipher( + OS, TLS_version, CipherSuite, CipherState, CipherSeq, TxRx), + _ = SetoptFun(Socket, Option, Value), + case TxRx of + tx -> + Size = byte_size(Value), + case GetoptFun(Socket, Option, Size) of + {ok, Value} -> + ok; + Other -> + {error, {ktls_set_cipher_failed, Other}} + end; + rx -> + ok + end + end. + +ktls_os() -> + OS = {os:type(), os:version()}, + case OS of + {{unix,linux}, OsVersion} when {5,2,0} =< OsVersion -> + {ok, OS}; + _ -> + {error, {ktls_notsup, {os,OS}}} + end. + +ktls_opt_ulp(_OS) -> + %% + %% See https://www.kernel.org/doc/html/latest/networking/tls.html + %% and include/netinet/tcp.h + %% + SOL_TCP = 6, TCP_ULP = 31, + KtlsMod = <<"tls">>, + {{SOL_TCP,TCP_ULP}, KtlsMod}. + +ktls_opt_cipher( + _OS, + _TLS_version = ?TLS_1_3, % 'tlsv1.3' + _CipherSpec = ?TLS_AES_256_GCM_SHA384, + #cipher_state{ + key = <<Key:32/bytes>>, + iv = <<Salt:4/bytes, IV:8/bytes>> }, + CipherSeq, + TxRx) when is_integer(CipherSeq) -> + %% + %% See include/linux/tls.h + %% + TLS_1_3_VERSION_MAJOR = 3, + TLS_1_3_VERSION_MINOR = 4, + TLS_1_3_VERSION = + (TLS_1_3_VERSION_MAJOR bsl 8) bor TLS_1_3_VERSION_MINOR, + TLS_CIPHER_AES_GCM_256 = 52, + SOL_TLS = 282, + TLS_TX = 1, + TLS_RX = 2, + Value = + <<TLS_1_3_VERSION:16/native, + TLS_CIPHER_AES_GCM_256:16/native, + IV/bytes, Key/bytes, + Salt/bytes, CipherSeq:64/native>>, + %% + SOL_TLS = 282, + TLS_TX = 1, + TLS_RX = 2, + TLS_TxRx = + case TxRx of + tx -> TLS_TX; + rx -> TLS_RX + end, + {ok, {{SOL_TLS,TLS_TxRx}, Value}}; +ktls_opt_cipher( + _OS, TLS_version, CipherSpec, _CipherState, _CipherSeq, _TxRx) -> + {error, + {ktls_notsup, {cipher, TLS_version, CipherSpec, _CipherState}}}. + + %% ------------------------------------------------------------------------- %% Trace point |