summaryrefslogtreecommitdiff
path: root/arch/x86
diff options
context:
space:
mode:
Diffstat (limited to 'arch/x86')
-rw-r--r--arch/x86/net/bpf_jit_comp.c237
1 files changed, 189 insertions, 48 deletions
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index 7b0ff169c9a0..26f43279b78b 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -221,14 +221,48 @@ struct jit_context {
/* Number of bytes emit_patch() needs to generate instructions */
#define X86_PATCH_SIZE 5
+/* Number of bytes that will be skipped on tailcall */
+#define X86_TAIL_CALL_OFFSET 11
-#define PROLOGUE_SIZE 25
+static void push_callee_regs(u8 **pprog, bool *callee_regs_used)
+{
+ u8 *prog = *pprog;
+ int cnt = 0;
+
+ if (callee_regs_used[0])
+ EMIT1(0x53); /* push rbx */
+ if (callee_regs_used[1])
+ EMIT2(0x41, 0x55); /* push r13 */
+ if (callee_regs_used[2])
+ EMIT2(0x41, 0x56); /* push r14 */
+ if (callee_regs_used[3])
+ EMIT2(0x41, 0x57); /* push r15 */
+ *pprog = prog;
+}
+
+static void pop_callee_regs(u8 **pprog, bool *callee_regs_used)
+{
+ u8 *prog = *pprog;
+ int cnt = 0;
+
+ if (callee_regs_used[3])
+ EMIT2(0x41, 0x5F); /* pop r15 */
+ if (callee_regs_used[2])
+ EMIT2(0x41, 0x5E); /* pop r14 */
+ if (callee_regs_used[1])
+ EMIT2(0x41, 0x5D); /* pop r13 */
+ if (callee_regs_used[0])
+ EMIT1(0x5B); /* pop rbx */
+ *pprog = prog;
+}
/*
- * Emit x86-64 prologue code for BPF program and check its size.
- * bpf_tail_call helper will skip it while jumping into another program
+ * Emit x86-64 prologue code for BPF program.
+ * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes
+ * while jumping to another program
*/
-static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
+static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
+ bool tail_call_reachable, bool is_subprog)
{
u8 *prog = *pprog;
int cnt = X86_PATCH_SIZE;
@@ -238,19 +272,18 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
*/
memcpy(prog, ideal_nops[NOP_ATOMIC5], cnt);
prog += cnt;
+ if (!ebpf_from_cbpf) {
+ if (tail_call_reachable && !is_subprog)
+ EMIT2(0x31, 0xC0); /* xor eax, eax */
+ else
+ EMIT2(0x66, 0x90); /* nop2 */
+ }
EMIT1(0x55); /* push rbp */
EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */
/* sub rsp, rounded_stack_depth */
EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8));
- EMIT1(0x53); /* push rbx */
- EMIT2(0x41, 0x55); /* push r13 */
- EMIT2(0x41, 0x56); /* push r14 */
- EMIT2(0x41, 0x57); /* push r15 */
- if (!ebpf_from_cbpf) {
- /* zero init tail_call_cnt */
- EMIT2(0x6a, 0x00);
- BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
- }
+ if (tail_call_reachable)
+ EMIT1(0x50); /* push rax */
*pprog = prog;
}
@@ -314,13 +347,14 @@ static int __bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
mutex_lock(&text_mutex);
if (memcmp(ip, old_insn, X86_PATCH_SIZE))
goto out;
+ ret = 1;
if (memcmp(ip, new_insn, X86_PATCH_SIZE)) {
if (text_live)
text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL);
else
memcpy(ip, new_insn, X86_PATCH_SIZE);
+ ret = 0;
}
- ret = 0;
out:
mutex_unlock(&text_mutex);
return ret;
@@ -337,6 +371,22 @@ int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
return __bpf_arch_text_poke(ip, t, old_addr, new_addr, true);
}
+static int get_pop_bytes(bool *callee_regs_used)
+{
+ int bytes = 0;
+
+ if (callee_regs_used[3])
+ bytes += 2;
+ if (callee_regs_used[2])
+ bytes += 2;
+ if (callee_regs_used[1])
+ bytes += 2;
+ if (callee_regs_used[0])
+ bytes += 1;
+
+ return bytes;
+}
+
/*
* Generate the following code:
*
@@ -351,12 +401,26 @@ int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
* goto *(prog->bpf_func + prologue_size);
* out:
*/
-static void emit_bpf_tail_call_indirect(u8 **pprog)
+static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
+ u32 stack_depth)
{
+ int tcc_off = -4 - round_up(stack_depth, 8);
u8 *prog = *pprog;
- int label1, label2, label3;
+ int pop_bytes = 0;
+ int off1 = 49;
+ int off2 = 38;
+ int off3 = 16;
int cnt = 0;
+ /* count the additional bytes used for popping callee regs from stack
+ * that need to be taken into account for each of the offsets that
+ * are used for bailing out of the tail call
+ */
+ pop_bytes = get_pop_bytes(callee_regs_used);
+ off1 += pop_bytes;
+ off2 += pop_bytes;
+ off3 += pop_bytes;
+
/*
* rdi - pointer to ctx
* rsi - pointer to bpf_array
@@ -370,21 +434,19 @@ static void emit_bpf_tail_call_indirect(u8 **pprog)
EMIT2(0x89, 0xD2); /* mov edx, edx */
EMIT3(0x39, 0x56, /* cmp dword ptr [rsi + 16], edx */
offsetof(struct bpf_array, map.max_entries));
-#define OFFSET1 (41 + RETPOLINE_RCX_BPF_JIT_SIZE) /* Number of bytes to jump */
+#define OFFSET1 (off1 + RETPOLINE_RCX_BPF_JIT_SIZE) /* Number of bytes to jump */
EMIT2(X86_JBE, OFFSET1); /* jbe out */
- label1 = cnt;
/*
* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
* goto out;
*/
- EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */
+ EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */
EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */
-#define OFFSET2 (30 + RETPOLINE_RCX_BPF_JIT_SIZE)
+#define OFFSET2 (off2 + RETPOLINE_RCX_BPF_JIT_SIZE)
EMIT2(X86_JA, OFFSET2); /* ja out */
- label2 = cnt;
EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */
- EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */
+ EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */
/* prog = array->ptrs[index]; */
EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6, /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */
@@ -394,48 +456,84 @@ static void emit_bpf_tail_call_indirect(u8 **pprog)
* if (prog == NULL)
* goto out;
*/
- EMIT3(0x48, 0x85, 0xC9); /* test rcx,rcx */
-#define OFFSET3 (8 + RETPOLINE_RCX_BPF_JIT_SIZE)
+ EMIT3(0x48, 0x85, 0xC9); /* test rcx,rcx */
+#define OFFSET3 (off3 + RETPOLINE_RCX_BPF_JIT_SIZE)
EMIT2(X86_JE, OFFSET3); /* je out */
- label3 = cnt;
- /* goto *(prog->bpf_func + prologue_size); */
+ *pprog = prog;
+ pop_callee_regs(pprog, callee_regs_used);
+ prog = *pprog;
+
+ EMIT1(0x58); /* pop rax */
+ EMIT3_off32(0x48, 0x81, 0xC4, /* add rsp, sd */
+ round_up(stack_depth, 8));
+
+ /* goto *(prog->bpf_func + X86_TAIL_CALL_OFFSET); */
EMIT4(0x48, 0x8B, 0x49, /* mov rcx, qword ptr [rcx + 32] */
offsetof(struct bpf_prog, bpf_func));
- EMIT4(0x48, 0x83, 0xC1, PROLOGUE_SIZE); /* add rcx, prologue_size */
-
+ EMIT4(0x48, 0x83, 0xC1, /* add rcx, X86_TAIL_CALL_OFFSET */
+ X86_TAIL_CALL_OFFSET);
/*
* Now we're ready to jump into next BPF program
* rdi == ctx (1st arg)
- * rcx == prog->bpf_func + prologue_size
+ * rcx == prog->bpf_func + X86_TAIL_CALL_OFFSET
*/
RETPOLINE_RCX_BPF_JIT();
/* out: */
- BUILD_BUG_ON(cnt - label1 != OFFSET1);
- BUILD_BUG_ON(cnt - label2 != OFFSET2);
- BUILD_BUG_ON(cnt - label3 != OFFSET3);
*pprog = prog;
}
static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke,
- u8 **pprog, int addr, u8 *image)
+ u8 **pprog, int addr, u8 *image,
+ bool *callee_regs_used, u32 stack_depth)
{
+ int tcc_off = -4 - round_up(stack_depth, 8);
u8 *prog = *pprog;
+ int pop_bytes = 0;
+ int off1 = 27;
+ int poke_off;
int cnt = 0;
+ /* count the additional bytes used for popping callee regs to stack
+ * that need to be taken into account for jump offset that is used for
+ * bailing out from of the tail call when limit is reached
+ */
+ pop_bytes = get_pop_bytes(callee_regs_used);
+ off1 += pop_bytes;
+
+ /*
+ * total bytes for:
+ * - nop5/ jmpq $off
+ * - pop callee regs
+ * - sub rsp, $val
+ * - pop rax
+ */
+ poke_off = X86_PATCH_SIZE + pop_bytes + 7 + 1;
+
/*
* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
* goto out;
*/
- EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */
+ EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */
EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */
- EMIT2(X86_JA, 14); /* ja out */
+ EMIT2(X86_JA, off1); /* ja out */
EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */
- EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */
+ EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */
+ poke->tailcall_bypass = image + (addr - poke_off - X86_PATCH_SIZE);
+ poke->adj_off = X86_TAIL_CALL_OFFSET;
poke->tailcall_target = image + (addr - X86_PATCH_SIZE);
- poke->adj_off = PROLOGUE_SIZE;
+ poke->bypass_addr = (u8 *)poke->tailcall_target + X86_PATCH_SIZE;
+
+ emit_jump(&prog, (u8 *)poke->tailcall_target + X86_PATCH_SIZE,
+ poke->tailcall_bypass);
+
+ *pprog = prog;
+ pop_callee_regs(pprog, callee_regs_used);
+ prog = *pprog;
+ EMIT1(0x58); /* pop rax */
+ EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8));
memcpy(prog, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE);
prog += X86_PATCH_SIZE;
@@ -476,6 +574,11 @@ static void bpf_tail_call_direct_fixup(struct bpf_prog *prog)
(u8 *)target->bpf_func +
poke->adj_off, false);
BUG_ON(ret < 0);
+ ret = __bpf_arch_text_poke(poke->tailcall_bypass,
+ BPF_MOD_JUMP,
+ (u8 *)poke->tailcall_target +
+ X86_PATCH_SIZE, NULL, false);
+ BUG_ON(ret < 0);
}
WRITE_ONCE(poke->tailcall_target_stable, true);
mutex_unlock(&array->aux->poke_mutex);
@@ -654,19 +757,49 @@ static bool ex_handler_bpf(const struct exception_table_entry *x,
return true;
}
+static void detect_reg_usage(struct bpf_insn *insn, int insn_cnt,
+ bool *regs_used, bool *tail_call_seen)
+{
+ int i;
+
+ for (i = 1; i <= insn_cnt; i++, insn++) {
+ if (insn->code == (BPF_JMP | BPF_TAIL_CALL))
+ *tail_call_seen = true;
+ if (insn->dst_reg == BPF_REG_6 || insn->src_reg == BPF_REG_6)
+ regs_used[0] = true;
+ if (insn->dst_reg == BPF_REG_7 || insn->src_reg == BPF_REG_7)
+ regs_used[1] = true;
+ if (insn->dst_reg == BPF_REG_8 || insn->src_reg == BPF_REG_8)
+ regs_used[2] = true;
+ if (insn->dst_reg == BPF_REG_9 || insn->src_reg == BPF_REG_9)
+ regs_used[3] = true;
+ }
+}
+
static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
int oldproglen, struct jit_context *ctx)
{
+ bool tail_call_reachable = bpf_prog->aux->tail_call_reachable;
struct bpf_insn *insn = bpf_prog->insnsi;
+ bool callee_regs_used[4] = {};
int insn_cnt = bpf_prog->len;
+ bool tail_call_seen = false;
bool seen_exit = false;
u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
int i, cnt = 0, excnt = 0;
int proglen = 0;
u8 *prog = temp;
+ detect_reg_usage(insn, insn_cnt, callee_regs_used,
+ &tail_call_seen);
+
+ /* tail call's presence in current prog implies it is reachable */
+ tail_call_reachable |= tail_call_seen;
+
emit_prologue(&prog, bpf_prog->aux->stack_depth,
- bpf_prog_was_classic(bpf_prog));
+ bpf_prog_was_classic(bpf_prog), tail_call_reachable,
+ bpf_prog->aux->func_idx != 0);
+ push_callee_regs(&prog, callee_regs_used);
addrs[0] = prog - temp;
for (i = 1; i <= insn_cnt; i++, insn++) {
@@ -1104,16 +1237,27 @@ xadd: if (is_imm8(insn->off))
/* call */
case BPF_JMP | BPF_CALL:
func = (u8 *) __bpf_call_base + imm32;
- if (!imm32 || emit_call(&prog, func, image + addrs[i - 1]))
- return -EINVAL;
+ if (tail_call_reachable) {
+ EMIT3_off32(0x48, 0x8B, 0x85,
+ -(bpf_prog->aux->stack_depth + 8));
+ if (!imm32 || emit_call(&prog, func, image + addrs[i - 1] + 7))
+ return -EINVAL;
+ } else {
+ if (!imm32 || emit_call(&prog, func, image + addrs[i - 1]))
+ return -EINVAL;
+ }
break;
case BPF_JMP | BPF_TAIL_CALL:
if (imm32)
emit_bpf_tail_call_direct(&bpf_prog->aux->poke_tab[imm32 - 1],
- &prog, addrs[i], image);
+ &prog, addrs[i], image,
+ callee_regs_used,
+ bpf_prog->aux->stack_depth);
else
- emit_bpf_tail_call_indirect(&prog);
+ emit_bpf_tail_call_indirect(&prog,
+ callee_regs_used,
+ bpf_prog->aux->stack_depth);
break;
/* cond jump */
@@ -1296,12 +1440,9 @@ emit_jmp:
seen_exit = true;
/* Update cleanup_addr */
ctx->cleanup_addr = proglen;
- if (!bpf_prog_was_classic(bpf_prog))
- EMIT1(0x5B); /* get rid of tail_call_cnt */
- EMIT2(0x41, 0x5F); /* pop r15 */
- EMIT2(0x41, 0x5E); /* pop r14 */
- EMIT2(0x41, 0x5D); /* pop r13 */
- EMIT1(0x5B); /* pop rbx */
+ pop_callee_regs(&prog, callee_regs_used);
+ if (tail_call_reachable)
+ EMIT1(0x59); /* pop rcx, get rid of tail_call_cnt */
EMIT1(0xC9); /* leave */
EMIT1(0xC3); /* ret */
break;