summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRobert Newson <rnewson@apache.org>2020-03-13 10:33:13 +0000
committerRobert Newson <rnewson@apache.org>2020-03-13 10:48:30 +0000
commit39b9cc7e741f6b3b9a1f08e7aff8f3e9d0b14325 (patch)
treee8c7836bcdb1c5cf085ddefc2bd0ba788bd362bf
parentb14ec653beaf4e7c287f0749a647b5cb469bd123 (diff)
downloadcouchdb-jwtf-enhance-alg-check.tar.gz
Enhance alg checkjwtf-enhance-alg-check
The "alg" check can now take list of algorithms that are supported, which must be from the valid list of algorithms.
-rw-r--r--src/jwtf/src/jwtf.erl7
-rw-r--r--src/jwtf/test/jwtf_tests.erl12
2 files changed, 15 insertions, 4 deletions
diff --git a/src/jwtf/src/jwtf.erl b/src/jwtf/src/jwtf.erl
index 8e58e0897..0bdc0aa1a 100644
--- a/src/jwtf/src/jwtf.erl
+++ b/src/jwtf/src/jwtf.erl
@@ -139,10 +139,11 @@ validate_alg(Props, Checks) ->
case {Required, Alg} of
{undefined, _} ->
ok;
- {true, undefined} ->
+ {Required, undefined} when Required /= undefined ->
throw({bad_request, <<"Missing alg header parameter">>});
- {true, Alg} ->
- case lists:member(Alg, valid_algorithms()) of
+ {Required, Alg} when Required == true; is_list(Required) ->
+ AllowedAlg = if Required == true -> true; true -> lists:member(Alg, Required) end,
+ case AllowedAlg andalso lists:member(Alg, valid_algorithms()) of
true ->
ok;
false ->
diff --git a/src/jwtf/test/jwtf_tests.erl b/src/jwtf/test/jwtf_tests.erl
index dcebe5f40..222bb4792 100644
--- a/src/jwtf/test/jwtf_tests.erl
+++ b/src/jwtf/test/jwtf_tests.erl
@@ -82,6 +82,16 @@ invalid_alg_test() ->
?assertEqual({error, {bad_request,<<"Invalid alg header parameter">>}},
jwtf:decode(Encoded, [alg], nil)).
+not_allowed_alg_test() ->
+ Encoded = encode({[{<<"alg">>, <<"HS256">>}]}, []),
+ ?assertEqual({error, {bad_request,<<"Invalid alg header parameter">>}},
+ jwtf:decode(Encoded, [{alg, [<<"RS256">>]}], nil)).
+
+reject_unknown_alg_test() ->
+ Encoded = encode({[{<<"alg">>, <<"NOPE">>}]}, []),
+ ?assertEqual({error, {bad_request,<<"Invalid alg header parameter">>}},
+ jwtf:decode(Encoded, [{alg, [<<"NOPE">>]}], nil)).
+
missing_iss_test() ->
Encoded = encode(valid_header(), {[]}),
@@ -176,7 +186,7 @@ hs256_test() ->
"6MTAwMDAwMDAwMDAwMDAsImtpZCI6ImJhciJ9.iS8AH11QHHlczkBn"
"Hl9X119BYLOZyZPllOVhSBZ4RZs">>,
KS = fun(<<"HS256">>, <<"123456">>) -> <<"secret">> end,
- Checks = [{iss, <<"https://foo.com">>}, iat, exp, typ, alg, kid],
+ Checks = [{iss, <<"https://foo.com">>}, iat, exp, typ, {alg, [<<"HS256">>]}, kid],
?assertMatch({ok, _}, catch jwtf:decode(EncodedToken, Checks, KS)).