diff options
author | Michael Klishin <klishinm@vmware.com> | 2023-01-29 10:58:08 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-29 10:58:08 -0500 |
commit | 5dc2d9b828250d2514a664d4dd6516bd776b2fb6 (patch) | |
tree | 0661c557c1c26dbd19582660e978798b39326efa | |
parent | d8da0b5f9e844c257ca6d20420738e1923c8ed49 (diff) | |
parent | 02cf072ae470125b2388f82cba1a42855a1eb325 (diff) | |
download | rabbitmq-server-git-5dc2d9b828250d2514a664d4dd6516bd776b2fb6.tar.gz |
Merge pull request #7091 from rabbitmq/mqtt-max-size-connect-packet
Set MQTT max packet size
-rw-r--r-- | .git-blame-ignore-revs | 4 | ||||
-rw-r--r-- | deps/rabbitmq_mqtt/src/rabbit_mqtt_packet.erl | 125 | ||||
-rw-r--r-- | deps/rabbitmq_mqtt/src/rabbit_mqtt_reader.erl | 4 | ||||
-rw-r--r-- | deps/rabbitmq_mqtt/test/shared_SUITE.erl | 29 | ||||
-rw-r--r-- | deps/rabbitmq_mqtt/test/util.erl | 9 | ||||
-rw-r--r-- | deps/rabbitmq_web_mqtt/src/rabbit_web_mqtt_handler.erl | 6 |
6 files changed, 124 insertions, 53 deletions
diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000000..fb86a8deb8 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,4 @@ +# Revert "Format MQTT code with erlfmt" +209f23fa2f58e0240116b3e8e5be9cd54d34b569 +# Format MQTT code with erlfmt +1de9fcf582def91d1cee6bea457dd24e8a53a431 diff --git a/deps/rabbitmq_mqtt/src/rabbit_mqtt_packet.erl b/deps/rabbitmq_mqtt/src/rabbit_mqtt_packet.erl index ededed8c5b..a525d02491 100644 --- a/deps/rabbitmq_mqtt/src/rabbit_mqtt_packet.erl +++ b/deps/rabbitmq_mqtt/src/rabbit_mqtt_packet.erl @@ -10,40 +10,53 @@ -include("rabbit_mqtt_packet.hrl"). -include("rabbit_mqtt.hrl"). --export([parse/2, initial_state/0, serialise/2]). +-export([init_state/0, reset_state/0, + parse/2, serialise/2]). -export_type([state/0]). --opaque state() :: none | fun(). +-opaque state() :: unauthenticated | authenticated | fun(). -define(RESERVED, 0). -define(MAX_LEN, 16#fffffff). -define(HIGHBIT, 2#10000000). -define(LOWBITS, 2#01111111). +-define(MAX_MULTIPLIER, ?HIGHBIT * ?HIGHBIT * ?HIGHBIT). +-define(MAX_PACKET_SIZE_CONNECT, 65_536). --spec initial_state() -> state(). -initial_state() -> none. +-spec init_state() -> state(). +init_state() -> unauthenticated. + +-spec reset_state() -> state(). +reset_state() -> authenticated. -spec parse(binary(), state()) -> {more, state()} | {ok, mqtt_packet(), binary()} | {error, any()}. -parse(<<>>, none) -> - {more, fun(Bin) -> parse(Bin, none) end}; -parse(<<MessageType:4, Dup:1, QoS:2, Retain:1, Rest/binary>>, none) -> +parse(<<>>, authenticated) -> + {more, fun(Bin) -> parse(Bin, authenticated) end}; +parse(<<MessageType:4, Dup:1, QoS:2, Retain:1, Rest/binary>>, authenticated) -> parse_remaining_len(Rest, #mqtt_packet_fixed{ type = MessageType, dup = bool(Dup), qos = QoS, retain = bool(Retain) }); -parse(Bin, Cont) -> Cont(Bin). +parse(<<?CONNECT:4, 0:4, Rest/binary>>, unauthenticated) -> + parse_remaining_len(Rest, #mqtt_packet_fixed{type = ?CONNECT}); +parse(Bin, Cont) + when is_function(Cont) -> + Cont(Bin). parse_remaining_len(<<>>, Fixed) -> {more, fun(Bin) -> parse_remaining_len(Bin, Fixed) end}; parse_remaining_len(Rest, Fixed) -> parse_remaining_len(Rest, Fixed, 1, 0). +parse_remaining_len(_Bin, _Fixed, Multiplier, _Length) + when Multiplier > ?MAX_MULTIPLIER -> + {error, malformed_remaining_length}; parse_remaining_len(_Bin, _Fixed, _Multiplier, Length) when Length > ?MAX_LEN -> - {error, invalid_mqtt_packet_len}; + {error, invalid_mqtt_packet_length}; parse_remaining_len(<<>>, Fixed, Multiplier, Length) -> {more, fun(Bin) -> parse_remaining_len(Bin, Fixed, Multiplier, Length) end}; parse_remaining_len(<<1:1, Len:7, Rest/binary>>, Fixed, Multiplier, Value) -> @@ -51,45 +64,12 @@ parse_remaining_len(<<1:1, Len:7, Rest/binary>>, Fixed, Multiplier, Value) -> parse_remaining_len(<<0:1, Len:7, Rest/binary>>, Fixed, Multiplier, Value) -> parse_packet(Rest, Fixed, Value + Len * Multiplier). -parse_packet(Bin, #mqtt_packet_fixed{ type = Type, - qos = Qos } = Fixed, Length) +parse_packet(Bin, #mqtt_packet_fixed{type = ?CONNECT} = Fixed, Length) -> + parse_connect(Bin, Fixed, Length); +parse_packet(Bin, #mqtt_packet_fixed{type = Type, + qos = Qos} = Fixed, Length) when Length =< ?MAX_LEN -> case {Type, Bin} of - {?CONNECT, <<PacketBin:Length/binary, Rest/binary>>} -> - {ProtoName, Rest1} = parse_utf(PacketBin), - <<ProtoVersion : 8, Rest2/binary>> = Rest1, - <<UsernameFlag : 1, - PasswordFlag : 1, - WillRetain : 1, - WillQos : 2, - WillFlag : 1, - CleanSession : 1, - _Reserved : 1, - KeepAlive : 16/big, - Rest3/binary>> = Rest2, - {ClientId, Rest4} = parse_utf(Rest3), - {WillTopic, Rest5} = parse_utf(Rest4, WillFlag), - {WillMsg, Rest6} = parse_msg(Rest5, WillFlag), - {UserName, Rest7} = parse_utf(Rest6, UsernameFlag), - {PasssWord, <<>>} = parse_utf(Rest7, PasswordFlag), - case protocol_name_approved(ProtoVersion, ProtoName) of - true -> - wrap(Fixed, - #mqtt_packet_connect{ - proto_ver = ProtoVersion, - will_retain = bool(WillRetain), - will_qos = WillQos, - will_flag = bool(WillFlag), - clean_sess = bool(CleanSession), - keep_alive = KeepAlive, - client_id = ClientId, - will_topic = WillTopic, - will_msg = WillMsg, - username = UserName, - password = PasssWord}, Rest); - false -> - {error, protocol_header_corrupt} - end; {?PUBLISH, <<PacketBin:Length/binary, Rest/binary>>} -> {TopicName, Rest1} = parse_utf(PacketBin), {PacketId, Payload} = case Qos of @@ -122,6 +102,59 @@ parse_packet(Bin, #mqtt_packet_fixed{ type = Type, end} end. +parse_connect(Bin, Fixed, Length) -> + MaxSize = application:get_env(?APP_NAME, + max_packet_size_unauthenticated, + ?MAX_PACKET_SIZE_CONNECT), + case Length =< MaxSize of + true -> + case Bin of + <<PacketBin:Length/binary, Rest/binary>> -> + {ProtoName, Rest1} = parse_utf(PacketBin), + <<ProtoVersion : 8, Rest2/binary>> = Rest1, + <<UsernameFlag : 1, + PasswordFlag : 1, + WillRetain : 1, + WillQos : 2, + WillFlag : 1, + CleanSession : 1, + _Reserved : 1, + KeepAlive : 16/big, + Rest3/binary>> = Rest2, + {ClientId, Rest4} = parse_utf(Rest3), + {WillTopic, Rest5} = parse_utf(Rest4, WillFlag), + {WillMsg, Rest6} = parse_msg(Rest5, WillFlag), + {UserName, Rest7} = parse_utf(Rest6, UsernameFlag), + {PasssWord, <<>>} = parse_utf(Rest7, PasswordFlag), + case protocol_name_approved(ProtoVersion, ProtoName) of + true -> + wrap(Fixed, + #mqtt_packet_connect{ + proto_ver = ProtoVersion, + will_retain = bool(WillRetain), + will_qos = WillQos, + will_flag = bool(WillFlag), + clean_sess = bool(CleanSession), + keep_alive = KeepAlive, + client_id = ClientId, + will_topic = WillTopic, + will_msg = WillMsg, + username = UserName, + password = PasssWord}, Rest); + false -> + {error, protocol_header_corrupt} + end; + TooShortBin + when byte_size(TooShortBin) < Length -> + {more, fun(BinMore) -> + parse_connect(<<TooShortBin/binary, BinMore/binary>>, + Fixed, Length) + end} + end; + false -> + {error, connect_packet_too_large} + end. + parse_topics(_, <<>>, Topics) -> Topics; parse_topics(?SUBSCRIBE = Sub, Bin, Topics) -> diff --git a/deps/rabbitmq_mqtt/src/rabbit_mqtt_reader.erl b/deps/rabbitmq_mqtt/src/rabbit_mqtt_reader.erl index 3949989230..a6c578a610 100644 --- a/deps/rabbitmq_mqtt/src/rabbit_mqtt_reader.erl +++ b/deps/rabbitmq_mqtt/src/rabbit_mqtt_reader.erl @@ -91,7 +91,7 @@ init(Ref) -> connection_state = running, received_connect_packet = false, conserve = false, - parse_state = rabbit_mqtt_packet:initial_state(), + parse_state = rabbit_mqtt_packet:init_state(), proc_state = ProcessorState}, State1 = control_throttle(State0), State = rabbit_event:init_stats_timer(State1, #state.stats_timer), @@ -336,7 +336,7 @@ process_received_bytes(Bytes, {ok, ProcState1} -> process_received_bytes( Rest, - State #state{parse_state = rabbit_mqtt_packet:initial_state(), + State #state{parse_state = rabbit_mqtt_packet:reset_state(), proc_state = ProcState1}); %% PUBLISH and more {error, unauthorized = Reason, ProcState1} -> diff --git a/deps/rabbitmq_mqtt/test/shared_SUITE.erl b/deps/rabbitmq_mqtt/test/shared_SUITE.erl index 71183d5f9c..b784426a80 100644 --- a/deps/rabbitmq_mqtt/test/shared_SUITE.erl +++ b/deps/rabbitmq_mqtt/test/shared_SUITE.erl @@ -92,6 +92,7 @@ subgroups() -> ,clean_session_kill_node ,rabbit_status_connection_count ,trace + ,max_packet_size_unauthenticated ]} ]}, {cluster_size_3, [], @@ -1337,6 +1338,34 @@ trace(Config) -> delete_queue(Ch, TraceQ), [ok = emqtt:disconnect(C) || C <- [Pub, Sub]]. +max_packet_size_unauthenticated(Config) -> + App = rabbitmq_mqtt, + Par = ClientId = ?FUNCTION_NAME, + Opts = [{will_topic, <<"will/topic">>}], + + {C1, Connect} = util:start_client( + ClientId, Config, 0, + [{will_payload, binary:copy(<<"a">>, 64_000)} | Opts]), + ?assertMatch({ok, _}, Connect(C1)), + ok = emqtt:disconnect(C1), + + MaxSize = 500, + ok = rpc(Config, application, set_env, [App, Par, MaxSize]), + + {C2, Connect} = util:start_client( + ClientId, Config, 0, + [{will_payload, binary:copy(<<"b">>, MaxSize + 1)} | Opts]), + true = unlink(C2), + ?assertMatch({error, _}, Connect(C2)), + + {C3, Connect} = util:start_client( + ClientId, Config, 0, + [{will_payload, binary:copy(<<"c">>, round(MaxSize / 2))} | Opts]), + ?assertMatch({ok, _}, Connect(C3)), + ok = emqtt:disconnect(C3), + + ok = rpc(Config, application, unset_env, [App, Par]). + %% ------------------------------------------------------------------- %% Internal helpers %% ------------------------------------------------------------------- diff --git a/deps/rabbitmq_mqtt/test/util.erl b/deps/rabbitmq_mqtt/test/util.erl index 44c684a6c5..6d86fe9a34 100644 --- a/deps/rabbitmq_mqtt/test/util.erl +++ b/deps/rabbitmq_mqtt/test/util.erl @@ -14,6 +14,7 @@ connect/2, connect/3, connect/4, + start_client/4, get_events/1, assert_event_type/2, assert_event_prop/2, @@ -119,6 +120,11 @@ connect(ClientId, Config, AdditionalOpts) -> connect(ClientId, Config, 0, AdditionalOpts). connect(ClientId, Config, Node, AdditionalOpts) -> + {C, Connect} = start_client(ClientId, Config, Node, AdditionalOpts), + {ok, _Properties} = Connect(C), + C. + +start_client(ClientId, Config, Node, AdditionalOpts) -> {Port, WsOpts, Connect} = case rabbit_ct_helpers:get_config(Config, websocket, false) of false -> @@ -136,5 +142,4 @@ connect(ClientId, Config, Node, AdditionalOpts) -> {clientid, rabbit_data_coercion:to_binary(ClientId)} ] ++ WsOpts ++ AdditionalOpts, {ok, C} = emqtt:start_link(Options), - {ok, _Properties} = Connect(C), - C. + {C, Connect}. diff --git a/deps/rabbitmq_web_mqtt/src/rabbit_web_mqtt_handler.erl b/deps/rabbitmq_web_mqtt/src/rabbit_web_mqtt_handler.erl index 72f52b3977..1fb022ee44 100644 --- a/deps/rabbitmq_web_mqtt/src/rabbit_web_mqtt_handler.erl +++ b/deps/rabbitmq_web_mqtt/src/rabbit_web_mqtt_handler.erl @@ -32,7 +32,7 @@ -record(state, { socket :: {rabbit_proxy_socket, any(), any()} | rabbit_net:socket(), - parse_state = rabbit_mqtt_packet:initial_state() :: rabbit_mqtt_packet:state(), + parse_state = rabbit_mqtt_packet:init_state() :: rabbit_mqtt_packet:state(), proc_state :: undefined | rabbit_mqtt_processor:state(), connection_state = running :: running | blocked, conserve = false :: boolean(), @@ -273,7 +273,7 @@ handle_data1(Data, State = #state{ parse_state = ParseState, {ok, ProcState1} -> handle_data1( Rest, - State#state{parse_state = rabbit_mqtt_packet:initial_state(), + State#state{parse_state = rabbit_mqtt_packet:reset_state(), proc_state = ProcState1}); {error, Reason, _} -> stop_mqtt_protocol_error(State, Reason, ConnName); @@ -296,7 +296,7 @@ parse(Data, ParseState) -> end. stop_mqtt_protocol_error(State, Reason, ConnName) -> - ?LOG_INFO("MQTT protocol error ~tp for connection ~tp", [Reason, ConnName]), + ?LOG_WARNING("Web MQTT protocol error ~tp for connection ~tp", [Reason, ConnName]), stop(State, ?CLOSE_PROTOCOL_ERROR, Reason). stop(State) -> |