From cc9330f8c0fac9952572a590cb70eb8e63921056 Mon Sep 17 00:00:00 2001 From: Takashi Kokubun Date: Sat, 18 Mar 2023 22:13:40 -0700 Subject: RJIT: Reorder opt_case_dispatch branches --- lib/ruby_vm/rjit/insn_compiler.rb | 48 ++++++++++++++++++++++++++++++++++++--- rjit_c.rb | 23 ++++++++++++++++--- tool/rjit/bindgen.rb | 1 + 3 files changed, 66 insertions(+), 6 deletions(-) diff --git a/lib/ruby_vm/rjit/insn_compiler.rb b/lib/ruby_vm/rjit/insn_compiler.rb index 47e842e00e..5918abbd55 100644 --- a/lib/ruby_vm/rjit/insn_compiler.rb +++ b/lib/ruby_vm/rjit/insn_compiler.rb @@ -1746,9 +1746,51 @@ module RubyVM::RJIT # @param ctx [RubyVM::RJIT::Context] # @param asm [RubyVM::RJIT::Assembler] def opt_case_dispatch(jit, ctx, asm) - # Just go to === branches for now - ctx.stack_pop - KeepCompiling + # Normally this instruction would lookup the key in a hash and jump to an + # offset based on that. + # Instead we can take the fallback case and continue with the next + # instruction. + # We'd hope that our jitted code will be sufficiently fast without the + # hash lookup, at least for small hashes, but it's worth revisiting this + # assumption in the future. + unless jit.at_current_insn? + defer_compilation(jit, ctx, asm) + return EndBlock + end + starting_context = ctx.dup + + case_hash = jit.operand(0, ruby: true) + else_offset = jit.operand(1) + + # Try to reorder case/else branches so that ones that are actually used come first. + # Supporting only Fixnum for now so that the implementation can be an equality check. + key_opnd = ctx.stack_pop(1) + comptime_key = jit.peek_at_stack(0) + + # Check that all cases are fixnums to avoid having to register BOP assumptions on + # all the types that case hashes support. This spends compile time to save memory. + if fixnum?(comptime_key) && comptime_key <= 2**32 && C.rb_hash_keys(case_hash).all? { |key| fixnum?(key) } + unless Invariants.assume_bop_not_redefined(jit, C::INTEGER_REDEFINED_OP_FLAG, C::BOP_EQQ) + return CantCompile + end + + # Check if the key is the same value + asm.cmp(key_opnd, comptime_key) + side_exit = side_exit(jit, starting_context) + jit_chain_guard(:jne, jit, starting_context, asm, side_exit) + + # Get the offset for the compile-time key + offset = C.rb_hash_stlike_lookup(case_hash, comptime_key) + # NOTE: If we hit the else branch with various values, it could negatively impact the performance. + jump_offset = offset || else_offset + + # Jump to the offset of case or else + target_pc = jit.pc + (jit.insn.len + jump_offset) * C.VALUE.size + jit_direct_jump(jit.iseq, target_pc, ctx, asm) + EndBlock + else + KeepCompiling # continue with === branches + end end # @param jit [RubyVM::RJIT::JITState] diff --git a/rjit_c.rb b/rjit_c.rb index 17aa6754f0..a9b99d71e0 100644 --- a/rjit_c.rb +++ b/rjit_c.rb @@ -2,6 +2,9 @@ # Part of this file is generated by tool/rjit/bindgen.rb. # Run `make rjit-bindgen` to update code between "RJIT bindgen begin" and "RJIT bindgen end". module RubyVM::RJIT # :nodoc: all + # + # Main: Used by RJIT + # # This `class << C` section is for calling C functions with Primitive. # For importing variables or macros, use tool/rjit/bindgen.rb instead. class << C = Module.new @@ -292,10 +295,23 @@ module RubyVM::RJIT # :nodoc: all C.VALUE.new(lep_addr) end - # - # Utilities: Not used by RJIT, but useful for debugging - # + def rb_hash_keys(hash) + Primitive.cexpr! 'rb_hash_keys(hash)' + end + + def rb_hash_stlike_lookup(hash, key) + Primitive.cstmt! %{ + VALUE result = Qnil; + rb_hash_stlike_lookup(hash, key, &result); + return result; + } + end + end + # + # Utilities: Not used by RJIT, but useful for debugging + # + class << C # Convert insn BINs to encoded VM pointers. def rb_vm_insn_encode(bin) Primitive.cexpr! 'SIZET2NUM((VALUE)rb_vm_get_insns_address_table()[NUM2INT(bin)])' @@ -316,6 +332,7 @@ module RubyVM::RJIT # :nodoc: all C::BOP_AND = Primitive.cexpr! %q{ SIZET2NUM(BOP_AND) } C::BOP_AREF = Primitive.cexpr! %q{ SIZET2NUM(BOP_AREF) } C::BOP_EQ = Primitive.cexpr! %q{ SIZET2NUM(BOP_EQ) } + C::BOP_EQQ = Primitive.cexpr! %q{ SIZET2NUM(BOP_EQQ) } C::BOP_FREEZE = Primitive.cexpr! %q{ SIZET2NUM(BOP_FREEZE) } C::BOP_GE = Primitive.cexpr! %q{ SIZET2NUM(BOP_GE) } C::BOP_GT = Primitive.cexpr! %q{ SIZET2NUM(BOP_GT) } diff --git a/tool/rjit/bindgen.rb b/tool/rjit/bindgen.rb index 2e472a411b..ee2a77446a 100755 --- a/tool/rjit/bindgen.rb +++ b/tool/rjit/bindgen.rb @@ -385,6 +385,7 @@ generator = BindingGenerator.new( BOP_AND BOP_AREF BOP_EQ + BOP_EQQ BOP_FREEZE BOP_GE BOP_GT -- cgit v1.2.1