From 1a73a6cdd2068b815430b775fe25186dab693faa Mon Sep 17 00:00:00 2001 From: Akinori MUSHA Date: Fri, 29 Jul 2022 13:56:54 +0900 Subject: Implement Enumerator::Product and Enumerator.product [Feature #18685] --- enumerator.c | 354 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 353 insertions(+), 1 deletion(-) (limited to 'enumerator.c') diff --git a/enumerator.c b/enumerator.c index 89abf4b888..ce2eacbd2a 100644 --- a/enumerator.c +++ b/enumerator.c @@ -125,7 +125,7 @@ */ VALUE rb_cEnumerator; static VALUE rb_cLazy; -static ID id_rewind, id_new, id_to_enum; +static ID id_rewind, id_new, id_to_enum, id_each_entry; static ID id_next, id_result, id_receiver, id_arguments, id_memo, id_method, id_force; static ID id_begin, id_end, id_step, id_exclude_end; static VALUE sym_each, sym_cycle, sym_yield; @@ -194,6 +194,12 @@ struct enum_chain { long pos; }; +static VALUE rb_cEnumProduct; + +struct enum_product { + VALUE enums; +}; + VALUE rb_cArithSeq; /* @@ -3347,6 +3353,335 @@ enumerator_plus(VALUE obj, VALUE eobj) return new_enum_chain(rb_ary_new_from_args(2, obj, eobj)); } +/* + * Document-class: Enumerator::Product + * + * Enumerator::Product generates a Cartesian product of any number of + * enumerable objects. Iterating over the product of enumerable + * objects is roughly equivalent to nested each_entry loops where the + * loop for the rightmost object is put innermost. + * + * innings = Enumerator::Product.new(1..9, ['top', 'bottom']) + * + * innings.each do |i, h| + * p [i, h] + * end + * # [1, "top"] + * # [1, "bottom"] + * # [2, "top"] + * # [2, "bottom"] + * # [3, "top"] + * # [3, "bottom"] + * # ... + * # [9, "top"] + * # [9, "bottom"] + * + * The method used against each enumerable object is `each_entry` + * instead of `each` so that the product of N enumerable objects + * yields exactly N arguments in each iteration. + * + * When no enumerator is given, it calls a given block once yielding + * an empty argument list. + * + * This type of objects can be created by Enumerator.product. + */ + +static void +enum_product_mark(void *p) +{ + struct enum_product *ptr = p; + rb_gc_mark_movable(ptr->enums); +} + +static void +enum_product_compact(void *p) +{ + struct enum_product *ptr = p; + ptr->enums = rb_gc_location(ptr->enums); +} + +#define enum_product_free RUBY_TYPED_DEFAULT_FREE + +static size_t +enum_product_memsize(const void *p) +{ + return sizeof(struct enum_product); +} + +static const rb_data_type_t enum_product_data_type = { + "product", + { + enum_product_mark, + enum_product_free, + enum_product_memsize, + enum_product_compact, + }, + 0, 0, RUBY_TYPED_FREE_IMMEDIATELY +}; + +static struct enum_product * +enum_product_ptr(VALUE obj) +{ + struct enum_product *ptr; + + TypedData_Get_Struct(obj, struct enum_product, &enum_product_data_type, ptr); + if (!ptr || ptr->enums == Qundef) { + rb_raise(rb_eArgError, "uninitialized product"); + } + return ptr; +} + +/* :nodoc: */ +static VALUE +enum_product_allocate(VALUE klass) +{ + struct enum_product *ptr; + VALUE obj; + + obj = TypedData_Make_Struct(klass, struct enum_product, &enum_product_data_type, ptr); + ptr->enums = Qundef; + + return obj; +} + +/* + * call-seq: + * Enumerator::Product.new(*enums) -> enum + * + * Generates a new enumerator object that generates a Cartesian + * product of given enumerable objects. + * + * e = Enumerator::Product.new(1..3, [4, 5]) + * e.to_a #=> [[1, 4], [1, 5], [2, 4], [2, 5], [3, 4], [3, 5]] + * e.size #=> 6 + */ +static VALUE +enum_product_initialize(VALUE obj, VALUE enums) +{ + struct enum_product *ptr; + + rb_check_frozen(obj); + TypedData_Get_Struct(obj, struct enum_product, &enum_product_data_type, ptr); + + if (!ptr) rb_raise(rb_eArgError, "unallocated product"); + + ptr->enums = rb_obj_freeze(enums); + + return obj; +} + +/* :nodoc: */ +static VALUE +enum_product_init_copy(VALUE obj, VALUE orig) +{ + struct enum_product *ptr0, *ptr1; + + if (!OBJ_INIT_COPY(obj, orig)) return obj; + ptr0 = enum_product_ptr(orig); + + TypedData_Get_Struct(obj, struct enum_product, &enum_product_data_type, ptr1); + + if (!ptr1) rb_raise(rb_eArgError, "unallocated product"); + + ptr1->enums = ptr0->enums; + + return obj; +} + +static VALUE +enum_product_total_size(VALUE enums) +{ + VALUE total = INT2FIX(1); + long i; + + for (i = 0; i < RARRAY_LEN(enums); i++) { + VALUE size = enum_size(RARRAY_AREF(enums, i)); + + if (NIL_P(size) || (RB_TYPE_P(size, T_FLOAT) && isinf(NUM2DBL(size)))) { + return size; + } + if (!RB_INTEGER_TYPE_P(size)) { + return Qnil; + } + + total = rb_funcall(total, '*', 1, size); + } + + return total; +} + +/* + * call-seq: + * obj.size -> int, Float::INFINITY or nil + * + * Returns the total size of the enumerator product calculated by + * multiplying the sizes of enumerables in the product. If any of the + * enumerables reports its size as nil or Float::INFINITY, that value + * is returned as the size. + */ +static VALUE +enum_product_size(VALUE obj) +{ + return enum_product_total_size(enum_product_ptr(obj)->enums); +} + +static VALUE +enum_product_enum_size(VALUE obj, VALUE args, VALUE eobj) +{ + return enum_product_size(obj); +} + +struct product_state { + VALUE obj; + VALUE block; + int argc; + VALUE *argv; + int index; +}; + +static VALUE product_each(VALUE, struct product_state *); + +static VALUE +product_each_i(RB_BLOCK_CALL_FUNC_ARGLIST(value, state)) +{ + struct product_state *pstate = (struct product_state *)state; + pstate->argv[pstate->index++] = value; + + VALUE val = product_each(pstate->obj, pstate); + pstate->index--; + return val; +} + +static VALUE +product_each(VALUE obj, struct product_state *pstate) +{ + struct enum_product *ptr = enum_product_ptr(obj); + VALUE enums = ptr->enums; + + if (pstate->index < pstate->argc) { + VALUE eobj = RARRAY_AREF(enums, pstate->index); + + rb_block_call(eobj, id_each_entry, 0, NULL, product_each_i, (VALUE)pstate); + } else { + rb_funcallv(pstate->block, id_call, pstate->argc, pstate->argv); + } + + return obj; +} + +static VALUE +enum_product_run(VALUE obj, VALUE block) +{ + struct enum_product *ptr = enum_product_ptr(obj); + int argc = RARRAY_LENINT(ptr->enums); + struct product_state state = { + .obj = obj, + .block = block, + .index = 0, + .argc = argc, + .argv = ALLOCA_N(VALUE, argc), + }; + + return product_each(obj, &state); +} + +/* + * call-seq: + * obj.each { |...| ... } -> obj + * obj.each -> enumerator + * + * Iterates over the elements of the first enumerable by calling the + * "each_entry" method on it with the given arguments, then proceeds + * to the following enumerables in sequence until all of the + * enumerables are exhausted. + * + * If no block is given, returns an enumerator. Otherwise, returns self. + */ +static VALUE +enum_product_each(VALUE obj) +{ + RETURN_SIZED_ENUMERATOR(obj, 0, 0, enum_product_enum_size); + + return enum_product_run(obj, rb_block_proc()); +} + +/* + * call-seq: + * obj.rewind -> obj + * + * Rewinds the product enumerator by calling the "rewind" method on + * each enumerable in reverse order. Each call is performed only if + * the enumerable responds to the method. + */ +static VALUE +enum_product_rewind(VALUE obj) +{ + struct enum_product *ptr = enum_product_ptr(obj); + VALUE enums = ptr->enums; + long i; + + for (i = 0; i < RARRAY_LEN(enums); i++) { + rb_check_funcall(RARRAY_AREF(enums, i), id_rewind, 0, 0); + } + + return obj; +} + +static VALUE +inspect_enum_product(VALUE obj, VALUE dummy, int recur) +{ + VALUE klass = rb_obj_class(obj); + struct enum_product *ptr; + + TypedData_Get_Struct(obj, struct enum_product, &enum_product_data_type, ptr); + + if (!ptr || ptr->enums == Qundef) { + return rb_sprintf("#<%"PRIsVALUE": uninitialized>", rb_class_path(klass)); + } + + if (recur) { + return rb_sprintf("#<%"PRIsVALUE": ...>", rb_class_path(klass)); + } + + return rb_sprintf("#<%"PRIsVALUE": %+"PRIsVALUE">", rb_class_path(klass), ptr->enums); +} + +/* + * call-seq: + * obj.inspect -> string + * + * Returns a printable version of the product enumerator. + */ +static VALUE +enum_product_inspect(VALUE obj) +{ + return rb_exec_recursive(inspect_enum_product, obj, 0); +} + +/* + * call-seq: + * Enumerator.product(*enums) -> enumerator + * + * Generates a new enumerator object that generates a Cartesian + * product of given enumerable objects. This is equivalent to + * Enumerator::Product.new. + * + * e = Enumerator.product(1..3, [4, 5]) + * e.to_a #=> [[1, 4], [1, 5], [2, 4], [2, 5], [3, 4], [3, 5]] + * e.size #=> 6 + */ +static VALUE +enumerator_s_product(VALUE klass, VALUE enums) +{ + VALUE obj = enum_product_initialize(enum_product_allocate(rb_cEnumProduct), enums); + + if (rb_block_given_p()) { + return enum_product_run(obj, rb_block_proc()); + } else { + return obj; + } +} + /* * Document-class: Enumerator::ArithmeticSequence * @@ -4214,6 +4549,22 @@ InitVM_Enumerator(void) rb_undef_method(rb_cEnumChain, "peek"); rb_undef_method(rb_cEnumChain, "peek_values"); + /* Product */ + rb_cEnumProduct = rb_define_class_under(rb_cEnumerator, "Product", rb_cEnumerator); + rb_define_alloc_func(rb_cEnumProduct, enum_product_allocate); + rb_define_method(rb_cEnumProduct, "initialize", enum_product_initialize, -2); + rb_define_method(rb_cEnumProduct, "initialize_copy", enum_product_init_copy, 1); + rb_define_method(rb_cEnumProduct, "each", enum_product_each, 0); + rb_define_method(rb_cEnumProduct, "size", enum_product_size, 0); + rb_define_method(rb_cEnumProduct, "rewind", enum_product_rewind, 0); + rb_define_method(rb_cEnumProduct, "inspect", enum_product_inspect, 0); + rb_undef_method(rb_cEnumProduct, "feed"); + rb_undef_method(rb_cEnumProduct, "next"); + rb_undef_method(rb_cEnumProduct, "next_values"); + rb_undef_method(rb_cEnumProduct, "peek"); + rb_undef_method(rb_cEnumProduct, "peek_values"); + rb_define_singleton_method(rb_cEnumerator, "product", enumerator_s_product, -2); + /* ArithmeticSequence */ rb_cArithSeq = rb_define_class_under(rb_cEnumerator, "ArithmeticSequence", rb_cEnumerator); rb_undef_alloc_func(rb_cArithSeq); @@ -4249,6 +4600,7 @@ Init_Enumerator(void) id_method = rb_intern_const("method"); id_force = rb_intern_const("force"); id_to_enum = rb_intern_const("to_enum"); + id_each_entry = rb_intern_const("each_entry"); id_begin = rb_intern_const("begin"); id_end = rb_intern_const("end"); id_step = rb_intern_const("step"); -- cgit v1.2.1