summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Evans <code@jeremyevans.net>2021-03-05 12:25:51 -0800
committerJeremy Evans <code@jeremyevans.net>2021-03-06 13:56:16 -0800
commite1d16a9e560a615e122e457325bcfb7c47228ed6 (patch)
tree1e93b0256e763f3eedcea7a03309bd12f5007005
parentbf40fe9fed19a5e22081b133661c0629988f1618 (diff)
downloadruby-e1d16a9e560a615e122e457325bcfb7c47228ed6.tar.gz
Make Enumerator#{+,chain} create lazy chain if any included enumerator is lazy
Implements [Feature #17347]
-rw-r--r--enumerator.c21
-rw-r--r--test/ruby/test_enumerator.rb12
2 files changed, 28 insertions, 5 deletions
diff --git a/enumerator.c b/enumerator.c
index 1c1ece0cfe..45620f352a 100644
--- a/enumerator.c
+++ b/enumerator.c
@@ -3137,6 +3137,20 @@ enum_chain_initialize(VALUE obj, VALUE enums)
return obj;
}
+static VALUE
+new_enum_chain(VALUE enums) {
+ long i;
+ VALUE obj = enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums);
+
+ for (i = 0; i < RARRAY_LEN(enums); i++) {
+ if (RTEST(rb_obj_is_kind_of(RARRAY_AREF(enums, i), rb_cLazy))) {
+ return enumerable_lazy(obj);
+ }
+ }
+
+ return obj;
+}
+
/* :nodoc: */
static VALUE
enum_chain_init_copy(VALUE obj, VALUE orig)
@@ -3306,8 +3320,7 @@ enum_chain(int argc, VALUE *argv, VALUE obj)
{
VALUE enums = rb_ary_new_from_values(1, &obj);
rb_ary_cat(enums, argv, argc);
-
- return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums);
+ return new_enum_chain(enums);
}
/*
@@ -3323,9 +3336,7 @@ enum_chain(int argc, VALUE *argv, VALUE obj)
static VALUE
enumerator_plus(VALUE obj, VALUE eobj)
{
- VALUE enums = rb_ary_new_from_args(2, obj, eobj);
-
- return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums);
+ return new_enum_chain(rb_ary_new_from_args(2, obj, eobj));
}
/*
diff --git a/test/ruby/test_enumerator.rb b/test/ruby/test_enumerator.rb
index 9b615ff9db..4e698fc478 100644
--- a/test/ruby/test_enumerator.rb
+++ b/test/ruby/test_enumerator.rb
@@ -820,6 +820,18 @@ class TestEnumerator < Test::Unit::TestCase
assert_equal([[3, 0], [4, 1]], [3].chain([4]).with_index.to_a)
end
+ def test_lazy_chain
+ ea = (10..).lazy.select(&:even?).take(10)
+ ed = (20..).lazy.select(&:odd?)
+ chain = (ea + ed).select{|x| x % 3 == 0}
+ assert_equal(12, chain.next)
+ assert_equal(18, chain.next)
+ assert_equal(24, chain.next)
+ assert_equal(21, chain.next)
+ assert_equal(27, chain.next)
+ assert_equal(33, chain.next)
+ end
+
def test_produce
assert_raise(ArgumentError) { Enumerator.produce }