diff options
Diffstat (limited to 'src/jwtf/src/jwtf_keystore.erl')
-rw-r--r-- | src/jwtf/src/jwtf_keystore.erl | 95 |
1 files changed, 69 insertions, 26 deletions
diff --git a/src/jwtf/src/jwtf_keystore.erl b/src/jwtf/src/jwtf_keystore.erl index 2f2f24744..01b4c8669 100644 --- a/src/jwtf/src/jwtf_keystore.erl +++ b/src/jwtf/src/jwtf_keystore.erl @@ -14,6 +14,8 @@ -behaviour(gen_server). -behaviour(config_listener). +-include_lib("public_key/include/public_key.hrl"). + % public api. -export([ get/2, @@ -29,19 +31,18 @@ % public functions -get(Alg, undefined) -> - get(Alg, "_default"); - -get(Alg, KID) when is_binary(KID) -> - get(Alg, binary_to_list(KID)); +get(Alg, undefined) when is_binary(Alg) -> + get(Alg, <<"_default">>); -get(Alg, KID) -> - case ets:lookup(?MODULE, KID) of +get(Alg, KID0) when is_binary(Alg), is_binary(KID0) -> + Kty = kty(Alg), + KID = binary_to_list(KID0), + case ets:lookup(?MODULE, {Kty, KID}) of [] -> - Key = get_from_config(Alg, KID), - ok = gen_server:call(?MODULE, {set, KID, Key}), + Key = get_from_config(Kty, KID), + ok = gen_server:call(?MODULE, {set, Kty, KID, Key}), Key; - [{KID, Key}] -> + [{{Kty, KID}, Key}] -> Key end. @@ -57,13 +58,13 @@ init(_) -> {ok, nil}. -handle_call({set, KID, Key}, _From, State) -> - true = ets:insert(?MODULE, {KID, Key}), +handle_call({set, Kty, KID, Key}, _From, State) -> + true = ets:insert(?MODULE, {{Kty, KID}, Key}), {reply, ok, State}. -handle_cast({delete, KID}, State) -> - true = ets:delete(?MODULE, KID), +handle_cast({delete, Kty, KID}, State) -> + true = ets:delete(?MODULE, {Kty, KID}), {noreply, State}; handle_cast(_Msg, State) -> @@ -88,8 +89,14 @@ code_change(_OldVsn, State, _Extra) -> % config listener callback -handle_config_change("jwt_keys", KID, _Value, _, _) -> - {ok, gen_server:cast(?MODULE, {delete, KID})}; +handle_config_change("jwt_keys", ConfigKey, _ConfigValue, _, _) -> + case string:split(ConfigKey, ":") of + [Kty, KID] -> + gen_server:cast(?MODULE, {delete, Kty, KID}); + _ -> + ignored + end, + {ok, nil}; handle_config_change(_, _, _, _, _) -> {ok, nil}. @@ -102,17 +109,53 @@ handle_config_terminate(_Server, _Reason, _State) -> % private functions -get_from_config(Alg, KID) -> - case config:get("jwt_keys", KID) of +get_from_config(Kty, KID) -> + case config:get("jwt_keys", string:join([Kty, KID], ":")) of undefined -> throw({bad_request, <<"Unknown kid">>}); - Key -> - case jwtf:verification_algorithm(Alg) of - {hmac, _} -> - base64:decode(Key); - {public_key, _} -> - BinKey = iolist_to_binary(string:replace(Key, "\\n", "\n", all)), - [PEMEntry] = public_key:pem_decode(BinKey), - public_key:pem_entry_decode(PEMEntry) + Encoded -> + case Kty of + "hmac" -> + try + base64:decode(Encoded) + catch + error:badarg -> + throw({bad_request, <<"Not a valid key">>}) + end; + "rsa" -> + case pem_decode(Encoded) of + #'RSAPublicKey'{} = Key -> + Key; + _ -> + throw({bad_request, <<"not an RSA public key">>}) + end; + "ec" -> + case pem_decode(Encoded) of + {#'ECPoint'{}, _} = Key -> + Key; + _ -> + throw({bad_request, <<"not an EC public key">>}) + end end end. + +pem_decode(PEM) -> + BinPEM = iolist_to_binary(string:replace(PEM, "\\n", "\n", all)), + case public_key:pem_decode(BinPEM) of + [PEMEntry] -> + public_key:pem_entry_decode(PEMEntry); + [] -> + throw({bad_request, <<"Not a valid key">>}) + end. + +kty(<<"HS", _/binary>>) -> + "hmac"; + +kty(<<"RS", _/binary>>) -> + "rsa"; + +kty(<<"ES", _/binary>>) -> + "ec"; + +kty(_) -> + throw({bad_request, <<"Unknown kty">>}). |