summaryrefslogtreecommitdiff
path: root/deps/rabbitmq_mqtt/src/rabbit_mqtt_processor.erl
diff options
context:
space:
mode:
Diffstat (limited to 'deps/rabbitmq_mqtt/src/rabbit_mqtt_processor.erl')
-rw-r--r--deps/rabbitmq_mqtt/src/rabbit_mqtt_processor.erl1054
1 files changed, 1054 insertions, 0 deletions
diff --git a/deps/rabbitmq_mqtt/src/rabbit_mqtt_processor.erl b/deps/rabbitmq_mqtt/src/rabbit_mqtt_processor.erl
new file mode 100644
index 0000000000..c3a25096e6
--- /dev/null
+++ b/deps/rabbitmq_mqtt/src/rabbit_mqtt_processor.erl
@@ -0,0 +1,1054 @@
+%% This Source Code Form is subject to the terms of the Mozilla Public
+%% License, v. 2.0. If a copy of the MPL was not distributed with this
+%% file, You can obtain one at https://mozilla.org/MPL/2.0/.
+%%
+%% Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved.
+%%
+
+-module(rabbit_mqtt_processor).
+
+-export([info/2, initial_state/2, initial_state/5,
+ process_frame/2, amqp_pub/2, amqp_callback/2, send_will/1,
+ close_connection/1, handle_pre_hibernate/0,
+ handle_ra_event/2]).
+
+%% for testing purposes
+-export([get_vhost_username/1, get_vhost/3, get_vhost_from_user_mapping/2,
+ add_client_id_to_adapter_info/2]).
+
+-include_lib("amqp_client/include/amqp_client.hrl").
+-include("rabbit_mqtt_frame.hrl").
+-include("rabbit_mqtt.hrl").
+
+-define(APP, rabbitmq_mqtt).
+-define(FRAME_TYPE(Frame, Type),
+ Frame = #mqtt_frame{ fixed = #mqtt_frame_fixed{ type = Type }}).
+-define(MAX_TOPIC_PERMISSION_CACHE_SIZE, 12).
+
+initial_state(Socket, SSLLoginName) ->
+ RealSocket = rabbit_net:unwrap_socket(Socket),
+ {ok, {PeerAddr, _PeerPort}} = rabbit_net:peername(RealSocket),
+ initial_state(RealSocket, SSLLoginName,
+ adapter_info(Socket, 'MQTT'),
+ fun serialise_and_send_to_client/2, PeerAddr).
+
+initial_state(Socket, SSLLoginName,
+ AdapterInfo0 = #amqp_adapter_info{additional_info = Extra},
+ SendFun, PeerAddr) ->
+ {ok, {mqtt2amqp_fun, M2A}, {amqp2mqtt_fun, A2M}} =
+ rabbit_mqtt_util:get_topic_translation_funs(),
+ %% MQTT connections use exactly one channel. The frame max is not
+ %% applicable and there is no way to know what client is used.
+ AdapterInfo = AdapterInfo0#amqp_adapter_info{additional_info = [
+ {channels, 1},
+ {channel_max, 1},
+ {frame_max, 0},
+ {client_properties,
+ [{<<"product">>, longstr, <<"MQTT client">>}]} | Extra]},
+ #proc_state{ unacked_pubs = gb_trees:empty(),
+ awaiting_ack = gb_trees:empty(),
+ message_id = 1,
+ subscriptions = #{},
+ consumer_tags = {undefined, undefined},
+ channels = {undefined, undefined},
+ exchange = rabbit_mqtt_util:env(exchange),
+ socket = Socket,
+ adapter_info = AdapterInfo,
+ ssl_login_name = SSLLoginName,
+ send_fun = SendFun,
+ peer_addr = PeerAddr,
+ mqtt2amqp_fun = M2A,
+ amqp2mqtt_fun = A2M}.
+
+process_frame(#mqtt_frame{ fixed = #mqtt_frame_fixed{ type = Type }},
+ PState = #proc_state{ connection = undefined } )
+ when Type =/= ?CONNECT ->
+ {error, connect_expected, PState};
+process_frame(Frame = #mqtt_frame{ fixed = #mqtt_frame_fixed{ type = Type }},
+ PState) ->
+ case process_request(Type, Frame, PState) of
+ {ok, PState1} -> {ok, PState1, PState1#proc_state.connection};
+ Ret -> Ret
+ end.
+
+add_client_id_to_adapter_info(ClientId, #amqp_adapter_info{additional_info = AdditionalInfo0} = AdapterInfo) ->
+ AdditionalInfo1 = [{variable_map, #{<<"client_id">> => ClientId}}
+ | AdditionalInfo0],
+ ClientProperties = proplists:get_value(client_properties, AdditionalInfo1, [])
+ ++ [{client_id, longstr, ClientId}],
+ AdditionalInfo2 = case lists:keysearch(client_properties, 1, AdditionalInfo1) of
+ {value, _} ->
+ lists:keyreplace(client_properties,
+ 1,
+ AdditionalInfo1,
+ {client_properties, ClientProperties});
+ false ->
+ [{client_properties, ClientProperties} | AdditionalInfo1]
+ end,
+ AdapterInfo#amqp_adapter_info{additional_info = AdditionalInfo2}.
+
+process_request(?CONNECT,
+ #mqtt_frame{ variable = #mqtt_frame_connect{
+ username = Username,
+ password = Password,
+ proto_ver = ProtoVersion,
+ clean_sess = CleanSess,
+ client_id = ClientId0,
+ keep_alive = Keepalive} = Var},
+ PState0 = #proc_state{ ssl_login_name = SSLLoginName,
+ send_fun = SendFun,
+ adapter_info = AdapterInfo,
+ peer_addr = Addr}) ->
+ ClientId = case ClientId0 of
+ [] -> rabbit_mqtt_util:gen_client_id();
+ [_|_] -> ClientId0
+ end,
+ rabbit_log_connection:debug("Received a CONNECT, client ID: ~p (expanded to ~p), username: ~p, "
+ "clean session: ~p, protocol version: ~p, keepalive: ~p",
+ [ClientId0, ClientId, Username, CleanSess, ProtoVersion, Keepalive]),
+ AdapterInfo1 = add_client_id_to_adapter_info(rabbit_data_coercion:to_binary(ClientId), AdapterInfo),
+ PState1 = PState0#proc_state{adapter_info = AdapterInfo1},
+ Ip = list_to_binary(inet:ntoa(Addr)),
+ {Return, PState5} =
+ case {lists:member(ProtoVersion, proplists:get_keys(?PROTOCOL_NAMES)),
+ ClientId0 =:= [] andalso CleanSess =:= false} of
+ {false, _} ->
+ {?CONNACK_PROTO_VER, PState1};
+ {_, true} ->
+ {?CONNACK_INVALID_ID, PState1};
+ _ ->
+ case creds(Username, Password, SSLLoginName) of
+ nocreds ->
+ rabbit_core_metrics:auth_attempt_failed(Ip, <<>>, mqtt),
+ rabbit_log_connection:error("MQTT login failed: no credentials provided~n"),
+ {?CONNACK_CREDENTIALS, PState1};
+ {invalid_creds, {undefined, Pass}} when is_list(Pass) ->
+ rabbit_core_metrics:auth_attempt_failed(Ip, <<>>, mqtt),
+ rabbit_log_connection:error("MQTT login failed: no username is provided"),
+ {?CONNACK_CREDENTIALS, PState1};
+ {invalid_creds, {User, undefined}} when is_list(User) ->
+ rabbit_core_metrics:auth_attempt_failed(Ip, User, mqtt),
+ rabbit_log_connection:error("MQTT login failed for user '~p': no password provided", [User]),
+ {?CONNACK_CREDENTIALS, PState1};
+ {UserBin, PassBin} ->
+ case process_login(UserBin, PassBin, ProtoVersion, PState1) of
+ connack_dup_auth ->
+ {SessionPresent0, PState2} = maybe_clean_sess(PState1),
+ {{?CONNACK_ACCEPT, SessionPresent0}, PState2};
+ {?CONNACK_ACCEPT, Conn, VHost, AState} ->
+ case rabbit_mqtt_collector:register(ClientId, self()) of
+ {ok, Corr} ->
+ RetainerPid = rabbit_mqtt_retainer_sup:child_for_vhost(VHost),
+ link(Conn),
+ {ok, Ch} = amqp_connection:open_channel(Conn),
+ link(Ch),
+ amqp_channel:enable_delivery_flow_control(Ch),
+ Prefetch = rabbit_mqtt_util:env(prefetch),
+ #'basic.qos_ok'{} = amqp_channel:call(Ch,
+ #'basic.qos'{prefetch_count = Prefetch}),
+ rabbit_mqtt_reader:start_keepalive(self(), Keepalive),
+ PState3 = PState1#proc_state{
+ will_msg = make_will_msg(Var),
+ clean_sess = CleanSess,
+ channels = {Ch, undefined},
+ connection = Conn,
+ client_id = ClientId,
+ retainer_pid = RetainerPid,
+ auth_state = AState,
+ register_state = {pending, Corr}},
+ {SessionPresent1, PState4} = maybe_clean_sess(PState3),
+ {{?CONNACK_ACCEPT, SessionPresent1}, PState4};
+ %% e.g. this node was removed from the MQTT cluster members
+ {error, _} = Err ->
+ rabbit_log_connection:error("MQTT cannot accept a connection: "
+ "client ID tracker is unavailable: ~p", [Err]),
+ %% ignore all exceptions, we are shutting down
+ catch amqp_connection:close(Conn),
+ {?CONNACK_SERVER, PState1};
+ {timeout, _} ->
+ rabbit_log_connection:error("MQTT cannot accept a connection: "
+ "client ID registration timed out"),
+ %% ignore all exceptions, we are shutting down
+ catch amqp_connection:close(Conn),
+ {?CONNACK_SERVER, PState1}
+ end;
+ ConnAck -> {ConnAck, PState1}
+ end
+ end
+ end,
+ {ReturnCode, SessionPresent} = case Return of
+ {?CONNACK_ACCEPT, Bool} -> {?CONNACK_ACCEPT, Bool};
+ Other -> {Other, false}
+ end,
+ SendFun(#mqtt_frame{fixed = #mqtt_frame_fixed{type = ?CONNACK},
+ variable = #mqtt_frame_connack{
+ session_present = SessionPresent,
+ return_code = ReturnCode}},
+ PState5),
+ case ReturnCode of
+ ?CONNACK_ACCEPT -> {ok, PState5};
+ ?CONNACK_CREDENTIALS -> {error, unauthenticated, PState5};
+ ?CONNACK_AUTH -> {error, unauthorized, PState5};
+ ?CONNACK_SERVER -> {error, unavailable, PState5};
+ ?CONNACK_INVALID_ID -> {error, invalid_client_id, PState5};
+ ?CONNACK_PROTO_VER -> {error, unsupported_protocol_version, PState5}
+ end;
+
+process_request(?PUBACK,
+ #mqtt_frame{
+ variable = #mqtt_frame_publish{ message_id = MessageId }},
+ #proc_state{ channels = {Channel, _},
+ awaiting_ack = Awaiting } = PState) ->
+ %% tag can be missing because of bogus clients and QoS downgrades
+ case gb_trees:is_defined(MessageId, Awaiting) of
+ false ->
+ {ok, PState};
+ true ->
+ Tag = gb_trees:get(MessageId, Awaiting),
+ amqp_channel:cast(Channel, #'basic.ack'{ delivery_tag = Tag }),
+ {ok, PState#proc_state{ awaiting_ack = gb_trees:delete(MessageId, Awaiting) }}
+ end;
+
+process_request(?PUBLISH,
+ Frame = #mqtt_frame{
+ fixed = Fixed = #mqtt_frame_fixed{ qos = ?QOS_2 }},
+ PState) ->
+ % Downgrade QOS_2 to QOS_1
+ process_request(?PUBLISH,
+ Frame#mqtt_frame{
+ fixed = Fixed#mqtt_frame_fixed{ qos = ?QOS_1 }},
+ PState);
+process_request(?PUBLISH,
+ #mqtt_frame{
+ fixed = #mqtt_frame_fixed{ qos = Qos,
+ retain = Retain,
+ dup = Dup },
+ variable = #mqtt_frame_publish{ topic_name = Topic,
+ message_id = MessageId },
+ payload = Payload },
+ PState = #proc_state{retainer_pid = RPid,
+ amqp2mqtt_fun = Amqp2MqttFun}) ->
+ check_publish(Topic, fun() ->
+ Msg = #mqtt_msg{retain = Retain,
+ qos = Qos,
+ topic = Topic,
+ dup = Dup,
+ message_id = MessageId,
+ payload = Payload},
+ Result = amqp_pub(Msg, PState),
+ case Retain of
+ false -> ok;
+ true -> hand_off_to_retainer(RPid, Amqp2MqttFun, Topic, Msg)
+ end,
+ {ok, Result}
+ end, PState);
+
+process_request(?SUBSCRIBE,
+ #mqtt_frame{
+ variable = #mqtt_frame_subscribe{
+ message_id = SubscribeMsgId,
+ topic_table = Topics},
+ payload = undefined},
+ #proc_state{channels = {Channel, _},
+ exchange = Exchange,
+ retainer_pid = RPid,
+ send_fun = SendFun,
+ message_id = StateMsgId,
+ mqtt2amqp_fun = Mqtt2AmqpFun} = PState0) ->
+ rabbit_log_connection:debug("Received a SUBSCRIBE for topic(s) ~p", [Topics]),
+ check_subscribe(Topics, fun() ->
+ {QosResponse, PState1} =
+ lists:foldl(fun (#mqtt_topic{name = TopicName,
+ qos = Qos}, {QosList, PState}) ->
+ SupportedQos = supported_subs_qos(Qos),
+ {Queue, #proc_state{subscriptions = Subs} = PState1} =
+ ensure_queue(SupportedQos, PState),
+ RoutingKey = Mqtt2AmqpFun(TopicName),
+ Binding = #'queue.bind'{
+ queue = Queue,
+ exchange = Exchange,
+ routing_key = RoutingKey},
+ #'queue.bind_ok'{} = amqp_channel:call(Channel, Binding),
+ SupportedQosList = case maps:find(TopicName, Subs) of
+ {ok, L} -> [SupportedQos|L];
+ error -> [SupportedQos]
+ end,
+ {[SupportedQos | QosList],
+ PState1 #proc_state{
+ subscriptions =
+ maps:put(TopicName, SupportedQosList, Subs)}}
+ end, {[], PState0}, Topics),
+ SendFun(#mqtt_frame{fixed = #mqtt_frame_fixed{type = ?SUBACK},
+ variable = #mqtt_frame_suback{
+ message_id = SubscribeMsgId,
+ qos_table = QosResponse}}, PState1),
+ %% we may need to send up to length(Topics) messages.
+ %% if QoS is > 0 then we need to generate a message id,
+ %% and increment the counter.
+ StartMsgId = safe_max_id(SubscribeMsgId, StateMsgId),
+ N = lists:foldl(fun (Topic, Acc) ->
+ case maybe_send_retained_message(RPid, Topic, Acc, PState1) of
+ {true, X} -> Acc + X;
+ false -> Acc
+ end
+ end, StartMsgId, Topics),
+ {ok, PState1#proc_state{message_id = N}}
+ end, PState0);
+
+process_request(?UNSUBSCRIBE,
+ #mqtt_frame{
+ variable = #mqtt_frame_subscribe{ message_id = MessageId,
+ topic_table = Topics },
+ payload = undefined }, #proc_state{ channels = {Channel, _},
+ exchange = Exchange,
+ client_id = ClientId,
+ subscriptions = Subs0,
+ send_fun = SendFun,
+ mqtt2amqp_fun = Mqtt2AmqpFun } = PState) ->
+ rabbit_log_connection:debug("Received an UNSUBSCRIBE for topic(s) ~p", [Topics]),
+ Queues = rabbit_mqtt_util:subcription_queue_name(ClientId),
+ Subs1 =
+ lists:foldl(
+ fun (#mqtt_topic{ name = TopicName }, Subs) ->
+ QosSubs = case maps:find(TopicName, Subs) of
+ {ok, Val} when is_list(Val) -> lists:usort(Val);
+ error -> []
+ end,
+ RoutingKey = Mqtt2AmqpFun(TopicName),
+ lists:foreach(
+ fun (QosSub) ->
+ Queue = element(QosSub + 1, Queues),
+ Binding = #'queue.unbind'{
+ queue = Queue,
+ exchange = Exchange,
+ routing_key = RoutingKey},
+ #'queue.unbind_ok'{} = amqp_channel:call(Channel, Binding)
+ end, QosSubs),
+ maps:remove(TopicName, Subs)
+ end, Subs0, Topics),
+ SendFun(#mqtt_frame{ fixed = #mqtt_frame_fixed { type = ?UNSUBACK },
+ variable = #mqtt_frame_suback{ message_id = MessageId }},
+ PState),
+ {ok, PState #proc_state{ subscriptions = Subs1 }};
+
+process_request(?PINGREQ, #mqtt_frame{}, #proc_state{ send_fun = SendFun } = PState) ->
+ rabbit_log_connection:debug("Received a PINGREQ"),
+ SendFun(#mqtt_frame{ fixed = #mqtt_frame_fixed{ type = ?PINGRESP }},
+ PState),
+ rabbit_log_connection:debug("Sent a PINGRESP"),
+ {ok, PState};
+
+process_request(?DISCONNECT, #mqtt_frame{}, PState) ->
+ rabbit_log_connection:debug("Received a DISCONNECT"),
+ {stop, PState}.
+
+hand_off_to_retainer(RetainerPid, Amqp2MqttFun, Topic0, #mqtt_msg{payload = <<"">>}) ->
+ Topic1 = Amqp2MqttFun(Topic0),
+ rabbit_mqtt_retainer:clear(RetainerPid, Topic1),
+ ok;
+hand_off_to_retainer(RetainerPid, Amqp2MqttFun, Topic0, Msg) ->
+ Topic1 = Amqp2MqttFun(Topic0),
+ rabbit_mqtt_retainer:retain(RetainerPid, Topic1, Msg),
+ ok.
+
+maybe_send_retained_message(RPid, #mqtt_topic{name = Topic0, qos = SubscribeQos}, MsgId,
+ #proc_state{ send_fun = SendFun,
+ amqp2mqtt_fun = Amqp2MqttFun } = PState) ->
+ Topic1 = Amqp2MqttFun(Topic0),
+ case rabbit_mqtt_retainer:fetch(RPid, Topic1) of
+ undefined -> false;
+ Msg ->
+ %% calculate effective QoS as the lower value of SUBSCRIBE frame QoS
+ %% and retained message QoS. The spec isn't super clear on this, we
+ %% do what Mosquitto does, per user feedback.
+ Qos = erlang:min(SubscribeQos, Msg#mqtt_msg.qos),
+ Id = case Qos of
+ ?QOS_0 -> undefined;
+ ?QOS_1 -> MsgId
+ end,
+ SendFun(#mqtt_frame{fixed = #mqtt_frame_fixed{
+ type = ?PUBLISH,
+ qos = Qos,
+ dup = false,
+ retain = Msg#mqtt_msg.retain
+ }, variable = #mqtt_frame_publish{
+ message_id = Id,
+ topic_name = Topic1
+ },
+ payload = Msg#mqtt_msg.payload}, PState),
+ case Qos of
+ ?QOS_0 -> false;
+ ?QOS_1 -> {true, 1}
+ end
+ end.
+
+amqp_callback({#'basic.deliver'{ consumer_tag = ConsumerTag,
+ delivery_tag = DeliveryTag,
+ routing_key = RoutingKey },
+ #amqp_msg{ props = #'P_basic'{ headers = Headers },
+ payload = Payload },
+ DeliveryCtx} = Delivery,
+ #proc_state{ channels = {Channel, _},
+ awaiting_ack = Awaiting,
+ message_id = MsgId,
+ send_fun = SendFun,
+ amqp2mqtt_fun = Amqp2MqttFun } = PState) ->
+ amqp_channel:notify_received(DeliveryCtx),
+ case {delivery_dup(Delivery), delivery_qos(ConsumerTag, Headers, PState)} of
+ {true, {?QOS_0, ?QOS_1}} ->
+ amqp_channel:cast(
+ Channel, #'basic.ack'{ delivery_tag = DeliveryTag }),
+ {ok, PState};
+ {true, {?QOS_0, ?QOS_0}} ->
+ {ok, PState};
+ {Dup, {DeliveryQos, _SubQos} = Qos} ->
+ TopicName = Amqp2MqttFun(RoutingKey),
+ SendFun(
+ #mqtt_frame{ fixed = #mqtt_frame_fixed{
+ type = ?PUBLISH,
+ qos = DeliveryQos,
+ dup = Dup },
+ variable = #mqtt_frame_publish{
+ message_id =
+ case DeliveryQos of
+ ?QOS_0 -> undefined;
+ ?QOS_1 -> MsgId
+ end,
+ topic_name = TopicName },
+ payload = Payload}, PState),
+ case Qos of
+ {?QOS_0, ?QOS_0} ->
+ {ok, PState};
+ {?QOS_1, ?QOS_1} ->
+ Awaiting1 = gb_trees:insert(MsgId, DeliveryTag, Awaiting),
+ PState1 = PState#proc_state{ awaiting_ack = Awaiting1 },
+ PState2 = next_msg_id(PState1),
+ {ok, PState2};
+ {?QOS_0, ?QOS_1} ->
+ amqp_channel:cast(
+ Channel, #'basic.ack'{ delivery_tag = DeliveryTag }),
+ {ok, PState}
+ end
+ end;
+
+amqp_callback(#'basic.ack'{ multiple = true, delivery_tag = Tag } = Ack,
+ PState = #proc_state{ unacked_pubs = UnackedPubs,
+ send_fun = SendFun }) ->
+ case gb_trees:size(UnackedPubs) > 0 andalso
+ gb_trees:take_smallest(UnackedPubs) of
+ {TagSmall, MsgId, UnackedPubs1} when TagSmall =< Tag ->
+ SendFun(
+ #mqtt_frame{ fixed = #mqtt_frame_fixed{ type = ?PUBACK },
+ variable = #mqtt_frame_publish{ message_id = MsgId }},
+ PState),
+ amqp_callback(Ack, PState #proc_state{ unacked_pubs = UnackedPubs1 });
+ _ ->
+ {ok, PState}
+ end;
+
+amqp_callback(#'basic.ack'{ multiple = false, delivery_tag = Tag },
+ PState = #proc_state{ unacked_pubs = UnackedPubs,
+ send_fun = SendFun }) ->
+ SendFun(
+ #mqtt_frame{ fixed = #mqtt_frame_fixed{ type = ?PUBACK },
+ variable = #mqtt_frame_publish{
+ message_id = gb_trees:get(
+ Tag, UnackedPubs) }}, PState),
+ {ok, PState #proc_state{ unacked_pubs = gb_trees:delete(Tag, UnackedPubs) }}.
+
+delivery_dup({#'basic.deliver'{ redelivered = Redelivered },
+ #amqp_msg{ props = #'P_basic'{ headers = Headers }},
+ _DeliveryCtx}) ->
+ case rabbit_mqtt_util:table_lookup(Headers, <<"x-mqtt-dup">>) of
+ undefined -> Redelivered;
+ {bool, Dup} -> Redelivered orelse Dup
+ end.
+
+ensure_valid_mqtt_message_id(Id) when Id >= 16#ffff ->
+ 1;
+ensure_valid_mqtt_message_id(Id) ->
+ Id.
+
+safe_max_id(Id0, Id1) ->
+ ensure_valid_mqtt_message_id(erlang:max(Id0, Id1)).
+
+next_msg_id(PState = #proc_state{ message_id = MsgId0 }) ->
+ MsgId1 = ensure_valid_mqtt_message_id(MsgId0 + 1),
+ PState#proc_state{ message_id = MsgId1 }.
+
+%% decide at which qos level to deliver based on subscription
+%% and the message publish qos level. non-MQTT publishes are
+%% assumed to be qos 1, regardless of delivery_mode.
+delivery_qos(Tag, _Headers, #proc_state{ consumer_tags = {Tag, _} }) ->
+ {?QOS_0, ?QOS_0};
+delivery_qos(Tag, Headers, #proc_state{ consumer_tags = {_, Tag} }) ->
+ case rabbit_mqtt_util:table_lookup(Headers, <<"x-mqtt-publish-qos">>) of
+ {byte, Qos} -> {lists:min([Qos, ?QOS_1]), ?QOS_1};
+ undefined -> {?QOS_1, ?QOS_1}
+ end.
+
+maybe_clean_sess(PState = #proc_state { clean_sess = false,
+ connection = Conn,
+ client_id = ClientId }) ->
+ SessionPresent = session_present(Conn, ClientId),
+ {_Queue, PState1} = ensure_queue(?QOS_1, PState),
+ {SessionPresent, PState1};
+maybe_clean_sess(PState = #proc_state { clean_sess = true,
+ connection = Conn,
+ client_id = ClientId }) ->
+ {_, Queue} = rabbit_mqtt_util:subcription_queue_name(ClientId),
+ {ok, Channel} = amqp_connection:open_channel(Conn),
+ ok = try amqp_channel:call(Channel, #'queue.delete'{ queue = Queue }) of
+ #'queue.delete_ok'{} -> ok
+ catch
+ exit:_Error -> ok
+ after
+ amqp_channel:close(Channel)
+ end,
+ {false, PState}.
+
+session_present(Conn, ClientId) ->
+ {_, QueueQ1} = rabbit_mqtt_util:subcription_queue_name(ClientId),
+ Declare = #'queue.declare'{queue = QueueQ1,
+ passive = true},
+ {ok, Channel} = amqp_connection:open_channel(Conn),
+ try
+ amqp_channel:call(Channel, Declare),
+ amqp_channel:close(Channel),
+ true
+ catch exit:{{shutdown, {server_initiated_close, ?NOT_FOUND, _Text}}, _} ->
+ false
+ end.
+
+make_will_msg(#mqtt_frame_connect{ will_flag = false }) ->
+ undefined;
+make_will_msg(#mqtt_frame_connect{ will_retain = Retain,
+ will_qos = Qos,
+ will_topic = Topic,
+ will_msg = Msg }) ->
+ #mqtt_msg{ retain = Retain,
+ qos = Qos,
+ topic = Topic,
+ dup = false,
+ payload = Msg }.
+
+process_login(_UserBin, _PassBin, _ProtoVersion,
+ #proc_state{channels = {Channel, _},
+ peer_addr = Addr,
+ auth_state = #auth_state{username = Username,
+ vhost = VHost}}) when is_pid(Channel) ->
+ UsernameStr = rabbit_data_coercion:to_list(Username),
+ VHostStr = rabbit_data_coercion:to_list(VHost),
+ rabbit_core_metrics:auth_attempt_failed(list_to_binary(inet:ntoa(Addr)), Username, mqtt),
+ rabbit_log_connection:warning("MQTT detected duplicate connect/login attempt for user ~p, vhost ~p",
+ [UsernameStr, VHostStr]),
+ connack_dup_auth;
+process_login(UserBin, PassBin, ProtoVersion,
+ #proc_state{channels = {undefined, undefined},
+ socket = Sock,
+ adapter_info = AdapterInfo,
+ ssl_login_name = SslLoginName,
+ peer_addr = Addr}) ->
+ {ok, {_, _, _, ToPort}} = rabbit_net:socket_ends(Sock, inbound),
+ {VHostPickedUsing, {VHost, UsernameBin}} = get_vhost(UserBin, SslLoginName, ToPort),
+ rabbit_log_connection:info(
+ "MQTT vhost picked using ~s~n",
+ [human_readable_vhost_lookup_strategy(VHostPickedUsing)]),
+ RemoteAddress = list_to_binary(inet:ntoa(Addr)),
+ case rabbit_vhost:exists(VHost) of
+ true ->
+ case amqp_connection:start(#amqp_params_direct{
+ username = UsernameBin,
+ password = PassBin,
+ virtual_host = VHost,
+ adapter_info = set_proto_version(AdapterInfo, ProtoVersion)}) of
+ {ok, Connection} ->
+ case rabbit_access_control:check_user_loopback(UsernameBin, Addr) of
+ ok ->
+ rabbit_core_metrics:auth_attempt_succeeded(RemoteAddress, UsernameBin,
+ mqtt),
+ [{internal_user, InternalUser}] = amqp_connection:info(
+ Connection, [internal_user]),
+ {?CONNACK_ACCEPT, Connection, VHost,
+ #auth_state{user = InternalUser,
+ username = UsernameBin,
+ vhost = VHost}};
+ not_allowed ->
+ rabbit_core_metrics:auth_attempt_failed(RemoteAddress, UsernameBin,
+ mqtt),
+ amqp_connection:close(Connection),
+ rabbit_log_connection:warning(
+ "MQTT login failed for ~p access_refused "
+ "(access must be from localhost)~n",
+ [binary_to_list(UsernameBin)]),
+ ?CONNACK_AUTH
+ end;
+ {error, {auth_failure, Explanation}} ->
+ rabbit_core_metrics:auth_attempt_failed(RemoteAddress, UsernameBin, mqtt),
+ rabbit_log_connection:error("MQTT login failed for user '~p' auth_failure: ~s~n",
+ [binary_to_list(UserBin), Explanation]),
+ ?CONNACK_CREDENTIALS;
+ {error, access_refused} ->
+ rabbit_core_metrics:auth_attempt_failed(RemoteAddress, UsernameBin, mqtt),
+ rabbit_log_connection:warning("MQTT login failed for user '~p': access_refused "
+ "(vhost access not allowed)~n",
+ [binary_to_list(UserBin)]),
+ ?CONNACK_AUTH;
+ {error, not_allowed} ->
+ rabbit_core_metrics:auth_attempt_failed(RemoteAddress, UsernameBin, mqtt),
+ %% when vhost allowed for TLS connection
+ rabbit_log_connection:warning("MQTT login failed for ~p access_refused "
+ "(vhost access not allowed)~n",
+ [binary_to_list(UserBin)]),
+ ?CONNACK_AUTH
+ end;
+ false ->
+ rabbit_core_metrics:auth_attempt_failed(RemoteAddress, UsernameBin, mqtt),
+ rabbit_log_connection:error("MQTT login failed for user '~p' auth_failure: vhost ~s does not exist~n",
+ [binary_to_list(UserBin), VHost]),
+ ?CONNACK_CREDENTIALS
+ end.
+
+get_vhost(UserBin, none, Port) ->
+ get_vhost_no_ssl(UserBin, Port);
+get_vhost(UserBin, undefined, Port) ->
+ get_vhost_no_ssl(UserBin, Port);
+get_vhost(UserBin, SslLogin, Port) ->
+ get_vhost_ssl(UserBin, SslLogin, Port).
+
+get_vhost_no_ssl(UserBin, Port) ->
+ case vhost_in_username(UserBin) of
+ true ->
+ {vhost_in_username_or_default, get_vhost_username(UserBin)};
+ false ->
+ PortVirtualHostMapping = rabbit_runtime_parameters:value_global(
+ mqtt_port_to_vhost_mapping
+ ),
+ case get_vhost_from_port_mapping(Port, PortVirtualHostMapping) of
+ undefined ->
+ {default_vhost, {rabbit_mqtt_util:env(vhost), UserBin}};
+ VHost ->
+ {port_to_vhost_mapping, {VHost, UserBin}}
+ end
+ end.
+
+get_vhost_ssl(UserBin, SslLoginName, Port) ->
+ UserVirtualHostMapping = rabbit_runtime_parameters:value_global(
+ mqtt_default_vhosts
+ ),
+ case get_vhost_from_user_mapping(SslLoginName, UserVirtualHostMapping) of
+ undefined ->
+ PortVirtualHostMapping = rabbit_runtime_parameters:value_global(
+ mqtt_port_to_vhost_mapping
+ ),
+ case get_vhost_from_port_mapping(Port, PortVirtualHostMapping) of
+ undefined ->
+ {vhost_in_username_or_default, get_vhost_username(UserBin)};
+ VHostFromPortMapping ->
+ {port_to_vhost_mapping, {VHostFromPortMapping, UserBin}}
+ end;
+ VHostFromCertMapping ->
+ {cert_to_vhost_mapping, {VHostFromCertMapping, UserBin}}
+ end.
+
+vhost_in_username(UserBin) ->
+ case application:get_env(?APP, ignore_colons_in_username) of
+ {ok, true} -> false;
+ _ ->
+ %% split at the last colon, disallowing colons in username
+ case re:split(UserBin, ":(?!.*?:)") of
+ [_, _] -> true;
+ [UserBin] -> false
+ end
+ end.
+
+get_vhost_username(UserBin) ->
+ Default = {rabbit_mqtt_util:env(vhost), UserBin},
+ case application:get_env(?APP, ignore_colons_in_username) of
+ {ok, true} -> Default;
+ _ ->
+ %% split at the last colon, disallowing colons in username
+ case re:split(UserBin, ":(?!.*?:)") of
+ [Vhost, UserName] -> {Vhost, UserName};
+ [UserBin] -> Default
+ end
+ end.
+
+get_vhost_from_user_mapping(_User, not_found) ->
+ undefined;
+get_vhost_from_user_mapping(User, Mapping) ->
+ M = rabbit_data_coercion:to_proplist(Mapping),
+ case rabbit_misc:pget(User, M) of
+ undefined ->
+ undefined;
+ VHost ->
+ VHost
+ end.
+
+get_vhost_from_port_mapping(_Port, not_found) ->
+ undefined;
+get_vhost_from_port_mapping(Port, Mapping) ->
+ M = rabbit_data_coercion:to_proplist(Mapping),
+ Res = case rabbit_misc:pget(rabbit_data_coercion:to_binary(Port), M) of
+ undefined ->
+ undefined;
+ VHost ->
+ VHost
+ end,
+ Res.
+
+human_readable_vhost_lookup_strategy(vhost_in_username_or_default) ->
+ "vhost in username or default";
+human_readable_vhost_lookup_strategy(port_to_vhost_mapping) ->
+ "MQTT port to vhost mapping";
+human_readable_vhost_lookup_strategy(cert_to_vhost_mapping) ->
+ "client certificate to vhost mapping";
+human_readable_vhost_lookup_strategy(default_vhost) ->
+ "plugin configuration or default";
+human_readable_vhost_lookup_strategy(Val) ->
+ atom_to_list(Val).
+
+creds(User, Pass, SSLLoginName) ->
+ DefaultUser = rabbit_mqtt_util:env(default_user),
+ DefaultPass = rabbit_mqtt_util:env(default_pass),
+ {ok, Anon} = application:get_env(?APP, allow_anonymous),
+ {ok, TLSAuth} = application:get_env(?APP, ssl_cert_login),
+ HaveDefaultCreds = Anon =:= true andalso
+ is_binary(DefaultUser) andalso
+ is_binary(DefaultPass),
+
+ CredentialsProvided = User =/= undefined orelse
+ Pass =/= undefined,
+
+ CorrectCredentials = is_list(User) andalso
+ is_list(Pass),
+
+ SSLLoginProvided = TLSAuth =:= true andalso
+ SSLLoginName =/= none,
+
+ case {CredentialsProvided, CorrectCredentials, SSLLoginProvided, HaveDefaultCreds} of
+ %% Username and password take priority
+ {true, true, _, _} -> {list_to_binary(User),
+ list_to_binary(Pass)};
+ %% Either username or password is provided
+ {true, false, _, _} -> {invalid_creds, {User, Pass}};
+ %% rabbitmq_mqtt.ssl_cert_login is true. SSL user name provided.
+ %% Authenticating using username only.
+ {false, false, true, _} -> {SSLLoginName, none};
+ %% Anonymous connection uses default credentials
+ {false, false, false, true} -> {DefaultUser, DefaultPass};
+ _ -> nocreds
+ end.
+
+supported_subs_qos(?QOS_0) -> ?QOS_0;
+supported_subs_qos(?QOS_1) -> ?QOS_1;
+supported_subs_qos(?QOS_2) -> ?QOS_1.
+
+delivery_mode(?QOS_0) -> 1;
+delivery_mode(?QOS_1) -> 2;
+delivery_mode(?QOS_2) -> 2.
+
+%% different qos subscriptions are received in different queues
+%% with appropriate durability and timeout arguments
+%% this will lead to duplicate messages for overlapping subscriptions
+%% with different qos values - todo: prevent duplicates
+ensure_queue(Qos, #proc_state{ channels = {Channel, _},
+ client_id = ClientId,
+ clean_sess = CleanSess,
+ consumer_tags = {TagQ0, TagQ1} = Tags} = PState) ->
+ {QueueQ0, QueueQ1} = rabbit_mqtt_util:subcription_queue_name(ClientId),
+ Qos1Args = case {rabbit_mqtt_util:env(subscription_ttl), CleanSess} of
+ {undefined, _} ->
+ [];
+ {Ms, false} when is_integer(Ms) ->
+ [{<<"x-expires">>, long, Ms}];
+ _ ->
+ []
+ end,
+ QueueSetup =
+ case {TagQ0, TagQ1, Qos} of
+ {undefined, _, ?QOS_0} ->
+ {QueueQ0,
+ #'queue.declare'{ queue = QueueQ0,
+ durable = false,
+ auto_delete = true },
+ #'basic.consume'{ queue = QueueQ0,
+ no_ack = true }};
+ {_, undefined, ?QOS_1} ->
+ {QueueQ1,
+ #'queue.declare'{ queue = QueueQ1,
+ durable = true,
+ %% Clean session means a transient connection,
+ %% translating into auto-delete.
+ %%
+ %% see rabbitmq/rabbitmq-mqtt#37
+ auto_delete = CleanSess,
+ arguments = Qos1Args },
+ #'basic.consume'{ queue = QueueQ1,
+ no_ack = false }};
+ {_, _, ?QOS_0} ->
+ {exists, QueueQ0};
+ {_, _, ?QOS_1} ->
+ {exists, QueueQ1}
+ end,
+ case QueueSetup of
+ {Queue, Declare, Consume} ->
+ #'queue.declare_ok'{} = amqp_channel:call(Channel, Declare),
+ #'basic.consume_ok'{ consumer_tag = Tag } =
+ amqp_channel:call(Channel, Consume),
+ {Queue, PState #proc_state{ consumer_tags = setelement(Qos+1, Tags, Tag) }};
+ {exists, Q} ->
+ {Q, PState}
+ end.
+
+send_will(PState = #proc_state{will_msg = undefined}) ->
+ PState;
+
+send_will(PState = #proc_state{will_msg = WillMsg = #mqtt_msg{retain = Retain,
+ topic = Topic},
+ retainer_pid = RPid,
+ channels = {ChQos0, ChQos1},
+ amqp2mqtt_fun = Amqp2MqttFun}) ->
+ case check_topic_access(Topic, write, PState) of
+ ok ->
+ amqp_pub(WillMsg, PState),
+ case Retain of
+ false -> ok;
+ true ->
+ hand_off_to_retainer(RPid, Amqp2MqttFun, Topic, WillMsg)
+ end;
+ Error ->
+ rabbit_log:warning(
+ "Could not send last will: ~p~n",
+ [Error])
+ end,
+ case ChQos1 of
+ undefined -> ok;
+ _ -> amqp_channel:close(ChQos1)
+ end,
+ case ChQos0 of
+ undefined -> ok;
+ _ -> amqp_channel:close(ChQos0)
+ end,
+ PState #proc_state{ channels = {undefined, undefined} }.
+
+amqp_pub(undefined, PState) ->
+ PState;
+
+%% set up a qos1 publishing channel if necessary
+%% this channel will only be used for publishing, not consuming
+amqp_pub(Msg = #mqtt_msg{ qos = ?QOS_1 },
+ PState = #proc_state{ channels = {ChQos0, undefined},
+ awaiting_seqno = undefined,
+ connection = Conn }) ->
+ {ok, Channel} = amqp_connection:open_channel(Conn),
+ #'confirm.select_ok'{} = amqp_channel:call(Channel, #'confirm.select'{}),
+ amqp_channel:register_confirm_handler(Channel, self()),
+ amqp_pub(Msg, PState #proc_state{ channels = {ChQos0, Channel},
+ awaiting_seqno = 1 });
+
+amqp_pub(#mqtt_msg{ qos = Qos,
+ topic = Topic,
+ dup = Dup,
+ message_id = MessageId,
+ payload = Payload },
+ PState = #proc_state{ channels = {ChQos0, ChQos1},
+ exchange = Exchange,
+ unacked_pubs = UnackedPubs,
+ awaiting_seqno = SeqNo,
+ mqtt2amqp_fun = Mqtt2AmqpFun }) ->
+ RoutingKey = Mqtt2AmqpFun(Topic),
+ Method = #'basic.publish'{ exchange = Exchange,
+ routing_key = RoutingKey },
+ Headers = [{<<"x-mqtt-publish-qos">>, byte, Qos},
+ {<<"x-mqtt-dup">>, bool, Dup}],
+ Msg = #amqp_msg{ props = #'P_basic'{ headers = Headers,
+ delivery_mode = delivery_mode(Qos)},
+ payload = Payload },
+ {UnackedPubs1, Ch, SeqNo1} =
+ case Qos =:= ?QOS_1 andalso MessageId =/= undefined of
+ true -> {gb_trees:enter(SeqNo, MessageId, UnackedPubs), ChQos1,
+ SeqNo + 1};
+ false -> {UnackedPubs, ChQos0, SeqNo}
+ end,
+ amqp_channel:cast_flow(Ch, Method, Msg),
+ PState #proc_state{ unacked_pubs = UnackedPubs1,
+ awaiting_seqno = SeqNo1 }.
+
+adapter_info(Sock, ProtoName) ->
+ amqp_connection:socket_adapter_info(Sock, {ProtoName, "N/A"}).
+
+set_proto_version(AdapterInfo = #amqp_adapter_info{protocol = {Proto, _}}, Vsn) ->
+ AdapterInfo#amqp_adapter_info{protocol = {Proto,
+ human_readable_mqtt_version(Vsn)}}.
+
+human_readable_mqtt_version(3) ->
+ "3.1.0";
+human_readable_mqtt_version(4) ->
+ "3.1.1";
+human_readable_mqtt_version(_) ->
+ "N/A".
+
+serialise_and_send_to_client(Frame, #proc_state{ socket = Sock }) ->
+ try rabbit_net:port_command(Sock, rabbit_mqtt_frame:serialise(Frame)) of
+ Res ->
+ Res
+ catch _:Error ->
+ rabbit_log_connection:error("MQTT: a socket write failed, the socket might already be closed"),
+ rabbit_log_connection:debug("Failed to write to socket ~p, error: ~p, frame: ~p",
+ [Sock, Error, Frame])
+ end.
+
+close_connection(PState = #proc_state{ connection = undefined }) ->
+ PState;
+close_connection(PState = #proc_state{ connection = Connection,
+ client_id = ClientId }) ->
+ % todo: maybe clean session
+ case ClientId of
+ undefined -> ok;
+ _ ->
+ case rabbit_mqtt_collector:unregister(ClientId, self()) of
+ ok -> ok;
+ %% ignore as we are shutting down
+ {timeout, _} -> ok
+ end
+ end,
+ %% ignore noproc or other exceptions, we are shutting down
+ catch amqp_connection:close(Connection),
+ PState #proc_state{ channels = {undefined, undefined},
+ connection = undefined }.
+
+handle_pre_hibernate() ->
+ erase(topic_permission_cache),
+ ok.
+
+handle_ra_event({applied, [{Corr, ok}]},
+ PState = #proc_state{register_state = {pending, Corr}}) ->
+ %% success case - command was applied transition into registered state
+ PState#proc_state{register_state = registered};
+handle_ra_event({not_leader, Leader, Corr},
+ PState = #proc_state{register_state = {pending, Corr},
+ client_id = ClientId}) ->
+ %% retry command against actual leader
+ {ok, NewCorr} = rabbit_mqtt_collector:register(Leader, ClientId, self()),
+ PState#proc_state{register_state = {pending, NewCorr}};
+handle_ra_event(register_timeout,
+ PState = #proc_state{register_state = {pending, _Corr},
+ client_id = ClientId}) ->
+ {ok, NewCorr} = rabbit_mqtt_collector:register(ClientId, self()),
+ PState#proc_state{register_state = {pending, NewCorr}};
+handle_ra_event(register_timeout, PState) ->
+ PState;
+handle_ra_event(Evt, PState) ->
+ %% log these?
+ rabbit_log:debug("unhandled ra_event: ~w ~n", [Evt]),
+ PState.
+
+%% NB: check_*: MQTT spec says we should ack normally, ie pretend there
+%% was no auth error, but here we are closing the connection with an error. This
+%% is what happens anyway if there is an authorization failure at the AMQP 0-9-1 client level.
+
+check_publish(TopicName, Fn, PState) ->
+ case check_topic_access(TopicName, write, PState) of
+ ok -> Fn();
+ _ -> {error, unauthorized, PState}
+ end.
+
+check_subscribe([], Fn, _) ->
+ Fn();
+
+check_subscribe([#mqtt_topic{name = TopicName} | Topics], Fn, PState) ->
+ case check_topic_access(TopicName, read, PState) of
+ ok -> check_subscribe(Topics, Fn, PState);
+ _ -> {error, unauthorized, PState}
+ end.
+
+check_topic_access(TopicName, Access,
+ #proc_state{
+ auth_state = #auth_state{user = User = #user{username = Username},
+ vhost = VHost},
+ exchange = Exchange,
+ client_id = ClientId,
+ mqtt2amqp_fun = Mqtt2AmqpFun }) ->
+ Cache =
+ case get(topic_permission_cache) of
+ undefined -> [];
+ Other -> Other
+ end,
+
+ Key = {TopicName, Username, ClientId, VHost, Exchange, Access},
+ case lists:member(Key, Cache) of
+ true ->
+ ok;
+ false ->
+ Resource = #resource{virtual_host = VHost,
+ kind = topic,
+ name = Exchange},
+
+ RoutingKey = Mqtt2AmqpFun(TopicName),
+ Context = #{routing_key => RoutingKey,
+ variable_map => #{
+ <<"username">> => Username,
+ <<"vhost">> => VHost,
+ <<"client_id">> => rabbit_data_coercion:to_binary(ClientId)
+ }
+ },
+
+ try rabbit_access_control:check_topic_access(User, Resource, Access, Context) of
+ ok ->
+ CacheTail = lists:sublist(Cache, ?MAX_TOPIC_PERMISSION_CACHE_SIZE - 1),
+ put(topic_permission_cache, [Key | CacheTail]),
+ ok;
+ R ->
+ R
+ catch
+ _:{amqp_error, access_refused, Msg, _} ->
+ rabbit_log:error("operation resulted in an error (access_refused): ~p~n", [Msg]),
+ {error, access_refused};
+ _:Error ->
+ rabbit_log:error("~p~n", [Error]),
+ {error, access_refused}
+ end
+ end.
+
+info(consumer_tags, #proc_state{consumer_tags = Val}) -> Val;
+info(unacked_pubs, #proc_state{unacked_pubs = Val}) -> Val;
+info(awaiting_ack, #proc_state{awaiting_ack = Val}) -> Val;
+info(awaiting_seqno, #proc_state{awaiting_seqno = Val}) -> Val;
+info(message_id, #proc_state{message_id = Val}) -> Val;
+info(client_id, #proc_state{client_id = Val}) ->
+ rabbit_data_coercion:to_binary(Val);
+info(clean_sess, #proc_state{clean_sess = Val}) -> Val;
+info(will_msg, #proc_state{will_msg = Val}) -> Val;
+info(channels, #proc_state{channels = Val}) -> Val;
+info(exchange, #proc_state{exchange = Val}) -> Val;
+info(adapter_info, #proc_state{adapter_info = Val}) -> Val;
+info(ssl_login_name, #proc_state{ssl_login_name = Val}) -> Val;
+info(retainer_pid, #proc_state{retainer_pid = Val}) -> Val;
+info(user, #proc_state{auth_state = #auth_state{username = Val}}) -> Val;
+info(vhost, #proc_state{auth_state = #auth_state{vhost = Val}}) -> Val;
+info(host, #proc_state{adapter_info = #amqp_adapter_info{host = Val}}) -> Val;
+info(port, #proc_state{adapter_info = #amqp_adapter_info{port = Val}}) -> Val;
+info(peer_host, #proc_state{adapter_info = #amqp_adapter_info{peer_host = Val}}) -> Val;
+info(peer_port, #proc_state{adapter_info = #amqp_adapter_info{peer_port = Val}}) -> Val;
+info(protocol, #proc_state{adapter_info = #amqp_adapter_info{protocol = Val}}) ->
+ case Val of
+ {Proto, Version} -> {Proto, rabbit_data_coercion:to_binary(Version)};
+ Other -> Other
+ end;
+info(channels, PState) -> additional_info(channels, PState);
+info(channel_max, PState) -> additional_info(channel_max, PState);
+info(frame_max, PState) -> additional_info(frame_max, PState);
+info(client_properties, PState) -> additional_info(client_properties, PState);
+info(ssl, PState) -> additional_info(ssl, PState);
+info(ssl_protocol, PState) -> additional_info(ssl_protocol, PState);
+info(ssl_key_exchange, PState) -> additional_info(ssl_key_exchange, PState);
+info(ssl_cipher, PState) -> additional_info(ssl_cipher, PState);
+info(ssl_hash, PState) -> additional_info(ssl_hash, PState);
+info(Other, _) -> throw({bad_argument, Other}).
+
+
+additional_info(Key,
+ #proc_state{adapter_info =
+ #amqp_adapter_info{additional_info = AddInfo}}) ->
+ proplists:get_value(Key, AddInfo).