summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--class.c56
-rw-r--r--test/ruby/test_module.rb13
2 files changed, 48 insertions, 21 deletions
diff --git a/class.c b/class.c
index 12a67d16bc..68cfbfb257 100644
--- a/class.c
+++ b/class.c
@@ -351,7 +351,7 @@ copy_tables(VALUE clone, VALUE orig)
}
}
-static void ensure_origin(VALUE klass);
+static bool ensure_origin(VALUE klass);
/* :nodoc: */
VALUE
@@ -1014,27 +1014,31 @@ clear_module_cache_i(ID id, VALUE val, void *data)
return ID_TABLE_CONTINUE;
}
+static bool
+module_in_super_chain(const VALUE klass, VALUE module)
+{
+ struct rb_id_table *const klass_m_tbl = RCLASS_M_TBL(RCLASS_ORIGIN(klass));
+ if (klass_m_tbl) {
+ while (module) {
+ if (klass_m_tbl == RCLASS_M_TBL(module))
+ return true;
+ module = RCLASS_SUPER(module);
+ }
+ }
+ return false;
+}
+
static int
-include_modules_at(const VALUE klass, VALUE c, VALUE module, int search_super)
+do_include_modules_at(const VALUE klass, VALUE c, VALUE module, int search_super, bool check_cyclic)
{
VALUE p, iclass, origin_stack = 0;
int method_changed = 0, constant_changed = 0, add_subclass;
long origin_len;
VALUE klass_origin = RCLASS_ORIGIN(klass);
- struct rb_id_table *const klass_m_tbl = RCLASS_M_TBL(klass_origin);
VALUE original_klass = klass;
- if (klass_m_tbl) {
- VALUE original_module = module;
-
- while (module) {
- if (klass_m_tbl == RCLASS_M_TBL(module))
- return -1;
- module = RCLASS_SUPER(module);
- }
-
- module = original_module;
- }
+ if (check_cyclic && module_in_super_chain(klass, module))
+ return -1;
while (module) {
int c_seen = FALSE;
@@ -1129,6 +1133,12 @@ include_modules_at(const VALUE klass, VALUE c, VALUE module, int search_super)
return method_changed;
}
+static int
+include_modules_at(const VALUE klass, VALUE c, VALUE module, int search_super)
+{
+ return do_include_modules_at(klass, c, module, search_super, true);
+}
+
static enum rb_id_table_iterator_result
move_refined_method(ID key, VALUE value, void *data)
{
@@ -1169,7 +1179,7 @@ cache_clear_refined_method(ID key, VALUE value, void *data)
return ID_TABLE_CONTINUE;
}
-static void
+static bool
ensure_origin(VALUE klass)
{
VALUE origin = RCLASS_ORIGIN(klass);
@@ -1182,20 +1192,24 @@ ensure_origin(VALUE klass)
RCLASS_M_TBL_INIT(klass);
rb_id_table_foreach(RCLASS_M_TBL(origin), cache_clear_refined_method, (void *)klass);
rb_id_table_foreach(RCLASS_M_TBL(origin), move_refined_method, (void *)klass);
+ return true;
}
+ return false;
}
void
rb_prepend_module(VALUE klass, VALUE module)
{
- int changed = 0;
- bool klass_had_no_origin = RCLASS_ORIGIN(klass) == klass;
+ int changed;
+ bool klass_had_no_origin;
ensure_includable(klass, module);
- ensure_origin(klass);
- changed = include_modules_at(klass, klass, module, FALSE);
- if (changed < 0)
- rb_raise(rb_eArgError, "cyclic prepend detected");
+ if (module_in_super_chain(klass, module))
+ rb_raise(rb_eArgError, "cyclic prepend detected");
+
+ klass_had_no_origin = ensure_origin(klass);
+ changed = do_include_modules_at(klass, klass, module, FALSE, false);
+ RUBY_ASSERT(changed >= 0); // already checked for cyclic prepend above
if (changed) {
rb_vm_check_redefinition_by_prepend(klass);
}
diff --git a/test/ruby/test_module.rb b/test/ruby/test_module.rb
index 0a5597fd6c..e5152b1012 100644
--- a/test/ruby/test_module.rb
+++ b/test/ruby/test_module.rb
@@ -485,6 +485,19 @@ class TestModule < Test::Unit::TestCase
assert_equal([m], m.ancestors)
end
+ def test_bug17590
+ m = Module.new
+ c = Class.new
+ c.prepend(m)
+ c.include(m)
+ m.prepend(m) rescue nil
+ m2 = Module.new
+ m2.prepend(m)
+ c.include(m2)
+
+ assert_equal([m, c, m2] + Object.ancestors, c.ancestors)
+ end
+
def test_prepend_works_with_duped_classes
m = Module.new
a = Class.new do