summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Meadows-Jönsson <eric.meadows.jonsson@gmail.com>2019-08-20 16:08:37 -0700
committerEric Meadows-Jönsson <eric.meadows.jonsson@gmail.com>2019-08-20 16:15:47 -0700
commit6d40036ca63c24cff5cb843e776f4c3e3796e14a (patch)
treeddbf608798009500dd448afb4fdfe41caa6abd0f
parentc605842597310df445a49ee6313b502ba36e67cc (diff)
downloadelixir-6d40036ca63c24cff5cb843e776f4c3e3796e14a.tar.gz
Create unions when key type overlaps in map
-rw-r--r--lib/elixir/lib/module/types/infer.ex40
-rw-r--r--lib/elixir/test/elixir/module/types/infer_test.exs30
2 files changed, 41 insertions, 29 deletions
diff --git a/lib/elixir/lib/module/types/infer.ex b/lib/elixir/lib/module/types/infer.ex
index 957efc828..b46d6a6d0 100644
--- a/lib/elixir/lib/module/types/infer.ex
+++ b/lib/elixir/lib/module/types/infer.ex
@@ -104,27 +104,20 @@ defmodule Module.Types.Infer do
# %{...}
def of_pattern({:%{}, _meta, args} = expr, context) do
- # TODO: Create unions when types for keys overlap, the pattern:
- # `%{1 => :foo, 2 => :bar}`
- # should create the type:
- # `%{integer() => :foo | :bar}`
-
case expr_stack(expr, context, &of_pairs(args, &1)) do
- {:ok, pairs, context} -> {:ok, {:map, pairs}, context}
+ {:ok, pairs, context} -> {:ok, {:map, pairs_to_unions(pairs, context)}, context}
{:error, reason} -> {:error, reason}
end
end
# %var{...}
def of_pattern({:%, _meta1, [var, {:%{}, _meta2, args}]} = expr, context) when is_var(var) do
- # TODO: Create unions when types for keys overlap, see above
-
expr_stack(expr, context, fn context ->
with {:ok, pairs, context} <- of_pairs(args, context),
{var_type, context} = new_var(var, context),
{:ok, _, context} <- unify(var_type, :atom, context) do
pairs = [{{:literal, :__struct__}, var_type} | pairs]
- {:ok, {:map, pairs}, context}
+ {:ok, {:map, pairs_to_unions(pairs, context)}, context}
end
end)
end
@@ -132,16 +125,18 @@ defmodule Module.Types.Infer do
# %Struct{...}
def of_pattern({:%, _meta1, [module, {:%{}, _meta2, args}]} = expr, context)
when is_atom(module) do
- # TODO: Create unions when types for keys overlap, see above
-
struct_pairs =
Enum.map(:maps.remove(:__struct__, module.__struct__()), fn {key, value} ->
{term_to_type(key), term_to_type(value)}
end)
case expr_stack(expr, context, &of_pairs(args, &1)) do
- {:ok, pattern_pairs, context} -> {:ok, {:map, pattern_pairs ++ struct_pairs}, context}
- {:error, reason} -> {:error, reason}
+ {:ok, pattern_pairs, context} ->
+ pairs = pairs_to_unions(pattern_pairs ++ struct_pairs, context)
+ {:ok, {:map, pairs}, context}
+
+ {:error, reason} ->
+ {:error, reason}
end
end
@@ -153,6 +148,25 @@ defmodule Module.Types.Infer do
end)
end
+ defp pairs_to_unions(pairs, context) do
+ # Maps only allow simple literal keys in patterns and
+ # term_to_type/1 does not return supertypes so we do
+ # not have to do subtype checking
+
+ Enum.reduce(pairs, [], fn {key, value}, pairs ->
+ case :lists.keyfind(key, 1, pairs) do
+ {^key, {:union, union}} ->
+ :lists.keystore(key, 1, pairs, {key, to_union([value | union], context)})
+
+ {^key, original_value} ->
+ :lists.keystore(key, 1, pairs, {key, to_union([value, original_value], context)})
+
+ false ->
+ [{key, value} | pairs]
+ end
+ end)
+ end
+
def term_to_type(term) when is_atom(term), do: {:literal, term}
def term_to_type(term) when is_bitstring(term), do: :binary
def term_to_type(term) when is_float(term), do: :float
diff --git a/lib/elixir/test/elixir/module/types/infer_test.exs b/lib/elixir/test/elixir/module/types/infer_test.exs
index eb968b70b..41de3893b 100644
--- a/lib/elixir/test/elixir/module/types/infer_test.exs
+++ b/lib/elixir/test/elixir/module/types/infer_test.exs
@@ -85,9 +85,8 @@ defmodule Module.Types.InferTest do
assert quoted_pattern(%{a: :b}) == {:ok, {:map, [{{:literal, :a}, {:literal, :b}}]}}
assert quoted_pattern(%{123 => a}) == {:ok, {:map, [{:integer, {:var, 0}}]}}
- # TODO
- # assert quoted_pattern(%{123 => :foo, 456 => :bar}) ==
- # {:ok, {:map, [{:integer, {:union, [{:literal, :foo}, {:literal, :bar}]}}]}}
+ assert quoted_pattern(%{123 => :foo, 456 => :bar}) ==
+ {:ok, {:map, [{:integer, {:union, [{:literal, :bar}, {:literal, :foo}]}}]}}
assert {:error, {{:unable_unify, {:literal, :foo}, :integer, _, _}, _}} =
quoted_pattern(%{a: a = 123, b: a = :foo})
@@ -102,30 +101,29 @@ defmodule Module.Types.InferTest do
{:ok,
{:map,
[
- {{:literal, :bar}, :integer},
+ {{:literal, :foo}, {:literal, :atom}},
{{:literal, :baz}, {:map, []}},
- {{:literal, :foo}, {:literal, :atom}}
+ {{:literal, :bar}, :integer}
]}}
- # TODO
- # assert quoted_pattern(%:"Elixir.Module.Types.InferTest.Struct"{foo: 123, bar: :atom}) ==
- # {:ok,
- # {:map,
- # [
- # {{:literal, :bar}, {:literal, :atom}},
- # {{:literal, :baz}, {:map, []}},
- # {{:literal, :foo}, :integer}
- # ]}}
+ assert quoted_pattern(%:"Elixir.Module.Types.InferTest.Struct"{foo: 123, bar: :atom}) ==
+ {:ok,
+ {:map,
+ [
+ {{:literal, :baz}, {:map, []}},
+ {{:literal, :bar}, {:union, [:integer, {:literal, :atom}]}},
+ {{:literal, :foo}, {:union, [{:literal, :atom}, :integer]}}
+ ]}}
end
test "struct var" do
assert quoted_pattern(%var{}) == {:ok, {:map, [{{:literal, :__struct__}, :atom}]}}
assert quoted_pattern(%var{foo: 123}) ==
- {:ok, {:map, [{{:literal, :__struct__}, :atom}, {{:literal, :foo}, :integer}]}}
+ {:ok, {:map, [{{:literal, :foo}, :integer}, {{:literal, :__struct__}, :atom}]}}
assert quoted_pattern(%var{foo: var}) ==
- {:ok, {:map, [{{:literal, :__struct__}, :atom}, {{:literal, :foo}, :atom}]}}
+ {:ok, {:map, [{{:literal, :foo}, :atom}, {{:literal, :__struct__}, :atom}]}}
end
test "binary" do