diff options
author | Jimmy Miller <jimmy.miller@shopify.com> | 2023-03-07 17:03:43 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-07 17:03:43 -0500 |
commit | 56df6d5f9d986a7959eb9cac27e21bc2ed505319 (patch) | |
tree | 3ce9efc14aa28e2b5e9f050fc8a7037980d78fb1 | |
parent | 33edcc112081f96856d52e73253d73c97a5c4a3c (diff) | |
download | ruby-56df6d5f9d986a7959eb9cac27e21bc2ed505319.tar.gz |
YJIT: Handle splat+rest for args pass greater than required (#7468)
For example:
```ruby
def my_func(x, y, *rest)
p [x, y, rest]
end
my_func(1, 2, 3, *[4, 5])
```
-rw-r--r-- | array.c | 6 | ||||
-rw-r--r-- | yjit.c | 3 | ||||
-rw-r--r-- | yjit/bindgen/src/main.rs | 1 | ||||
-rw-r--r-- | yjit/src/codegen.rs | 45 | ||||
-rw-r--r-- | yjit/src/cruby_bindings.inc.rs | 5 | ||||
-rw-r--r-- | yjit/src/stats.rs | 2 |
6 files changed, 54 insertions, 8 deletions
@@ -1802,6 +1802,12 @@ rb_ary_unshift_m(int argc, VALUE *argv, VALUE ary) return ary; } +/* non-static for yjit */ +VALUE +rb_yjit_rb_ary_unshift_m(int argc, VALUE *argv, VALUE ary) { + return rb_ary_unshift_m(argc, argv, ary); +} + VALUE rb_ary_unshift(VALUE ary, VALUE item) { @@ -848,6 +848,9 @@ rb_yarv_ary_entry_internal(VALUE ary, long offset) } VALUE +rb_yjit_rb_ary_unshift_m(int argc, VALUE *argv, VALUE ary); + +VALUE rb_yarv_fix_mod_fix(VALUE recv, VALUE obj) { return rb_fix_mod_fix(recv, obj); diff --git a/yjit/bindgen/src/main.rs b/yjit/bindgen/src/main.rs index 481c403714..71bf1df3e8 100644 --- a/yjit/bindgen/src/main.rs +++ b/yjit/bindgen/src/main.rs @@ -136,6 +136,7 @@ fn main() { .allowlist_function("rb_ary_resurrect") .allowlist_function("rb_ary_clear") .allowlist_function("rb_ary_dup") + .allowlist_function("rb_yjit_rb_ary_unshift_m") // From internal/array.h .allowlist_function("rb_ec_ary_new_from_values") diff --git a/yjit/src/codegen.rs b/yjit/src/codegen.rs index 64f2d6e654..74201154e4 100644 --- a/yjit/src/codegen.rs +++ b/yjit/src/codegen.rs @@ -1170,7 +1170,7 @@ fn gen_newarray( let values_ptr = if n == 0 { Opnd::UImm(0) } else { - asm.comment("load pointer to array elts"); + asm.comment("load pointer to array elements"); let offset_magnitude = (SIZEOF_VALUE as u32) * n; let values_opnd = ctx.sp_opnd(-(offset_magnitude as isize)); asm.lea(values_opnd) @@ -5321,9 +5321,13 @@ fn gen_send_iseq( // foo(1, 2, *[3, 4]) // In this case, we can just dup the splat array as the rest array. // No need to move things around between the array and stack. - if iseq_has_rest && flags & VM_CALL_ARGS_SPLAT != 0 && argc - 1 != required_num { - gen_counter_incr!(asm, send_iseq_has_rest_and_splat_not_equal); - return CantCompile; + + let non_rest_arg_count = argc - 1; + if iseq_has_rest && flags & VM_CALL_ARGS_SPLAT != 0 && non_rest_arg_count != required_num { + if non_rest_arg_count < required_num { + gen_counter_incr!(asm, send_iseq_has_rest_and_splat_fewer); + return CantCompile; + } } // This struct represents the metadata about the caller-specified @@ -5792,9 +5796,36 @@ fn gen_send_iseq( rb_ary_dup as *const u8, vec![array], ); - let stack_ret = ctx.stack_push(Type::TArray); - asm.mov(stack_ret, array); + if non_rest_arg_count > required_num { + // If we have more arguments than required, we need to prepend + // the items from the stack onto the array. + let diff = (non_rest_arg_count - required_num) as u32; + + // diff is >0 so no need to worry about null pointer + asm.comment("load pointer to array elements"); + let offset_magnitude = SIZEOF_VALUE as u32 * diff; + let values_opnd = ctx.sp_opnd(-(offset_magnitude as isize)); + let values_ptr = asm.lea(values_opnd); + asm.comment("prepend stack values to rest array"); + let array = asm.ccall( + rb_yjit_rb_ary_unshift_m as *const u8, + vec![Opnd::UImm(diff as u64), values_ptr, array], + ); + ctx.stack_pop(diff as usize); + + // We now should have the required arguments + // and an array of all the rest arguments + argc = required_num + 1; + let stack_ret = ctx.stack_push(Type::TArray); + asm.mov(stack_ret, array); + } else { + // We exit on less than right now, this is only handling + // the case where they are equal + assert!(non_rest_arg_count == required_num); + let stack_ret = ctx.stack_push(Type::TArray); + asm.mov(stack_ret, array); + } } else { assert!(argc >= required_num); let n = (argc - required_num) as u32; @@ -5803,7 +5834,7 @@ fn gen_send_iseq( let values_ptr = if n == 0 { Opnd::UImm(0) } else { - asm.comment("load pointer to array elts"); + asm.comment("load pointer to array elements"); let offset_magnitude = SIZEOF_VALUE as u32 * n; let values_opnd = ctx.sp_opnd(-(offset_magnitude as isize)); asm.lea(values_opnd) diff --git a/yjit/src/cruby_bindings.inc.rs b/yjit/src/cruby_bindings.inc.rs index 21a6d09e84..a683c5a1e4 100644 --- a/yjit/src/cruby_bindings.inc.rs +++ b/yjit/src/cruby_bindings.inc.rs @@ -1297,6 +1297,11 @@ extern "C" { pub fn rb_yarv_str_eql_internal(str1: VALUE, str2: VALUE) -> VALUE; pub fn rb_str_neq_internal(str1: VALUE, str2: VALUE) -> VALUE; pub fn rb_yarv_ary_entry_internal(ary: VALUE, offset: ::std::os::raw::c_long) -> VALUE; + pub fn rb_yjit_rb_ary_unshift_m( + argc: ::std::os::raw::c_int, + argv: *mut VALUE, + ary: VALUE, + ) -> VALUE; pub fn rb_yarv_fix_mod_fix(recv: VALUE, obj: VALUE) -> VALUE; pub fn rb_yjit_dump_iseq_loc(iseq: *const rb_iseq_t, insn_idx: u32); pub fn rb_FL_TEST(obj: VALUE, flags: VALUE) -> VALUE; diff --git a/yjit/src/stats.rs b/yjit/src/stats.rs index 4a13997f74..014030bf64 100644 --- a/yjit/src/stats.rs +++ b/yjit/src/stats.rs @@ -253,7 +253,7 @@ make_counters! { send_send_getter, send_send_builtin, send_iseq_has_rest_and_captured, - send_iseq_has_rest_and_splat_not_equal, + send_iseq_has_rest_and_splat_fewer, send_iseq_has_rest_and_send, send_iseq_has_rest_and_block, send_iseq_has_rest_and_kw, |