This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new a49273e050 Enable conv family fused with mish (#12228) a49273e050 is described below commit a49273e05092480bde8593c6a137bb251b5dee6c Author: billishyahao <yahao...@intel.com> AuthorDate: Mon Aug 1 16:56:49 2022 +0800 Enable conv family fused with mish (#12228) --- python/tvm/relay/op/contrib/dnnl.py | 11 +++++++++-- src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 4 ++++ tests/python/contrib/test_dnnl.py | 15 +++++++++++---- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index f17b325dce..46c20e947f 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -53,7 +53,7 @@ from .register import register_pattern_table logger = logging.getLogger("DNNL") -supported_post_elts = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None] +supported_post_elts = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", "mish", None] def _register_external_op_helper(op_name, supported=True): @@ -137,6 +137,13 @@ def append_eltwise_ops(op, eltwise): elif eltwise == "swish": sig_out = is_op("sigmoid")(op) op = is_op("multiply")(op, sig_out) + elif eltwise == "mish": + const1 = wildcard() + exp = is_op("exp")(op) + add = is_op("add")(exp, const1) + log = is_op("log")(add) + tanh = is_op("tanh")(log) + op = is_op("multiply")(op, tanh) elif eltwise: op = is_op(eltwise)(op) return op @@ -411,7 +418,7 @@ def pattern_table(): ) ) - elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None] + elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", "mish", None] for with_bias in [True, False]: for elt in elt_list: if not with_bias and not elt: diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index 1fe8fccc77..d019f4e811 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -191,6 +191,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { std::regex gelu_pat(".*_gelu.*"); std::regex swish_pat(".*_swish.*"); std::regex sum_pat(".*_sum.*"); + std::regex mish_pat(".*_mish.*"); // parsing of name to extract attributes auto op_name = nodes_[nid].GetOpName(); @@ -220,6 +221,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase { if (std::regex_match(op_name, gelu_pat)) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f); } + if (std::regex_match(op_name, mish_pat)) { + ops.append_eltwise(1.f, dnnl::algorithm::eltwise_mish, 1.f, 0.f); + } if (ops.len() != 0) { attr.set_post_ops(ops); } diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 74d0da1238..8de8bd9ce6 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -252,6 +252,13 @@ def add_activation(activation, out, dic, param_lst): elif activation == "gelu": out = gelu_helper(out) return out, dic, param_lst + elif activation == "mish": + exp = relay.exp(out) + add = relay.add(exp, relay.const(1.0)) + log = relay.log(add) + tanh = relay.tanh(log) + out = relay.multiply(out, tanh) + return out, dic, param_lst else: return out, dic, param_lst @@ -765,7 +772,7 @@ def test_conv2d_weights_const(run_module, dtype="float32"): def test_conv2d_pattern(run_module, dtype="float32"): x_shape = (1, 32, 8, 8) k_shape = (16, 32, 3, 3) - activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu"] + activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu", "mish"] for a in activation_lst: conv2d, dic, param_lst = get_conv2d(x_shape, k_shape, activation=a, dtype=dtype) conv2d = tvm.IRModule.from_expr(conv2d) @@ -849,7 +856,7 @@ def test_conv2d_transpose(run_module, dtype="float32"): def test_conv2d_transpose_pattern(run_module, dtype="float32"): - activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu"] + activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu", "mish"] for a in activation_lst: conv2d, dic, param_lst = get_conv2d_transpose(activation=a, dtype=dtype) conv2d = tvm.IRModule.from_expr(conv2d) @@ -882,7 +889,7 @@ def test_conv3d(run_module, dtype="float32"): def test_conv3d_pattern(run_module, dtype="float32"): - activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu"] + activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu", "mish"] for a in activation_lst: conv3d, dic, param_lst = get_conv3d(activation=a, dtype=dtype) conv3d = tvm.IRModule.from_expr(conv3d) @@ -915,7 +922,7 @@ def test_conv3d_transpose(run_module, dtype="float32"): def test_conv3d_transpose_pattern(run_module, dtype="float32"): - activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu"] + activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish", "gelu", "mish"] for a in activation_lst: conv3d, dic, param_lst = get_conv3d_transpose(activation=a, dtype=dtype) conv3d = tvm.IRModule.from_expr(conv3d)