Module Name: src Committed By: alnsn Date: Thu May 15 22:20:08 UTC 2014
Modified Files: src/sys/net: bpfjit.c Log Message: Refactor bpfjit code. - Implement Array Bounds Check Elimination for packet bytes. - Track initialization of registers and memwords. - Remove "bj_" prefix from struct members. - Shorten "BPFJIT_" prefix to "BJ_". - Other small improvements. To generate a diff of this commit: cvs rdiff -u -r1.6 -r1.7 src/sys/net/bpfjit.c Please note that diffs are not public domain; they are subject to the copyright notices on the relevant files.
Modified files: Index: src/sys/net/bpfjit.c diff -u src/sys/net/bpfjit.c:1.6 src/sys/net/bpfjit.c:1.7 --- src/sys/net/bpfjit.c:1.6 Sun Dec 15 21:18:01 2013 +++ src/sys/net/bpfjit.c Thu May 15 22:20:08 2014 @@ -1,7 +1,7 @@ -/* $NetBSD: bpfjit.c,v 1.6 2013/12/15 21:18:01 pooka Exp $ */ +/* $NetBSD: bpfjit.c,v 1.7 2014/05/15 22:20:08 alnsn Exp $ */ /*- - * Copyright (c) 2011-2012 Alexander Nasonov. + * Copyright (c) 2011-2014 Alexander Nasonov. * All rights reserved. * * Redistribution and use in source and binary forms, with or without @@ -31,30 +31,33 @@ #include <sys/cdefs.h> #ifdef _KERNEL -__KERNEL_RCSID(0, "$NetBSD: bpfjit.c,v 1.6 2013/12/15 21:18:01 pooka Exp $"); +__KERNEL_RCSID(0, "$NetBSD: bpfjit.c,v 1.7 2014/05/15 22:20:08 alnsn Exp $"); #else -__RCSID("$NetBSD: bpfjit.c,v 1.6 2013/12/15 21:18:01 pooka Exp $"); +__RCSID("$NetBSD: bpfjit.c,v 1.7 2014/05/15 22:20:08 alnsn Exp $"); #endif #include <sys/types.h> #include <sys/queue.h> #ifndef _KERNEL -#include <stdlib.h> #include <assert.h> -#define BPFJIT_ALLOC(sz) malloc(sz) -#define BPFJIT_FREE(p, sz) free(p) -#define BPFJIT_ASSERT(c) assert(c) +#define BJ_ASSERT(c) assert(c) +#else +#define BJ_ASSERT(c) KASSERT(c) +#endif + +#ifndef _KERNEL +#include <stdlib.h> +#define BJ_ALLOC(sz) malloc(sz) +#define BJ_FREE(p, sz) free(p) #else #include <sys/kmem.h> -#define BPFJIT_ALLOC(sz) kmem_alloc(sz, KM_SLEEP) -#define BPFJIT_FREE(p, sz) kmem_free(p, sz) -#define BPFJIT_ASSERT(c) KASSERT(c) +#define BJ_ALLOC(sz) kmem_alloc(sz, KM_SLEEP) +#define BJ_FREE(p, sz) kmem_free(p, sz) #endif #ifndef _KERNEL #include <limits.h> -#include <stdio.h> #include <stdbool.h> #include <stddef.h> #include <stdint.h> @@ -68,28 +71,50 @@ __RCSID("$NetBSD: bpfjit.c,v 1.6 2013/12 #include <net/bpfjit.h> #include <sljitLir.h> -#define BPFJIT_A SLJIT_TEMPORARY_REG1 -#define BPFJIT_X SLJIT_TEMPORARY_EREG1 -#define BPFJIT_TMP1 SLJIT_TEMPORARY_REG2 -#define BPFJIT_TMP2 SLJIT_TEMPORARY_REG3 -#define BPFJIT_BUF SLJIT_SAVED_REG1 -#define BPFJIT_WIRELEN SLJIT_SAVED_REG2 -#define BPFJIT_BUFLEN SLJIT_SAVED_REG3 -#define BPFJIT_KERN_TMP SLJIT_TEMPORARY_EREG2 +#if !defined(_KERNEL) && defined(SLJIT_VERBOSE) && SLJIT_VERBOSE +#include <stdio.h> /* for stderr */ +#endif + +/* + * Permanent register assignments. + */ +#define BJ_BUF SLJIT_SAVED_REG1 +#define BJ_WIRELEN SLJIT_SAVED_REG2 +#define BJ_BUFLEN SLJIT_SAVED_REG3 +#define BJ_AREG SLJIT_TEMPORARY_REG1 +#define BJ_TMP1REG SLJIT_TEMPORARY_REG2 +#define BJ_TMP2REG SLJIT_TEMPORARY_REG3 +#define BJ_XREG SLJIT_TEMPORARY_EREG1 +#define BJ_TMP3REG SLJIT_TEMPORARY_EREG2 + +typedef unsigned int bpfjit_init_mask_t; +#define BJ_INIT_NOBITS 0u +#define BJ_INIT_MBIT(k) (1u << (k)) +#define BJ_INIT_MMASK (BJ_INIT_MBIT(BPF_MEMWORDS) - 1u) +#define BJ_INIT_ABIT BJ_INIT_MBIT(BPF_MEMWORDS) +#define BJ_INIT_XBIT BJ_INIT_MBIT(BPF_MEMWORDS + 1) -/* - * Flags for bpfjit_optimization_hints(). +struct bpfjit_stack +{ + uint32_t mem[BPF_MEMWORDS]; +#ifdef _KERNEL + void *tmp; +#endif +}; + +/* + * Data for BPF_JMP instruction. + * Forward declaration for struct bpfjit_jump. */ -#define BPFJIT_INIT_X 0x10000 -#define BPFJIT_INIT_A 0x20000 +struct bpfjit_jump_data; /* - * Node of bj_jumps list. + * Node of bjumps list. */ struct bpfjit_jump { - struct sljit_jump *bj_jump; - SLIST_ENTRY(bpfjit_jump) bj_entries; - uint32_t bj_safe_length; + struct sljit_jump *sjump; + SLIST_ENTRY(bpfjit_jump) entries; + struct bpfjit_jump_data *jdata; }; /* @@ -97,11 +122,19 @@ struct bpfjit_jump { */ struct bpfjit_jump_data { /* - * These entries make up bj_jumps list: - * bj_jtf[0] - when coming from jt path, - * bj_jtf[1] - when coming from jf path. + * These entries make up bjumps list: + * jtf[0] - when coming from jt path, + * jtf[1] - when coming from jf path. */ - struct bpfjit_jump bj_jtf[2]; + struct bpfjit_jump jtf[2]; + /* + * Length calculated by Array Bounds Check Elimination (ABC) pass. + */ + uint32_t abc_length; + /* + * Length checked by the last out-of-bounds check. + */ + uint32_t checked_length; }; /* @@ -110,11 +143,16 @@ struct bpfjit_jump_data { */ struct bpfjit_read_pkt_data { /* - * If positive, emit "if (buflen < bj_check_length) return 0". + * Length calculated by Array Bounds Check Elimination (ABC) pass. + */ + uint32_t abc_length; + /* + * If positive, emit "if (buflen < check_length) return 0" + * out-of-bounds check. * We assume that buflen is never equal to UINT32_MAX (otherwise, - * we need a special bool variable to emit unconditional "return 0"). + * we'd need a special bool variable to emit unconditional "return 0"). */ - uint32_t bj_check_length; + uint32_t check_length; }; /* @@ -122,14 +160,15 @@ struct bpfjit_read_pkt_data { */ struct bpfjit_insn_data { /* List of jumps to this insn. */ - SLIST_HEAD(, bpfjit_jump) bj_jumps; + SLIST_HEAD(, bpfjit_jump) bjumps; union { - struct bpfjit_jump_data bj_jdata; - struct bpfjit_read_pkt_data bj_rdata; - } bj_aux; + struct bpfjit_jump_data jdata; + struct bpfjit_read_pkt_data rdata; + } u; - bool bj_unreachable; + bpfjit_init_mask_t invalid; + bool unreachable; }; #ifdef _KERNEL @@ -162,7 +201,7 @@ bpfjit_modcmd(modcmd_t cmd, void *arg) #endif static uint32_t -read_width(struct bpf_insn *pc) +read_width(const struct bpf_insn *pc) { switch (BPF_SIZE(pc->code)) { @@ -173,27 +212,43 @@ read_width(struct bpf_insn *pc) case BPF_B: return 1; default: - BPFJIT_ASSERT(false); + BJ_ASSERT(false); return 0; } } -/* - * Get offset of M[k] on the stack. - */ -static size_t -mem_local_offset(uint32_t k, unsigned int minm) +static bool +grow_jumps(struct sljit_jump ***jumps, size_t *size) { - size_t moff = (k - minm) * sizeof(uint32_t); + struct sljit_jump **newptr; + const size_t elemsz = sizeof(struct sljit_jump *); + size_t old_size = *size; + size_t new_size = 2 * old_size; + + if (new_size < old_size || new_size > SIZE_MAX / elemsz) + return false; + + newptr = BJ_ALLOC(new_size * elemsz); + if (newptr == NULL) + return false; + + memcpy(newptr, *jumps, old_size * elemsz); + BJ_FREE(*jumps, old_size * elemsz); + + *jumps = newptr; + *size = new_size; + return true; +} -#ifdef _KERNEL - /* - * 4 bytes for the third argument of m_xword/m_xhalf/m_xbyte. - */ - return sizeof(uint32_t) + moff; -#else - return moff; -#endif +static bool +append_jump(struct sljit_jump *jump, struct sljit_jump ***jumps, + size_t *size, size_t *max_size) +{ + if (*size == *max_size && !grow_jumps(jumps, max_size)) + return false; + + (*jumps)[(*size)++] = jump; + return true; } /* @@ -205,8 +260,8 @@ emit_read8(struct sljit_compiler* compil return sljit_emit_op1(compiler, SLJIT_MOV_UB, - BPFJIT_A, 0, - SLJIT_MEM1(BPFJIT_BUF), k); + BJ_AREG, 0, + SLJIT_MEM1(BJ_BUF), k); } /* @@ -220,24 +275,24 @@ emit_read16(struct sljit_compiler* compi /* tmp1 = buf[k]; */ status = sljit_emit_op1(compiler, SLJIT_MOV_UB, - BPFJIT_TMP1, 0, - SLJIT_MEM1(BPFJIT_BUF), k); + BJ_TMP1REG, 0, + SLJIT_MEM1(BJ_BUF), k); if (status != SLJIT_SUCCESS) return status; /* A = buf[k+1]; */ status = sljit_emit_op1(compiler, SLJIT_MOV_UB, - BPFJIT_A, 0, - SLJIT_MEM1(BPFJIT_BUF), k+1); + BJ_AREG, 0, + SLJIT_MEM1(BJ_BUF), k+1); if (status != SLJIT_SUCCESS) return status; /* tmp1 = tmp1 << 8; */ status = sljit_emit_op2(compiler, SLJIT_SHL, - BPFJIT_TMP1, 0, - BPFJIT_TMP1, 0, + BJ_TMP1REG, 0, + BJ_TMP1REG, 0, SLJIT_IMM, 8); if (status != SLJIT_SUCCESS) return status; @@ -245,9 +300,9 @@ emit_read16(struct sljit_compiler* compi /* A = A + tmp1; */ status = sljit_emit_op2(compiler, SLJIT_ADD, - BPFJIT_A, 0, - BPFJIT_A, 0, - BPFJIT_TMP1, 0); + BJ_AREG, 0, + BJ_AREG, 0, + BJ_TMP1REG, 0); return status; } @@ -262,32 +317,32 @@ emit_read32(struct sljit_compiler* compi /* tmp1 = buf[k]; */ status = sljit_emit_op1(compiler, SLJIT_MOV_UB, - BPFJIT_TMP1, 0, - SLJIT_MEM1(BPFJIT_BUF), k); + BJ_TMP1REG, 0, + SLJIT_MEM1(BJ_BUF), k); if (status != SLJIT_SUCCESS) return status; /* tmp2 = buf[k+1]; */ status = sljit_emit_op1(compiler, SLJIT_MOV_UB, - BPFJIT_TMP2, 0, - SLJIT_MEM1(BPFJIT_BUF), k+1); + BJ_TMP2REG, 0, + SLJIT_MEM1(BJ_BUF), k+1); if (status != SLJIT_SUCCESS) return status; /* A = buf[k+3]; */ status = sljit_emit_op1(compiler, SLJIT_MOV_UB, - BPFJIT_A, 0, - SLJIT_MEM1(BPFJIT_BUF), k+3); + BJ_AREG, 0, + SLJIT_MEM1(BJ_BUF), k+3); if (status != SLJIT_SUCCESS) return status; /* tmp1 = tmp1 << 24; */ status = sljit_emit_op2(compiler, SLJIT_SHL, - BPFJIT_TMP1, 0, - BPFJIT_TMP1, 0, + BJ_TMP1REG, 0, + BJ_TMP1REG, 0, SLJIT_IMM, 24); if (status != SLJIT_SUCCESS) return status; @@ -295,25 +350,25 @@ emit_read32(struct sljit_compiler* compi /* A = A + tmp1; */ status = sljit_emit_op2(compiler, SLJIT_ADD, - BPFJIT_A, 0, - BPFJIT_A, 0, - BPFJIT_TMP1, 0); + BJ_AREG, 0, + BJ_AREG, 0, + BJ_TMP1REG, 0); if (status != SLJIT_SUCCESS) return status; /* tmp1 = buf[k+2]; */ status = sljit_emit_op1(compiler, SLJIT_MOV_UB, - BPFJIT_TMP1, 0, - SLJIT_MEM1(BPFJIT_BUF), k+2); + BJ_TMP1REG, 0, + SLJIT_MEM1(BJ_BUF), k+2); if (status != SLJIT_SUCCESS) return status; /* tmp2 = tmp2 << 16; */ status = sljit_emit_op2(compiler, SLJIT_SHL, - BPFJIT_TMP2, 0, - BPFJIT_TMP2, 0, + BJ_TMP2REG, 0, + BJ_TMP2REG, 0, SLJIT_IMM, 16); if (status != SLJIT_SUCCESS) return status; @@ -321,17 +376,17 @@ emit_read32(struct sljit_compiler* compi /* A = A + tmp2; */ status = sljit_emit_op2(compiler, SLJIT_ADD, - BPFJIT_A, 0, - BPFJIT_A, 0, - BPFJIT_TMP2, 0); + BJ_AREG, 0, + BJ_AREG, 0, + BJ_TMP2REG, 0); if (status != SLJIT_SUCCESS) return status; /* tmp1 = tmp1 << 8; */ status = sljit_emit_op2(compiler, SLJIT_SHL, - BPFJIT_TMP1, 0, - BPFJIT_TMP1, 0, + BJ_TMP1REG, 0, + BJ_TMP1REG, 0, SLJIT_IMM, 8); if (status != SLJIT_SUCCESS) return status; @@ -339,9 +394,9 @@ emit_read32(struct sljit_compiler* compi /* A = A + tmp1; */ status = sljit_emit_op2(compiler, SLJIT_ADD, - BPFJIT_A, 0, - BPFJIT_A, 0, - BPFJIT_TMP1, 0); + BJ_AREG, 0, + BJ_AREG, 0, + BJ_TMP1REG, 0); return status; } @@ -358,16 +413,20 @@ emit_read32(struct sljit_compiler* compi * BPF_LD+BPF_B+BPF_IND A <- P[X+k:1] * BPF_LDX+BPF_B+BPF_MSH X <- 4*(P[k:1]&0xf) * - * dst must be BPFJIT_A for BPF_LD instructions and BPFJIT_X - * or any of BPFJIT_TMP* registrers for BPF_MSH instruction. + * The dst variable should be + * - BJ_AREG when emitting code for BPF_LD instructions, + * - BJ_XREG or any of BJ_TMP[1-3]REG registers when emitting + * code for BPF_MSH instruction. */ static int -emit_xcall(struct sljit_compiler* compiler, struct bpf_insn *pc, +emit_xcall(struct sljit_compiler* compiler, const struct bpf_insn *pc, int dst, sljit_w dstw, struct sljit_jump **ret0_jump, uint32_t (*fn)(const struct mbuf *, uint32_t, int *)) { -#if BPFJIT_X != SLJIT_TEMPORARY_EREG1 || \ - BPFJIT_X == SLJIT_RETURN_REG +#if BJ_XREG == SLJIT_RETURN_REG || \ + BJ_XREG == SLJIT_TEMPORARY_REG1 || \ + BJ_XREG == SLJIT_TEMPORARY_REG2 || \ + BJ_XREG == SLJIT_TEMPORARY_REG3 #error "Not supported assignment of registers." #endif int status; @@ -375,14 +434,14 @@ emit_xcall(struct sljit_compiler* compil /* * The third argument of fn is an address on stack. */ - const int arg3_offset = 0; + const int arg3_offset = offsetof(struct bpfjit_stack, tmp); if (BPF_CLASS(pc->code) == BPF_LDX) { /* save A */ status = sljit_emit_op1(compiler, SLJIT_MOV, - BPFJIT_KERN_TMP, 0, - BPFJIT_A, 0); + BJ_TMP3REG, 0, + BJ_AREG, 0); if (status != SLJIT_SUCCESS) return status; } @@ -393,7 +452,7 @@ emit_xcall(struct sljit_compiler* compil status = sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TEMPORARY_REG1, 0, - BPFJIT_BUF, 0); + BJ_BUF, 0); if (status != SLJIT_SUCCESS) return status; @@ -401,7 +460,7 @@ emit_xcall(struct sljit_compiler* compil status = sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_TEMPORARY_REG2, 0, - BPFJIT_X, 0, + BJ_XREG, 0, SLJIT_IMM, (uint32_t)pc->k); } else { status = sljit_emit_op1(compiler, @@ -423,30 +482,22 @@ emit_xcall(struct sljit_compiler* compil SLJIT_CALL3, SLJIT_IMM, SLJIT_FUNC_OFFSET(fn)); - if (BPF_CLASS(pc->code) == BPF_LDX) { - + if (dst != SLJIT_RETURN_REG) { /* move return value to dst */ - BPFJIT_ASSERT(dst != SLJIT_RETURN_REG); status = sljit_emit_op1(compiler, SLJIT_MOV, dst, dstw, SLJIT_RETURN_REG, 0); if (status != SLJIT_SUCCESS) return status; + } + if (BPF_CLASS(pc->code) == BPF_LDX) { /* restore A */ status = sljit_emit_op1(compiler, SLJIT_MOV, - BPFJIT_A, 0, - BPFJIT_KERN_TMP, 0); - if (status != SLJIT_SUCCESS) - return status; - - } else if (dst != SLJIT_RETURN_REG) { - status = sljit_emit_op1(compiler, - SLJIT_MOV, - dst, dstw, - SLJIT_RETURN_REG, 0); + BJ_AREG, 0, + BJ_TMP3REG, 0); if (status != SLJIT_SUCCESS) return status; } @@ -482,8 +533,8 @@ emit_xcall(struct sljit_compiler* compil */ static int emit_pkt_read(struct sljit_compiler* compiler, - struct bpf_insn *pc, struct sljit_jump *to_mchain_jump, - struct sljit_jump **ret0, size_t *ret0_size) + const struct bpf_insn *pc, struct sljit_jump *to_mchain_jump, + struct sljit_jump ***ret0, size_t *ret0_size, size_t *ret0_maxsize) { int status = 0; /* XXX gcc 4.1 */ uint32_t width; @@ -499,10 +550,10 @@ emit_pkt_read(struct sljit_compiler* com if (to_mchain_jump == NULL) { to_mchain_jump = sljit_emit_cmp(compiler, SLJIT_C_EQUAL, - BPFJIT_BUFLEN, 0, + BJ_BUFLEN, 0, SLJIT_IMM, 0); if (to_mchain_jump == NULL) - return SLJIT_ERR_ALLOC_FAILED; + return SLJIT_ERR_ALLOC_FAILED; } #endif @@ -512,8 +563,8 @@ emit_pkt_read(struct sljit_compiler* com /* tmp1 = buflen - (pc->k + width); */ status = sljit_emit_op2(compiler, SLJIT_SUB, - BPFJIT_TMP1, 0, - BPFJIT_BUFLEN, 0, + BJ_TMP1REG, 0, + BJ_BUFLEN, 0, SLJIT_IMM, k + width); if (status != SLJIT_SUCCESS) return status; @@ -521,20 +572,21 @@ emit_pkt_read(struct sljit_compiler* com /* buf += X; */ status = sljit_emit_op2(compiler, SLJIT_ADD, - BPFJIT_BUF, 0, - BPFJIT_BUF, 0, - BPFJIT_X, 0); + BJ_BUF, 0, + BJ_BUF, 0, + BJ_XREG, 0); if (status != SLJIT_SUCCESS) return status; /* if (tmp1 < X) return 0; */ jump = sljit_emit_cmp(compiler, SLJIT_C_LESS, - BPFJIT_TMP1, 0, - BPFJIT_X, 0); + BJ_TMP1REG, 0, + BJ_XREG, 0); if (jump == NULL) - return SLJIT_ERR_ALLOC_FAILED; - ret0[(*ret0_size)++] = jump; + return SLJIT_ERR_ALLOC_FAILED; + if (!append_jump(jump, ret0, ret0_size, ret0_maxsize)) + return SLJIT_ERR_ALLOC_FAILED; } switch (width) { @@ -556,9 +608,9 @@ emit_pkt_read(struct sljit_compiler* com /* buf -= X; */ status = sljit_emit_op2(compiler, SLJIT_SUB, - BPFJIT_BUF, 0, - BPFJIT_BUF, 0, - BPFJIT_X, 0); + BJ_BUF, 0, + BJ_BUF, 0, + BJ_XREG, 0); if (status != SLJIT_SUCCESS) return status; } @@ -566,41 +618,43 @@ emit_pkt_read(struct sljit_compiler* com #ifdef _KERNEL over_mchain_jump = sljit_emit_jump(compiler, SLJIT_JUMP); if (over_mchain_jump == NULL) - return SLJIT_ERR_ALLOC_FAILED; + return SLJIT_ERR_ALLOC_FAILED; /* entry point to mchain handler */ label = sljit_emit_label(compiler); if (label == NULL) - return SLJIT_ERR_ALLOC_FAILED; + return SLJIT_ERR_ALLOC_FAILED; sljit_set_label(to_mchain_jump, label); if (check_zero_buflen) { /* if (buflen != 0) return 0; */ jump = sljit_emit_cmp(compiler, SLJIT_C_NOT_EQUAL, - BPFJIT_BUFLEN, 0, + BJ_BUFLEN, 0, SLJIT_IMM, 0); if (jump == NULL) return SLJIT_ERR_ALLOC_FAILED; - ret0[(*ret0_size)++] = jump; + if (!append_jump(jump, ret0, ret0_size, ret0_maxsize)) + return SLJIT_ERR_ALLOC_FAILED; } switch (width) { case 4: - status = emit_xcall(compiler, pc, BPFJIT_A, 0, &jump, &m_xword); + status = emit_xcall(compiler, pc, BJ_AREG, 0, &jump, &m_xword); break; case 2: - status = emit_xcall(compiler, pc, BPFJIT_A, 0, &jump, &m_xhalf); + status = emit_xcall(compiler, pc, BJ_AREG, 0, &jump, &m_xhalf); break; case 1: - status = emit_xcall(compiler, pc, BPFJIT_A, 0, &jump, &m_xbyte); + status = emit_xcall(compiler, pc, BJ_AREG, 0, &jump, &m_xbyte); break; } if (status != SLJIT_SUCCESS) return status; - ret0[(*ret0_size)++] = jump; + if (!append_jump(jump, ret0, ret0_size, ret0_maxsize)) + return SLJIT_ERR_ALLOC_FAILED; label = sljit_emit_label(compiler); if (label == NULL) @@ -616,8 +670,8 @@ emit_pkt_read(struct sljit_compiler* com */ static int emit_msh(struct sljit_compiler* compiler, - struct bpf_insn *pc, struct sljit_jump *to_mchain_jump, - struct sljit_jump **ret0, size_t *ret0_size) + const struct bpf_insn *pc, struct sljit_jump *to_mchain_jump, + struct sljit_jump ***ret0, size_t *ret0_size, size_t *ret0_maxsize) { int status; #ifdef _KERNEL @@ -631,26 +685,26 @@ emit_msh(struct sljit_compiler* compiler if (to_mchain_jump == NULL) { to_mchain_jump = sljit_emit_cmp(compiler, SLJIT_C_EQUAL, - BPFJIT_BUFLEN, 0, + BJ_BUFLEN, 0, SLJIT_IMM, 0); if (to_mchain_jump == NULL) - return SLJIT_ERR_ALLOC_FAILED; + return SLJIT_ERR_ALLOC_FAILED; } #endif /* tmp1 = buf[k] */ status = sljit_emit_op1(compiler, SLJIT_MOV_UB, - BPFJIT_TMP1, 0, - SLJIT_MEM1(BPFJIT_BUF), k); + BJ_TMP1REG, 0, + SLJIT_MEM1(BJ_BUF), k); if (status != SLJIT_SUCCESS) return status; /* tmp1 &= 0xf */ status = sljit_emit_op2(compiler, SLJIT_AND, - BPFJIT_TMP1, 0, - BPFJIT_TMP1, 0, + BJ_TMP1REG, 0, + BJ_TMP1REG, 0, SLJIT_IMM, 0xf); if (status != SLJIT_SUCCESS) return status; @@ -658,8 +712,8 @@ emit_msh(struct sljit_compiler* compiler /* tmp1 = tmp1 << 2 */ status = sljit_emit_op2(compiler, SLJIT_SHL, - BPFJIT_X, 0, - BPFJIT_TMP1, 0, + BJ_XREG, 0, + BJ_TMP1REG, 0, SLJIT_IMM, 2); if (status != SLJIT_SUCCESS) return status; @@ -679,23 +733,26 @@ emit_msh(struct sljit_compiler* compiler /* if (buflen != 0) return 0; */ jump = sljit_emit_cmp(compiler, SLJIT_C_NOT_EQUAL, - BPFJIT_BUFLEN, 0, + BJ_BUFLEN, 0, SLJIT_IMM, 0); if (jump == NULL) - return SLJIT_ERR_ALLOC_FAILED; - ret0[(*ret0_size)++] = jump; + return SLJIT_ERR_ALLOC_FAILED; + if (!append_jump(jump, ret0, ret0_size, ret0_maxsize)) + return SLJIT_ERR_ALLOC_FAILED; } - status = emit_xcall(compiler, pc, BPFJIT_TMP1, 0, &jump, &m_xbyte); + status = emit_xcall(compiler, pc, BJ_TMP1REG, 0, &jump, &m_xbyte); if (status != SLJIT_SUCCESS) return status; - ret0[(*ret0_size)++] = jump; + + if (!append_jump(jump, ret0, ret0_size, ret0_maxsize)) + return SLJIT_ERR_ALLOC_FAILED; /* tmp1 &= 0xf */ status = sljit_emit_op2(compiler, SLJIT_AND, - BPFJIT_TMP1, 0, - BPFJIT_TMP1, 0, + BJ_TMP1REG, 0, + BJ_TMP1REG, 0, SLJIT_IMM, 0xf); if (status != SLJIT_SUCCESS) return status; @@ -703,8 +760,8 @@ emit_msh(struct sljit_compiler* compiler /* tmp1 = tmp1 << 2 */ status = sljit_emit_op2(compiler, SLJIT_SHL, - BPFJIT_X, 0, - BPFJIT_TMP1, 0, + BJ_XREG, 0, + BJ_TMP1REG, 0, SLJIT_IMM, 2); if (status != SLJIT_SUCCESS) return status; @@ -730,13 +787,13 @@ emit_pow2_division(struct sljit_compiler shift++; } - BPFJIT_ASSERT(k == 1 && shift < 32); + BJ_ASSERT(k == 1 && shift < 32); if (shift != 0) { status = sljit_emit_op2(compiler, SLJIT_LSHR|SLJIT_INT_OP, - BPFJIT_A, 0, - BPFJIT_A, 0, + BJ_AREG, 0, + BJ_AREG, 0, SLJIT_IMM, shift); } @@ -754,25 +811,25 @@ divide(sljit_uw x, sljit_uw y) /* * Generate A = A / div. - * divt,divw are either SLJIT_IMM,pc->k or BPFJIT_X,0. + * divt,divw are either SLJIT_IMM,pc->k or BJ_XREG,0. */ static int emit_division(struct sljit_compiler* compiler, int divt, sljit_w divw) { int status; -#if BPFJIT_X == SLJIT_TEMPORARY_REG1 || \ - BPFJIT_X == SLJIT_RETURN_REG || \ - BPFJIT_X == SLJIT_TEMPORARY_REG2 || \ - BPFJIT_A == SLJIT_TEMPORARY_REG2 +#if BJ_XREG == SLJIT_RETURN_REG || \ + BJ_XREG == SLJIT_TEMPORARY_REG1 || \ + BJ_XREG == SLJIT_TEMPORARY_REG2 || \ + BJ_AREG == SLJIT_TEMPORARY_REG2 #error "Not supported assignment of registers." #endif -#if BPFJIT_A != SLJIT_TEMPORARY_REG1 +#if BJ_AREG != SLJIT_TEMPORARY_REG1 status = sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TEMPORARY_REG1, 0, - BPFJIT_A, 0); + BJ_AREG, 0); if (status != SLJIT_SUCCESS) return status; #endif @@ -787,10 +844,10 @@ emit_division(struct sljit_compiler* com #if defined(BPFJIT_USE_UDIV) status = sljit_emit_op0(compiler, SLJIT_UDIV|SLJIT_INT_OP); -#if BPFJIT_A != SLJIT_TEMPORARY_REG1 +#if BJ_AREG != SLJIT_TEMPORARY_REG1 status = sljit_emit_op1(compiler, SLJIT_MOV, - BPFJIT_A, 0, + BJ_AREG, 0, SLJIT_TEMPORARY_REG1, 0); if (status != SLJIT_SUCCESS) return status; @@ -800,10 +857,10 @@ emit_division(struct sljit_compiler* com SLJIT_CALL2, SLJIT_IMM, SLJIT_FUNC_OFFSET(divide)); -#if BPFJIT_A != SLJIT_RETURN_REG +#if BJ_AREG != SLJIT_RETURN_REG status = sljit_emit_op1(compiler, SLJIT_MOV, - BPFJIT_A, 0, + BJ_AREG, 0, SLJIT_RETURN_REG, 0); if (status != SLJIT_SUCCESS) return status; @@ -814,30 +871,12 @@ emit_division(struct sljit_compiler* com } /* - * Count BPF_RET instructions. - */ -static size_t -count_returns(struct bpf_insn *insns, size_t insn_count) -{ - size_t i; - size_t rv; - - rv = 0; - for (i = 0; i < insn_count; i++) { - if (BPF_CLASS(insns[i].code) == BPF_RET) - rv++; - } - - return rv; -} - -/* * Return true if pc is a "read from packet" instruction. * If length is not NULL and return value is true, *length will * be set to a safe length required to read a packet. */ static bool -read_pkt_insn(struct bpf_insn *pc, uint32_t *length) +read_pkt_insn(const struct bpf_insn *pc, uint32_t *length) { bool rv; uint32_t width; @@ -868,20 +907,14 @@ read_pkt_insn(struct bpf_insn *pc, uint3 return rv; } -/* - * Set bj_check_length for all "read from packet" instructions - * in a linear block of instructions [from, to). - */ static void -set_check_length(struct bpf_insn *insns, struct bpfjit_insn_data *insn_dat, - size_t from, size_t to, uint32_t length) +optimize_init(struct bpfjit_insn_data *insn_dat, size_t insn_count) { + size_t i; - for (; from < to; from++) { - if (read_pkt_insn(&insns[from], NULL)) { - insn_dat[from].bj_aux.bj_rdata.bj_check_length = length; - length = 0; - } + for (i = 0; i < insn_count; i++) { + SLIST_INIT(&insn_dat[i].bjumps); + insn_dat[i].invalid = BJ_INIT_NOBITS; } } @@ -891,63 +924,152 @@ set_check_length(struct bpf_insn *insns, * terminate a block. Blocks are linear, that is, there are no jumps out * from the middle of a block and there are no jumps in to the middle of * a block. - * If a block has one or more "read from packet" instructions, - * bj_check_length will be set to one value for the whole block and that - * value will be equal to the greatest value of safe lengths of "read from - * packet" instructions inside the block. + * + * The function also sets bits in *initmask for memwords that + * need to be initialized to zero. Note that this set should be empty + * for any valid kernel filter program. */ -static int -optimize(struct bpf_insn *insns, - struct bpfjit_insn_data *insn_dat, size_t insn_count) +static bool +optimize_pass1(const struct bpf_insn *insns, + struct bpfjit_insn_data *insn_dat, size_t insn_count, + bpfjit_init_mask_t *initmask, int *nscratches) { + struct bpfjit_jump *jtf; size_t i; - size_t first_read; - bool unreachable; uint32_t jt, jf; - uint32_t length, safe_length; - struct bpfjit_jump *jmp, *jtf; + bpfjit_init_mask_t invalid; /* borrowed from bpf_filter() */ + bool unreachable; - for (i = 0; i < insn_count; i++) - SLIST_INIT(&insn_dat[i].bj_jumps); + *nscratches = 2; + *initmask = BJ_INIT_NOBITS; - safe_length = 0; unreachable = false; - first_read = SIZE_MAX; + invalid = ~BJ_INIT_NOBITS; for (i = 0; i < insn_count; i++) { - - if (!SLIST_EMPTY(&insn_dat[i].bj_jumps)) { + if (!SLIST_EMPTY(&insn_dat[i].bjumps)) unreachable = false; + insn_dat[i].unreachable = unreachable; - set_check_length(insns, insn_dat, - first_read, i, safe_length); - first_read = SIZE_MAX; - - safe_length = UINT32_MAX; - SLIST_FOREACH(jmp, &insn_dat[i].bj_jumps, bj_entries) { - if (jmp->bj_safe_length < safe_length) - safe_length = jmp->bj_safe_length; - } - } - - insn_dat[i].bj_unreachable = unreachable; if (unreachable) continue; - if (read_pkt_insn(&insns[i], &length)) { - if (first_read == SIZE_MAX) - first_read = i; - if (length > safe_length) - safe_length = length; - } + invalid |= insn_dat[i].invalid; switch (BPF_CLASS(insns[i].code)) { case BPF_RET: + if (BPF_RVAL(insns[i].code) == BPF_A) + *initmask |= invalid & BJ_INIT_ABIT; + unreachable = true; continue; + case BPF_LD: + if (BPF_MODE(insns[i].code) == BPF_IND || + BPF_MODE(insns[i].code) == BPF_ABS) { + if (BPF_MODE(insns[i].code) == BPF_IND && + *nscratches < 4) { + /* uses BJ_XREG */ + *nscratches = 4; + } + if (*nscratches < 3 && + read_width(&insns[i]) == 4) { + /* uses BJ_TMP2REG */ + *nscratches = 3; + } + } + + if (BPF_MODE(insns[i].code) == BPF_IND) + *initmask |= invalid & BJ_INIT_XBIT; + + if (BPF_MODE(insns[i].code) == BPF_MEM && + (uint32_t)insns[i].k < BPF_MEMWORDS) { + *initmask |= invalid & BJ_INIT_MBIT(insns[i].k); + } + + invalid &= ~BJ_INIT_ABIT; + continue; + + case BPF_LDX: +#if defined(_KERNEL) + /* uses BJ_TMP3REG */ + *nscratches = 5; +#endif + /* uses BJ_XREG */ + if (*nscratches < 4) + *nscratches = 4; + + if (BPF_MODE(insns[i].code) == BPF_MEM && + (uint32_t)insns[i].k < BPF_MEMWORDS) { + *initmask |= invalid & BJ_INIT_MBIT(insns[i].k); + } + + invalid &= ~BJ_INIT_XBIT; + continue; + + case BPF_ST: + *initmask |= invalid & BJ_INIT_ABIT; + + if ((uint32_t)insns[i].k < BPF_MEMWORDS) + invalid &= ~BJ_INIT_MBIT(insns[i].k); + + continue; + + case BPF_STX: + /* uses BJ_XREG */ + if (*nscratches < 4) + *nscratches = 4; + + *initmask |= invalid & BJ_INIT_XBIT; + + if ((uint32_t)insns[i].k < BPF_MEMWORDS) + invalid &= ~BJ_INIT_MBIT(insns[i].k); + + continue; + + case BPF_ALU: + *initmask |= invalid & BJ_INIT_ABIT; + + if (insns[i].code != (BPF_ALU|BPF_NEG) && + BPF_SRC(insns[i].code) == BPF_X) { + *initmask |= invalid & BJ_INIT_XBIT; + /* uses BJ_XREG */ + if (*nscratches < 4) + *nscratches = 4; + + } + + invalid &= ~BJ_INIT_ABIT; + continue; + + case BPF_MISC: + switch (BPF_MISCOP(insns[i].code)) { + case BPF_TAX: // X <- A + /* uses BJ_XREG */ + if (*nscratches < 4) + *nscratches = 4; + + *initmask |= invalid & BJ_INIT_ABIT; + invalid &= ~BJ_INIT_XBIT; + continue; + + case BPF_TXA: // A <- X + /* uses BJ_XREG */ + if (*nscratches < 4) + *nscratches = 4; + + *initmask |= invalid & BJ_INIT_XBIT; + invalid &= ~BJ_INIT_ABIT; + continue; + } + + continue; + case BPF_JMP: - if (insns[i].code == (BPF_JMP|BPF_JA)) { + /* Initialize abc_length for ABC pass. */ + insn_dat[i].u.jdata.abc_length = UINT32_MAX; + + if (BPF_OP(insns[i].code) == BPF_JA) { jt = jf = insns[i].k; } else { jt = insns[i].jt; @@ -956,80 +1078,140 @@ optimize(struct bpf_insn *insns, if (jt >= insn_count - (i + 1) || jf >= insn_count - (i + 1)) { - return -1; + return false; } if (jt > 0 && jf > 0) unreachable = true; - jtf = insn_dat[i].bj_aux.bj_jdata.bj_jtf; + jt += i + 1; + jf += i + 1; + + jtf = insn_dat[i].u.jdata.jtf; - jtf[0].bj_jump = NULL; - jtf[0].bj_safe_length = safe_length; - SLIST_INSERT_HEAD(&insn_dat[i + 1 + jt].bj_jumps, - &jtf[0], bj_entries); + jtf[0].sjump = NULL; + jtf[0].jdata = &insn_dat[i].u.jdata; + SLIST_INSERT_HEAD(&insn_dat[jt].bjumps, + &jtf[0], entries); if (jf != jt) { - jtf[1].bj_jump = NULL; - jtf[1].bj_safe_length = safe_length; - SLIST_INSERT_HEAD(&insn_dat[i + 1 + jf].bj_jumps, - &jtf[1], bj_entries); + jtf[1].sjump = NULL; + jtf[1].jdata = &insn_dat[i].u.jdata; + SLIST_INSERT_HEAD(&insn_dat[jf].bjumps, + &jtf[1], entries); } + insn_dat[jf].invalid |= invalid; + insn_dat[jt].invalid |= invalid; + invalid = 0; + continue; } } - set_check_length(insns, insn_dat, first_read, insn_count, safe_length); - - return 0; + return true; } /* - * Count out-of-bounds and division by zero jumps. - * - * insn_dat should be initialized by optimize(). + * Array Bounds Check Elimination (ABC) pass. */ -static size_t -get_ret0_size(struct bpf_insn *insns, struct bpfjit_insn_data *insn_dat, - size_t insn_count) +static void +optimize_pass2(const struct bpf_insn *insns, + struct bpfjit_insn_data *insn_dat, size_t insn_count) { - size_t rv = 0; + struct bpfjit_jump *jmp; + const struct bpf_insn *pc; + struct bpfjit_insn_data *pd; size_t i; + uint32_t length, abc_length = 0; - for (i = 0; i < insn_count; i++) { + for (i = insn_count; i != 0; i--) { + pc = &insns[i-1]; + pd = &insn_dat[i-1]; - if (read_pkt_insn(&insns[i], NULL)) { - if (insn_dat[i].bj_aux.bj_rdata.bj_check_length > 0) - rv++; -#ifdef _KERNEL - rv++; -#endif + if (pd->unreachable) + continue; + + switch (BPF_CLASS(pc->code)) { + case BPF_RET: + abc_length = 0; + break; + + case BPF_JMP: + abc_length = pd->u.jdata.abc_length; + break; + + default: + if (read_pkt_insn(pc, &length)) { + if (abc_length < length) + abc_length = length; + pd->u.rdata.abc_length = abc_length; + } + break; } - if (insns[i].code == (BPF_LD|BPF_IND|BPF_B) || - insns[i].code == (BPF_LD|BPF_IND|BPF_H) || - insns[i].code == (BPF_LD|BPF_IND|BPF_W)) { - rv++; + SLIST_FOREACH(jmp, &pd->bjumps, entries) { + if (jmp->jdata->abc_length > abc_length) + jmp->jdata->abc_length = abc_length; } + } +} - if (insns[i].code == (BPF_ALU|BPF_DIV|BPF_X)) - rv++; +static void +optimize_pass3(const struct bpf_insn *insns, + struct bpfjit_insn_data *insn_dat, size_t insn_count) +{ + struct bpfjit_jump *jmp; + size_t i; + uint32_t length, checked_length = 0; - if (insns[i].code == (BPF_ALU|BPF_DIV|BPF_K) && - insns[i].k == 0) { - rv++; + for (i = 0; i < insn_count; i++) { + if (insn_dat[i].unreachable) + continue; + + SLIST_FOREACH(jmp, &insn_dat[i].bjumps, entries) { + if (jmp->jdata->checked_length < checked_length) + checked_length = jmp->jdata->checked_length; + } + + if (BPF_CLASS(insns[i].code) == BPF_JMP) { + insn_dat[i].u.jdata.checked_length = checked_length; + } else if (read_pkt_insn(&insns[i], &length)) { + struct bpfjit_read_pkt_data *rdata = + &insn_dat[i].u.rdata; + rdata->check_length = 0; + if (checked_length < rdata->abc_length) { + checked_length = rdata->abc_length; + rdata->check_length = checked_length; + } } } +} - return rv; +static bool +optimize(const struct bpf_insn *insns, + struct bpfjit_insn_data *insn_dat, size_t insn_count, + bpfjit_init_mask_t *initmask, int *nscratches) +{ + + optimize_init(insn_dat, insn_count); + + if (!optimize_pass1(insns, insn_dat, insn_count, + initmask, nscratches)) { + return false; + } + + optimize_pass2(insns, insn_dat, insn_count); + optimize_pass3(insns, insn_dat, insn_count); + + return true; } /* * Convert BPF_ALU operations except BPF_NEG and BPF_DIV to sljit operation. */ static int -bpf_alu_to_sljit_op(struct bpf_insn *pc) +bpf_alu_to_sljit_op(const struct bpf_insn *pc) { /* @@ -1045,7 +1227,7 @@ bpf_alu_to_sljit_op(struct bpf_insn *pc) case BPF_LSH: return SLJIT_SHL; case BPF_RSH: return SLJIT_LSHR|SLJIT_INT_OP; default: - BPFJIT_ASSERT(false); + BJ_ASSERT(false); return 0; } } @@ -1054,7 +1236,7 @@ bpf_alu_to_sljit_op(struct bpf_insn *pc) * Convert BPF_JMP operations except BPF_JA to sljit condition. */ static int -bpf_jmp_to_sljit_cond(struct bpf_insn *pc, bool negate) +bpf_jmp_to_sljit_cond(const struct bpf_insn *pc, bool negate) { /* * Note: all supported 64bit arches have 32bit comparison @@ -1076,114 +1258,37 @@ bpf_jmp_to_sljit_cond(struct bpf_insn *p rv |= negate ? SLJIT_C_EQUAL : SLJIT_C_NOT_EQUAL; break; default: - BPFJIT_ASSERT(false); + BJ_ASSERT(false); } return rv; } -static unsigned int -bpfjit_optimization_hints(struct bpf_insn *insns, size_t insn_count) -{ - unsigned int rv = BPFJIT_INIT_A; - struct bpf_insn *pc; - unsigned int minm, maxm; - - BPFJIT_ASSERT(BPF_MEMWORDS - 1 <= 0xff); - - maxm = 0; - minm = BPF_MEMWORDS - 1; - - for (pc = insns; pc != insns + insn_count; pc++) { - switch (BPF_CLASS(pc->code)) { - case BPF_LD: - if (BPF_MODE(pc->code) == BPF_IND) - rv |= BPFJIT_INIT_X; - if (BPF_MODE(pc->code) == BPF_MEM && - (uint32_t)pc->k < BPF_MEMWORDS) { - if (pc->k > maxm) - maxm = pc->k; - if (pc->k < minm) - minm = pc->k; - } - continue; - case BPF_LDX: - rv |= BPFJIT_INIT_X; - if (BPF_MODE(pc->code) == BPF_MEM && - (uint32_t)pc->k < BPF_MEMWORDS) { - if (pc->k > maxm) - maxm = pc->k; - if (pc->k < minm) - minm = pc->k; - } - continue; - case BPF_ST: - if ((uint32_t)pc->k < BPF_MEMWORDS) { - if (pc->k > maxm) - maxm = pc->k; - if (pc->k < minm) - minm = pc->k; - } - continue; - case BPF_STX: - rv |= BPFJIT_INIT_X; - if ((uint32_t)pc->k < BPF_MEMWORDS) { - if (pc->k > maxm) - maxm = pc->k; - if (pc->k < minm) - minm = pc->k; - } - continue; - case BPF_ALU: - if (pc->code == (BPF_ALU|BPF_NEG)) - continue; - if (BPF_SRC(pc->code) == BPF_X) - rv |= BPFJIT_INIT_X; - continue; - case BPF_JMP: - if (pc->code == (BPF_JMP|BPF_JA)) - continue; - if (BPF_SRC(pc->code) == BPF_X) - rv |= BPFJIT_INIT_X; - continue; - case BPF_RET: - continue; - case BPF_MISC: - rv |= BPFJIT_INIT_X; - continue; - default: - BPFJIT_ASSERT(false); - } - } - - return rv | (maxm << 8) | minm; -} - /* * Convert BPF_K and BPF_X to sljit register. */ static int -kx_to_reg(struct bpf_insn *pc) +kx_to_reg(const struct bpf_insn *pc) { switch (BPF_SRC(pc->code)) { case BPF_K: return SLJIT_IMM; - case BPF_X: return BPFJIT_X; + case BPF_X: return BJ_XREG; default: - BPFJIT_ASSERT(false); + BJ_ASSERT(false); return 0; } } static sljit_w -kx_to_reg_arg(struct bpf_insn *pc) +kx_to_reg_arg(const struct bpf_insn *pc) { switch (BPF_SRC(pc->code)) { case BPF_K: return (uint32_t)pc->k; /* SLJIT_IMM, pc->k, */ - case BPF_X: return 0; /* BPFJIT_X, 0, */ + case BPF_X: return 0; /* BJ_XREG, 0, */ default: - BPFJIT_ASSERT(false); + BJ_ASSERT(false); return 0; } } @@ -1192,26 +1297,22 @@ bpfjit_func_t bpfjit_generate_code(bpf_ctx_t *bc, struct bpf_insn *insns, size_t insn_count) { void *rv; + struct sljit_compiler *compiler; + size_t i; int status; int branching, negate; unsigned int rval, mode, src; - int ntmp; - unsigned int locals_size; - unsigned int minm, maxm; /* min/max k for M[k] */ - size_t mem_locals_start; /* start of M[] array */ - unsigned int opts; - struct bpf_insn *pc; - struct sljit_compiler* compiler; - - /* a list of jumps to a normal return from a generated function */ - struct sljit_jump **returns; - size_t returns_size, returns_maxsize; + + /* optimization related */ + bpfjit_init_mask_t initmask; + int nscratches; /* a list of jumps to out-of-bound return from a generated function */ struct sljit_jump **ret0; - size_t ret0_size = 0, ret0_maxsize = 0; + size_t ret0_size, ret0_maxsize; + const struct bpf_insn *pc; struct bpfjit_insn_data *insn_dat; /* for local use */ @@ -1226,43 +1327,29 @@ bpfjit_generate_code(bpf_ctx_t *bc, stru rv = NULL; compiler = NULL; insn_dat = NULL; - returns = NULL; ret0 = NULL; - opts = bpfjit_optimization_hints(insns, insn_count); - minm = opts & 0xff; - maxm = (opts >> 8) & 0xff; - mem_locals_start = mem_local_offset(0, 0); - locals_size = (minm <= maxm) ? - mem_local_offset(maxm + 1, minm) : mem_locals_start; - - ntmp = 4; -#ifdef _KERNEL - ntmp += 1; /* for BPFJIT_KERN_TMP */ -#endif - - returns_maxsize = count_returns(insns, insn_count); - if (returns_maxsize == 0) + if (insn_count == 0 || insn_count > SIZE_MAX / sizeof(insn_dat[0])) goto fail; - insn_dat = BPFJIT_ALLOC(insn_count * sizeof(insn_dat[0])); + insn_dat = BJ_ALLOC(insn_count * sizeof(insn_dat[0])); if (insn_dat == NULL) goto fail; - if (optimize(insns, insn_dat, insn_count) < 0) + if (!optimize(insns, insn_dat, insn_count, + &initmask, &nscratches)) { goto fail; - - ret0_size = 0; - ret0_maxsize = get_ret0_size(insns, insn_dat, insn_count); - if (ret0_maxsize > 0) { - ret0 = BPFJIT_ALLOC(ret0_maxsize * sizeof(ret0[0])); - if (ret0 == NULL) - goto fail; } - returns_size = 0; - returns = BPFJIT_ALLOC(returns_maxsize * sizeof(returns[0])); - if (returns == NULL) +#if defined(_KERNEL) + /* bpf_filter() checks initialization of memwords. */ + BJ_ASSERT((initmask & BJ_INIT_MMASK) == 0); +#endif + + ret0_size = 0; + ret0_maxsize = 64; + ret0 = BJ_ALLOC(ret0_maxsize * sizeof(ret0[0])); + if (ret0 == NULL) goto fail; compiler = sljit_create_compiler(); @@ -1273,41 +1360,46 @@ bpfjit_generate_code(bpf_ctx_t *bc, stru sljit_compiler_verbose(compiler, stderr); #endif - status = sljit_emit_enter(compiler, 3, ntmp, 3, locals_size); + status = sljit_emit_enter(compiler, + 3, nscratches, 3, sizeof(struct bpfjit_stack)); if (status != SLJIT_SUCCESS) goto fail; - for (i = mem_locals_start; i < locals_size; i+= sizeof(uint32_t)) { - status = sljit_emit_op1(compiler, - SLJIT_MOV_UI, - SLJIT_MEM1(SLJIT_LOCALS_REG), i, - SLJIT_IMM, 0); - if (status != SLJIT_SUCCESS) - goto fail; + for (i = 0; i < BPF_MEMWORDS; i++) { + if (initmask & BJ_INIT_MBIT(i)) { + status = sljit_emit_op1(compiler, + SLJIT_MOV_UI, + SLJIT_MEM1(SLJIT_LOCALS_REG), + offsetof(struct bpfjit_stack, mem) + + i * sizeof(uint32_t), + SLJIT_IMM, 0); + if (status != SLJIT_SUCCESS) + goto fail; + } } - if (opts & BPFJIT_INIT_A) { + if (initmask & BJ_INIT_ABIT) { /* A = 0; */ status = sljit_emit_op1(compiler, SLJIT_MOV, - BPFJIT_A, 0, + BJ_AREG, 0, SLJIT_IMM, 0); if (status != SLJIT_SUCCESS) goto fail; } - if (opts & BPFJIT_INIT_X) { + if (initmask & BJ_INIT_XBIT) { /* X = 0; */ status = sljit_emit_op1(compiler, SLJIT_MOV, - BPFJIT_X, 0, + BJ_XREG, 0, SLJIT_IMM, 0); if (status != SLJIT_SUCCESS) goto fail; } for (i = 0; i < insn_count; i++) { - if (insn_dat[i].bj_unreachable) + if (insn_dat[i].unreachable) continue; to_mchain_jump = NULL; @@ -1316,30 +1408,32 @@ bpfjit_generate_code(bpf_ctx_t *bc, stru * Resolve jumps to the current insn. */ label = NULL; - SLIST_FOREACH(bjump, &insn_dat[i].bj_jumps, bj_entries) { - if (bjump->bj_jump != NULL) { + SLIST_FOREACH(bjump, &insn_dat[i].bjumps, entries) { + if (bjump->sjump != NULL) { if (label == NULL) label = sljit_emit_label(compiler); if (label == NULL) goto fail; - sljit_set_label(bjump->bj_jump, label); + sljit_set_label(bjump->sjump, label); } } if (read_pkt_insn(&insns[i], NULL) && - insn_dat[i].bj_aux.bj_rdata.bj_check_length > 0) { - /* if (buflen < bj_check_length) return 0; */ + insn_dat[i].u.rdata.check_length > 0) { + /* if (buflen < check_length) return 0; */ jump = sljit_emit_cmp(compiler, SLJIT_C_LESS, - BPFJIT_BUFLEN, 0, + BJ_BUFLEN, 0, SLJIT_IMM, - insn_dat[i].bj_aux.bj_rdata.bj_check_length); + insn_dat[i].u.rdata.check_length); if (jump == NULL) goto fail; #ifdef _KERNEL to_mchain_jump = jump; #else - ret0[ret0_size++] = jump; + if (!append_jump(jump, &ret0, + &ret0_size, &ret0_maxsize)) + goto fail; #endif } @@ -1354,7 +1448,7 @@ bpfjit_generate_code(bpf_ctx_t *bc, stru if (pc->code == (BPF_LD|BPF_IMM)) { status = sljit_emit_op1(compiler, SLJIT_MOV, - BPFJIT_A, 0, + BJ_AREG, 0, SLJIT_IMM, (uint32_t)pc->k); if (status != SLJIT_SUCCESS) goto fail; @@ -1364,13 +1458,14 @@ bpfjit_generate_code(bpf_ctx_t *bc, stru /* BPF_LD+BPF_MEM A <- M[k] */ if (pc->code == (BPF_LD|BPF_MEM)) { - if (pc->k < minm || pc->k > maxm) + if (pc->k >= BPF_MEMWORDS) goto fail; status = sljit_emit_op1(compiler, SLJIT_MOV_UI, - BPFJIT_A, 0, + BJ_AREG, 0, SLJIT_MEM1(SLJIT_LOCALS_REG), - mem_local_offset(pc->k, minm)); + offsetof(struct bpfjit_stack, mem) + + pc->k * sizeof(uint32_t)); if (status != SLJIT_SUCCESS) goto fail; @@ -1381,8 +1476,8 @@ bpfjit_generate_code(bpf_ctx_t *bc, stru if (pc->code == (BPF_LD|BPF_W|BPF_LEN)) { status = sljit_emit_op1(compiler, SLJIT_MOV, - BPFJIT_A, 0, - BPFJIT_WIRELEN, 0); + BJ_AREG, 0, + BJ_WIRELEN, 0); if (status != SLJIT_SUCCESS) goto fail; @@ -1394,7 +1489,7 @@ bpfjit_generate_code(bpf_ctx_t *bc, stru goto fail; status = emit_pkt_read(compiler, pc, - to_mchain_jump, ret0, &ret0_size); + to_mchain_jump, &ret0, &ret0_size, &ret0_maxsize); if (status != SLJIT_SUCCESS) goto fail; @@ -1409,7 +1504,7 @@ bpfjit_generate_code(bpf_ctx_t *bc, stru goto fail; status = sljit_emit_op1(compiler, SLJIT_MOV, - BPFJIT_X, 0, + BJ_XREG, 0, SLJIT_IMM, (uint32_t)pc->k); if (status != SLJIT_SUCCESS) goto fail; @@ -1423,8 +1518,8 @@ bpfjit_generate_code(bpf_ctx_t *bc, stru goto fail; status = sljit_emit_op1(compiler, SLJIT_MOV, - BPFJIT_X, 0, - BPFJIT_WIRELEN, 0); + BJ_XREG, 0, + BJ_WIRELEN, 0); if (status != SLJIT_SUCCESS) goto fail; @@ -1435,13 +1530,14 @@ bpfjit_generate_code(bpf_ctx_t *bc, stru if (mode == BPF_MEM) { if (BPF_SIZE(pc->code) != BPF_W) goto fail; - if (pc->k < minm || pc->k > maxm) + if (pc->k >= BPF_MEMWORDS) goto fail; status = sljit_emit_op1(compiler, SLJIT_MOV_UI, - BPFJIT_X, 0, + BJ_XREG, 0, SLJIT_MEM1(SLJIT_LOCALS_REG), - mem_local_offset(pc->k, minm)); + offsetof(struct bpfjit_stack, mem) + + pc->k * sizeof(uint32_t)); if (status != SLJIT_SUCCESS) goto fail; @@ -1453,47 +1549,48 @@ bpfjit_generate_code(bpf_ctx_t *bc, stru goto fail; status = emit_msh(compiler, pc, - to_mchain_jump, ret0, &ret0_size); + to_mchain_jump, &ret0, &ret0_size, &ret0_maxsize); if (status != SLJIT_SUCCESS) goto fail; continue; case BPF_ST: - if (pc->code != BPF_ST || pc->k < minm || pc->k > maxm) + if (pc->code != BPF_ST || pc->k >= BPF_MEMWORDS) goto fail; status = sljit_emit_op1(compiler, SLJIT_MOV_UI, SLJIT_MEM1(SLJIT_LOCALS_REG), - mem_local_offset(pc->k, minm), - BPFJIT_A, 0); + offsetof(struct bpfjit_stack, mem) + + pc->k * sizeof(uint32_t), + BJ_AREG, 0); if (status != SLJIT_SUCCESS) goto fail; continue; case BPF_STX: - if (pc->code != BPF_STX || pc->k < minm || pc->k > maxm) + if (pc->code != BPF_STX || pc->k >= BPF_MEMWORDS) goto fail; status = sljit_emit_op1(compiler, SLJIT_MOV_UI, SLJIT_MEM1(SLJIT_LOCALS_REG), - mem_local_offset(pc->k, minm), - BPFJIT_X, 0); + offsetof(struct bpfjit_stack, mem) + + pc->k * sizeof(uint32_t), + BJ_XREG, 0); if (status != SLJIT_SUCCESS) goto fail; continue; case BPF_ALU: - if (pc->code == (BPF_ALU|BPF_NEG)) { status = sljit_emit_op1(compiler, SLJIT_NEG, - BPFJIT_A, 0, - BPFJIT_A, 0); + BJ_AREG, 0, + BJ_AREG, 0); if (status != SLJIT_SUCCESS) goto fail; @@ -1503,8 +1600,8 @@ bpfjit_generate_code(bpf_ctx_t *bc, stru if (BPF_OP(pc->code) != BPF_DIV) { status = sljit_emit_op2(compiler, bpf_alu_to_sljit_op(pc), - BPFJIT_A, 0, - BPFJIT_A, 0, + BJ_AREG, 0, + BJ_AREG, 0, kx_to_reg(pc), kx_to_reg_arg(pc)); if (status != SLJIT_SUCCESS) goto fail; @@ -1522,20 +1619,24 @@ bpfjit_generate_code(bpf_ctx_t *bc, stru if (src == BPF_X) { jump = sljit_emit_cmp(compiler, SLJIT_C_EQUAL|SLJIT_INT_OP, - BPFJIT_X, 0, + BJ_XREG, 0, SLJIT_IMM, 0); if (jump == NULL) goto fail; - ret0[ret0_size++] = jump; + if (!append_jump(jump, &ret0, + &ret0_size, &ret0_maxsize)) + goto fail; } else if (pc->k == 0) { jump = sljit_emit_jump(compiler, SLJIT_JUMP); if (jump == NULL) goto fail; - ret0[ret0_size++] = jump; + if (!append_jump(jump, &ret0, + &ret0_size, &ret0_maxsize)) + goto fail; } if (src == BPF_X) { - status = emit_division(compiler, BPFJIT_X, 0); + status = emit_division(compiler, BJ_XREG, 0); if (status != SLJIT_SUCCESS) goto fail; } else if (pc->k != 0) { @@ -1543,7 +1644,7 @@ bpfjit_generate_code(bpf_ctx_t *bc, stru status = emit_division(compiler, SLJIT_IMM, (uint32_t)pc->k); } else { - status = emit_pow2_division(compiler, + status = emit_pow2_division(compiler, (uint32_t)pc->k); } if (status != SLJIT_SUCCESS) @@ -1553,8 +1654,7 @@ bpfjit_generate_code(bpf_ctx_t *bc, stru continue; case BPF_JMP: - - if (pc->code == (BPF_JMP|BPF_JA)) { + if (BPF_OP(pc->code) == BPF_JA) { jt = jf = pc->k; } else { jt = pc->jt; @@ -1563,34 +1663,34 @@ bpfjit_generate_code(bpf_ctx_t *bc, stru negate = (jt == 0) ? 1 : 0; branching = (jt == jf) ? 0 : 1; - jtf = insn_dat[i].bj_aux.bj_jdata.bj_jtf; + jtf = insn_dat[i].u.jdata.jtf; if (branching) { if (BPF_OP(pc->code) != BPF_JSET) { jump = sljit_emit_cmp(compiler, bpf_jmp_to_sljit_cond(pc, negate), - BPFJIT_A, 0, + BJ_AREG, 0, kx_to_reg(pc), kx_to_reg_arg(pc)); } else { status = sljit_emit_op2(compiler, SLJIT_AND, - BPFJIT_TMP1, 0, - BPFJIT_A, 0, + BJ_TMP1REG, 0, + BJ_AREG, 0, kx_to_reg(pc), kx_to_reg_arg(pc)); if (status != SLJIT_SUCCESS) goto fail; jump = sljit_emit_cmp(compiler, bpf_jmp_to_sljit_cond(pc, negate), - BPFJIT_TMP1, 0, + BJ_TMP1REG, 0, SLJIT_IMM, 0); } if (jump == NULL) goto fail; - BPFJIT_ASSERT(jtf[negate].bj_jump == NULL); - jtf[negate].bj_jump = jump; + BJ_ASSERT(jtf[negate].sjump == NULL); + jtf[negate].sjump = jump; } if (!branching || (jt != 0 && jf != 0)) { @@ -1598,23 +1698,21 @@ bpfjit_generate_code(bpf_ctx_t *bc, stru if (jump == NULL) goto fail; - BPFJIT_ASSERT(jtf[branching].bj_jump == NULL); - jtf[branching].bj_jump = jump; + BJ_ASSERT(jtf[branching].sjump == NULL); + jtf[branching].sjump = jump; } continue; case BPF_RET: - rval = BPF_RVAL(pc->code); if (rval == BPF_X) goto fail; /* BPF_RET+BPF_K accept k bytes */ if (rval == BPF_K) { - status = sljit_emit_op1(compiler, - SLJIT_MOV, - BPFJIT_A, 0, + status = sljit_emit_return(compiler, + SLJIT_MOV_UI, SLJIT_IMM, (uint32_t)pc->k); if (status != SLJIT_SUCCESS) goto fail; @@ -1622,49 +1720,32 @@ bpfjit_generate_code(bpf_ctx_t *bc, stru /* BPF_RET+BPF_A accept A bytes */ if (rval == BPF_A) { -#if BPFJIT_A != SLJIT_RETURN_REG - status = sljit_emit_op1(compiler, - SLJIT_MOV, - SLJIT_RETURN_REG, 0, - BPFJIT_A, 0); + status = sljit_emit_return(compiler, + SLJIT_MOV_UI, + BJ_AREG, 0); if (status != SLJIT_SUCCESS) goto fail; -#endif - } - - /* - * Save a jump to a normal return. If the program - * ends with BPF_RET, no jump is needed because - * the normal return is generated right after the - * last instruction. - */ - if (i != insn_count - 1) { - jump = sljit_emit_jump(compiler, SLJIT_JUMP); - if (jump == NULL) - goto fail; - returns[returns_size++] = jump; } continue; case BPF_MISC: - - if (pc->code == (BPF_MISC|BPF_TAX)) { + switch (BPF_MISCOP(pc->code)) { + case BPF_TAX: status = sljit_emit_op1(compiler, SLJIT_MOV_UI, - BPFJIT_X, 0, - BPFJIT_A, 0); + BJ_XREG, 0, + BJ_AREG, 0); if (status != SLJIT_SUCCESS) goto fail; continue; - } - if (pc->code == (BPF_MISC|BPF_TXA)) { + case BPF_TXA: status = sljit_emit_op1(compiler, SLJIT_MOV, - BPFJIT_A, 0, - BPFJIT_X, 0); + BJ_AREG, 0, + BJ_XREG, 0); if (status != SLJIT_SUCCESS) goto fail; @@ -1675,45 +1756,22 @@ bpfjit_generate_code(bpf_ctx_t *bc, stru } /* switch */ } /* main loop */ - BPFJIT_ASSERT(ret0_size == ret0_maxsize); - BPFJIT_ASSERT(returns_size <= returns_maxsize); + BJ_ASSERT(ret0_size <= ret0_maxsize); - if (returns_size > 0) { + if (ret0_size > 0) { label = sljit_emit_label(compiler); if (label == NULL) goto fail; - for (i = 0; i < returns_size; i++) - sljit_set_label(returns[i], label); + for (i = 0; i < ret0_size; i++) + sljit_set_label(ret0[i], label); } status = sljit_emit_return(compiler, SLJIT_MOV_UI, - BPFJIT_A, 0); + SLJIT_IMM, 0); if (status != SLJIT_SUCCESS) goto fail; - if (ret0_size > 0) { - label = sljit_emit_label(compiler); - if (label == NULL) - goto fail; - - for (i = 0; i < ret0_size; i++) - sljit_set_label(ret0[i], label); - - status = sljit_emit_op1(compiler, - SLJIT_MOV, - SLJIT_RETURN_REG, 0, - SLJIT_IMM, 0); - if (status != SLJIT_SUCCESS) - goto fail; - - status = sljit_emit_return(compiler, - SLJIT_MOV_UI, - SLJIT_RETURN_REG, 0); - if (status != SLJIT_SUCCESS) - goto fail; - } - rv = sljit_generate_code(compiler); fail: @@ -1721,13 +1779,10 @@ fail: sljit_free_compiler(compiler); if (insn_dat != NULL) - BPFJIT_FREE(insn_dat, insn_count * sizeof(insn_dat[0])); - - if (returns != NULL) - BPFJIT_FREE(returns, returns_maxsize * sizeof(returns[0])); + BJ_FREE(insn_dat, insn_count * sizeof(insn_dat[0])); if (ret0 != NULL) - BPFJIT_FREE(ret0, ret0_maxsize * sizeof(ret0[0])); + BJ_FREE(ret0, ret0_maxsize * sizeof(ret0[0])); return (bpfjit_func_t)rv; } @@ -1735,5 +1790,6 @@ fail: void bpfjit_free_code(bpfjit_func_t code) { + sljit_free_code((void *)code); }