Skip to content

Commit

Permalink
Make it possible to have a MFA transport (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmzeeman authored May 29, 2024
1 parent a8563f3 commit b6332bb
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
9 changes: 6 additions & 3 deletions src/mqtt_sessions.erl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@

-type session_ref() :: pid() | binary().
-type msg_options() :: #{
transport => pid() | function(),
transport => transport(),
peer_ip => tuple() | undefined,
context_prefs => map(),
connection_pid => pid()
Expand All @@ -94,6 +94,8 @@

-type callback() :: pid() | {module(), atom(), list()}.

-type transport() :: function() | callback().

-export_type([
session_ref/0,
msg_options/0,
Expand All @@ -102,7 +104,8 @@
subscriber/0,
subscriber_options/0,
topic/0,
callback/0
callback/0,
transport/0
]).

-define(SIDEJOBS_PER_SESSION, 20).
Expand Down Expand Up @@ -188,7 +191,7 @@ update_user_context(Pool, ClientId, Fun) ->
{error, _} = Error -> Error
end.

-spec get_transport( pid() ) -> {ok, pid()} | {error, notransport | noproc}.
-spec get_transport( pid() ) -> {ok, transport()} | {error, notransport | noproc}.
get_transport(SessionPid) ->
mqtt_sessions_process:get_transport(SessionPid).

Expand Down
13 changes: 8 additions & 5 deletions src/mqtt_sessions_process.erl
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,15 @@

-type packet_id() :: 0..65535. % ?MAX_PACKET_ID


-record(state, {
protocol_version :: mqtt_packet_map:mqtt_version(),
pool :: atom(),
runtime :: atom(),
client_id :: binary(),
routing_id :: binary(),
user_context :: term(),
transport = undefined :: pid() | function() | undefined,
transport = undefined :: mqtt_sessions:transport() | undefined,
connection_pid = undefined :: pid() | undefined,
is_session_present = false :: boolean(),
pending_connack = undefined :: term(),
Expand Down Expand Up @@ -145,7 +146,7 @@ update_user_context(Pid, Fun) ->
{error, noproc}
end.

-spec get_transport( pid() ) -> {ok, pid()} | {error, notransport | noproc}.
-spec get_transport( pid() ) -> {ok, mqtt_sessions:transport()} | {error, notransport | noproc}.
get_transport(Pid) ->
try
gen_server:call(Pid, get_transport, infinity)
Expand Down Expand Up @@ -230,8 +231,8 @@ handle_call({update_user_context, Fun}, _From, #state{ user_context = UserContex

handle_call(get_transport, _From, #state{ transport = undefined } = State) ->
{reply, {error, notransport}, State};
handle_call(get_transport, _From, #state{ transport = TransportPid } = State) ->
{reply, {ok, TransportPid}, State};
handle_call(get_transport, _From, #state{ transport = Transport } = State) ->
{reply, {ok, Transport}, State};

handle_call({incoming_data, NewData, ConnectionPid}, _From, #state{ incoming_data = Data, connection_pid = ConnectionPid } = State) ->
Data1 = << Data/binary, NewData/binary >>,
Expand Down Expand Up @@ -1030,7 +1031,9 @@ send_transport(Msg, #state{ transport = Pid }) when is_pid(Pid) ->
ok
end;
send_transport(Msg, #state{ transport = Fun }) when is_function(Fun) ->
Fun(Msg).
Fun(Msg);
send_transport(Msg, #state{ transport = {M, F, A} }) ->
erlang:apply(M, F, [Msg | A]).


%% @doc Queue a message, extract, type, message expiry, and QoS
Expand Down

0 comments on commit b6332bb

Please sign in to comment.