This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 0406b9af22 [Unity][Frontend] Add relax onnx importer and tests (#14999)
0406b9af22 is described below
commit 0406b9af22ee0e6d8db7bbba10a74ecdedfacb6d
Author: Josh Fromm <[email protected]>
AuthorDate: Sat Jun 10 11:03:37 2023 -0400
[Unity][Frontend] Add relax onnx importer and tests (#14999)
This PR adds a direct ONNX to relax importer and test suite. It has decent
out of the box coverage and has been tested for numerous end to end use cases.
I hope this will be a valuable tool for the Unity community.
---
python/tvm/ir/supply.py | 7 +-
python/tvm/relax/frontend/onnx/__init__.py | 20 +
python/tvm/relax/frontend/onnx/onnx_frontend.py | 2173 +++++++++++++++++++++++
tests/python/relax/test_frontend_onnx.py | 1619 +++++++++++++++++
4 files changed, 3817 insertions(+), 2 deletions(-)
diff --git a/python/tvm/ir/supply.py b/python/tvm/ir/supply.py
index 095ac43c03..a501e8849e 100644
--- a/python/tvm/ir/supply.py
+++ b/python/tvm/ir/supply.py
@@ -32,7 +32,7 @@ class NameSupply(Object):
def __init__(self, prefix=""):
self.__init_handle_by_constructor__(_ffi_api.NameSupply, prefix)
- def fresh_name(self, name, add_prefix=True):
+ def fresh_name(self, name, add_prefix=True, add_underscore=True):
"""Generates a unique name from this NameSupply.
Parameters
@@ -42,8 +42,11 @@ class NameSupply(Object):
add_prefix: bool
If set to true, then the prefix of this NameSupply will be
prepended to the name.
+
+ add_underscore: bool
+ If set to True, adds '_' between prefix and digit.
"""
- return _ffi_api.NameSupply_FreshName(self, name, add_prefix)
+ return _ffi_api.NameSupply_FreshName(self, name, add_prefix,
add_underscore)
def reserve_name(self, name, add_prefix=True):
"""Reserves an existing name with this NameSupply.
diff --git a/python/tvm/relax/frontend/onnx/__init__.py
b/python/tvm/relax/frontend/onnx/__init__.py
new file mode 100644
index 0000000000..42f9f69383
--- /dev/null
+++ b/python/tvm/relax/frontend/onnx/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Tools for converting ONNX graphs into Relax graphs.
+"""
+from .onnx_frontend import from_onnx
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
new file mode 100644
index 0000000000..d653bb5511
--- /dev/null
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -0,0 +1,2173 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""ONNX: Open Neural Network Exchange importer for Relax.
+
+This module implemnets the required functionality to read ONNX models
+and convert them into equivalent Relax functions. The entry point that
encapsulates
+this functionality is the function from_onnx.
+
+In order to extend the functionality of the importer, you can add new
+operators to the operator registry. The operator registry is a dictionary
+that maps operator names to operator converters. The registry is defined
+in the _get_converter_map function. To add a new operator, you can define
+a new class that inherits from the OnnxOpConverter class and implement
+the _impl method.
+
+By default, ONNX defines models in terms of dynamic shapes. The ONNX importer
+retains dynamic shapes upon import, and when possible, the compiler attempts to
+convert the model to use static shapes at compile time.
+If this fails, there may still be dynamic operations in the model.
+Not all TVM kernels currently support dynamic shapes, please file an issue on
+github.com/apache/tvm/issues if you hit an error with dynamic kernels.
+"""
+import warnings
+from typing import Union, Tuple, Optional, List, Dict, Any
+
+import numpy as _np
+
+import tvm
+from tvm import relax, topi
+from tvm.ir import IRModule
+from tvm.ir.supply import NameSupply
+
+import onnx.onnx_ml_pb2
+
+
+def get_type(elem_type: Union[str, int]) -> str:
+ """Converts onnx integer datatype to numpy datatype"""
+ # If a string was passed instead of a tensor type, it does not need
+ # conversion and can be returned.
+ if isinstance(elem_type, str):
+ return elem_type
+
+ try:
+ from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE # pylint:
disable=import-outside-toplevel
+ except ImportError as exception:
+ raise ImportError("Unable to import onnx which is required
{}".format(exception))
+
+ return str(TENSOR_TYPE_TO_NP_TYPE[elem_type])
+
+
+def get_constant(
+ var: Union[relax.Constant, relax.Var],
+ params: List[Dict[str, relax.Var]],
+) -> Union[relax.Constant, relax.Var]:
+ """Attempt to convert a variable to a constant if possible.
+ This is the primary function meant to interact with params.
+
+ Parameters
+ ----------
+ var: Union[relax.Constant, relax.Var]
+ The input value to try to convert to a constant.
+ params: List[Dict[str, relax.Var]]
+ The parameters for the graph. Contains both the global registry of
nodes
+ for the graph and the parameter dictionary. The global registry is
updated
+ with a constant value if possible.
+
+ Returns
+ -------
+ var : Union[relax.Constant, relax.Var]
+ The input value converted to a constant if possible. If the value
+ isn't found in params, the input variable is returned unmodified.
+ """
+ # Params is actually both the graph nodes and param dictionary, unpack
them.
+ graph_nodes, params = params
+ # Convert if possible
+ if isinstance(var, relax.Var) and var.name_hint in params:
+ # When converting a parameter to a constant, update references to it
as well.
+ _, value = params.pop(var.name_hint)
+ const_value = relax.const(value)
+ graph_nodes[var.name_hint] = const_value
+ return const_value
+ # Otherwise return variable.
+ else:
+ return var
+
+
+def get_info(info_proto: onnx.onnx_ml_pb2.ValueInfoProto) -> Tuple[str, List,
str, List]:
+ """Extract the shape from a ValueInfoProto.
+
+ Parameters
+ ----------
+ info_proto: onnx.onnx_ml_pb2.ValueInfoProto
+ The ValueInfoProto to extract the info from.
+
+ Returns
+ -------
+ Tuple[str, List, str, List]
+ The name, shape, type, and shape name of the ValueInfoProto.
+ """
+ shape = []
+ shape_name = []
+ for dim in info_proto.type.tensor_type.shape.dim:
+ name = dim.dim_param
+ value = dim.dim_value
+ if value is None or value == 0:
+ value = tvm.tir.Var("dyn", "int64")
+ shape_name.append(name)
+ else:
+ shape_name.append(value)
+ shape.append(value)
+
+ name = info_proto.name
+ if info_proto.type.tensor_type.elem_type:
+ dtype = get_type(info_proto.type.tensor_type.elem_type)
+ else:
+ dtype = None
+ return name, shape, dtype, shape_name
+
+
+def get_numpy(tensor_proto: onnx.onnx_ml_pb2.TensorProto) -> _np.ndarray:
+ """Grab data in TensorProto and convert to numpy array."""
+ try:
+ from onnx.numpy_helper import to_array # pylint:
disable=import-outside-toplevel
+ except ImportError as exception:
+ raise ImportError("Unable to import onnx which is required
{}".format(exception))
+ return to_array(tensor_proto)
+
+
+class onnx_input(list): # pylint: disable=invalid-name
+ """A list that returns None when out-of-bounds indices are accessed."""
+
+ def __getitem__(self, item):
+ if isinstance(item, slice):
+ if item.stop is None:
+ stop = len(self)
+ else:
+ stop = item.stop
+ indices = list(range(stop)[item])
+ return [self[i] for i in indices]
+ if isinstance(item, int):
+ return list(self)[item] if item < len(self) else None
+ raise TypeError("list indices must be integers or slices, not %s" %
type(item).__name__)
+
+
+# pylint: disable=invalid-name, len-as-condition, unused-argument,
too-many-lines, redefined-builtin
+class OnnxOpConverter(object):
+ """A helper class for holding the common logic for ONNX op converters.
+ Each converter maps to a single ONNX op and defines the equivalent
+ functionality using Relax expressions. The converter can define multiple
versions
+ of the op and the version is selected based on the opset version of the
model.
+ """
+
+ @classmethod
+ def get_converter(cls, opset):
+ """Get converter matches given opset.
+
+ Parameters
+ ----------
+ opset: int
+ opset from model.
+
+ Returns
+ -------
+ converter, which should be `_impl_vx`. Number x is the biggest
+ number smaller than or equal to opset belongs to all support
versions.
+ """
+ versions = [int(d.replace("_impl_v", "")) for d in dir(cls) if
"_impl_v" in d]
+ versions = sorted(versions + [opset])
+ version = versions[max([i for i, v in enumerate(versions) if v ==
opset]) - 1]
+ if hasattr(cls, "_impl_v{}".format(version)):
+ return getattr(cls, "_impl_v{}".format(version))
+ raise NotImplementedError(
+ "opset version {} of {} not implemented".format(version,
cls.__name__)
+ )
+
+
+class MatMul(OnnxOpConverter):
+ """Converts an onnx MatMul node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ return relax.op.matmul(inputs[0], inputs[1])
+
+
+class Div(OnnxOpConverter):
+ """Converts an onnx Div node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v14(cls, bb, inputs, attr, params):
+ if all([isinstance(inp, relax.Constant) for inp in inputs]):
+ output = inputs[0].data.numpy() / inputs[1].data.numpy()
+ return relax.const(output, inputs[0].struct_info.dtype)
+ return relax.op.divide(inputs[0], inputs[1])
+
+
+class Sigmoid(OnnxOpConverter):
+ """Converts an onnx Sigmoid node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ return relax.op.sigmoid(inputs[0])
+
+
+class Softmax(OnnxOpConverter):
+ """Converts an onnx Softmax node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ axis = attr.get("axis", -1)
+ return relax.op.nn.softmax(inputs[0], axis=axis)
+
+
+class Transpose(OnnxOpConverter):
+ """Converts an onnx Transpose node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ axes = attr.get("perm", None)
+ if isinstance(inputs[0], relax.Constant):
+ output = _np.transpose(inputs[0].data.numpy(), axes)
+ return relax.const(output, output.dtype)
+ return relax.op.permute_dims(inputs[0], axes)
+
+
+class Unsqueeze(OnnxOpConverter):
+ """Converts an onnx Unsqueeze node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ axes = list(attr.get("axes"))
+ inputs = inputs + [relax.const(axes, "int64")]
+ return cls._impl_v13(bb, inputs, attr, params)
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ axes = get_constant(inputs[1], params)
+
+ # If input is a constant, compute directly
+ if isinstance(data, relax.Constant) and isinstance(axes,
relax.Constant):
+ axes = axes.data.numpy().tolist()
+ expanded = data.data.numpy()
+ if len(expanded.shape) == 0:
+ # Special case implying input is a scalar, wrap it as a list.
+ if 0 in axes:
+ axes.remove(0)
+ expanded = [expanded]
+ for axis in axes:
+ expanded = _np.expand_dims(expanded, axis=axis)
+ return relax.const(expanded, data.struct_info.dtype)
+
+ if isinstance(axes, relax.Constant):
+ constant_axes = list(axes.data.numpy())
+ constant_axes = list(map(int, constant_axes))
+ constant_axes = sorted(constant_axes)
+ for axis in constant_axes:
+ data = relax.op.expand_dims(data, axis=axis)
+ return data
+
+ raise NotImplementedError("Unsqueeze with dynamic axes is not
supported.")
+
+
+class Concat(OnnxOpConverter):
+ """Convert an onnx Concat node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ axis = attr.get("axis", 0)
+ # If all inputs are constant, perform computation directly.
+ if all([isinstance(inp, relax.Constant) for inp in inputs]):
+ const_inputs = []
+ for inp in inputs:
+ const_inputs.append(inp.data.numpy())
+ out = _np.concatenate(const_inputs, axis=axis)
+ dtype = inputs[0].struct_info.dtype
+ return relax.const(out, dtype)
+ return relax.op.concat(inputs, axis=axis)
+
+
+class Add(OnnxOpConverter):
+ """Convert an onnx Add node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ if all([isinstance(inp, relax.Constant) for inp in inputs]):
+ output = inputs[0].data.numpy() + inputs[1].data.numpy()
+ return relax.const(output, output.dtype)
+ return relax.op.add(inputs[0], inputs[1])
+
+
+class Mul(OnnxOpConverter):
+ """Convert an onnx Mul node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ if all([isinstance(inp, relax.Constant) for inp in inputs]):
+ output = inputs[0].data.numpy() * inputs[1].data.numpy()
+ return relax.const(output, output.dtype)
+ return relax.op.multiply(inputs[0], inputs[1])
+
+
+class Cast(OnnxOpConverter):
+ """Convert an onnx Cast node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ to_type = get_type(attr["to"])
+ if isinstance(inputs[0], relax.Constant):
+ output = inputs[0].data.numpy().astype(to_type)
+ return relax.const(output, to_type)
+ return relax.op.astype(inputs[0], to_type)
+
+
+class Gather(OnnxOpConverter):
+ """Convert an onnx Gather node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ # Unpack inputs
+ data = inputs[0]
+ indices = inputs[1]
+ axis = attr.get("axis", 0)
+
+ # If all inputs are constant, we can compute directly.
+ if all([isinstance(inp, relax.Constant) for inp in [data, indices]]):
+ output = _np.take(data.data.numpy(), indices.data.numpy(),
axis=axis)
+ return relax.const(output, output.dtype)
+
+ # If input is a shape expression, take a value from that shape and
return it as a constant.
+ if isinstance(data, relax.ShapeExpr):
+ assert isinstance(
+ indices, relax.Constant
+ ), "Only constant indices supported for shape gather."
+ np_index = indices.data.numpy()
+ if len(np_index.shape) == 1:
+ np_index = np_index[0]
+ np_index = int(np_index)
+ shape_val = data[np_index]
+ if hasattr(shape_val, "value"):
+ return relax.const(shape_val.value, dtype="int64")
+ else:
+ raise ValueError("Need to fix this case.")
+
+ # TODO(jwfromm) Make relax.take work with other indices shape.
+ return bb.emit_te(topi.take, data, indices, axis)
+
+
+class Gemm(OnnxOpConverter):
+ """Convert an onnx Gemm node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ alpha = attr.get("alpha", None)
+ beta = attr.get("beta", None)
+ transA = attr.get("transA", False)
+ transB = attr.get("transB", False)
+ A = inputs[0]
+ B = inputs[1]
+ C = inputs[2]
+ dtype = A.checked_type.dtype
+
+ # Compute Y = alpha * A X B + beta * C
+
+ if alpha is not None:
+ A = bb.normalize(relax.op.multiply(A, relax.const(alpha,
dtype=dtype)))
+
+ if transA:
+ A = relax.op.permute_dims(A, [1, 0])
+ if transB:
+ B = relax.op.permute_dims(B, [1, 0])
+ Y = bb.normalize(relax.op.matmul(A, B))
+
+ if C is not None:
+ if beta is not None:
+ C = bb.normalize(relax.op.multiply(C, relax.const(beta,
dtype=dtype)))
+ Y = relax.op.add(Y, C)
+
+ return Y
+
+
+class Reshape(OnnxOpConverter):
+ """Convert an onnx Reshape node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ new_shape = get_constant(inputs[1], params)
+
+ if isinstance(data, relax.Constant) and isinstance(new_shape,
relax.Constant):
+ out = _np.reshape(data.data.numpy(),
new_shape.data.numpy().tolist())
+ return relax.const(out, out.dtype)
+ if isinstance(new_shape, relax.Constant):
+ new_shape = new_shape.data.numpy().tolist()
+ out = relax.op.reshape(data, new_shape)
+ return out
+
+
+class Gelu(OnnxOpConverter):
+ """Operator converter for Gelu from Microsoft onnxruntime contrib opset.
+
+ gelu(x) = 0.5x(1 + erf(x/sqrt(2)))
+ """
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return relax.op.nn.gelu(inputs[0])
+
+
+class BiasGelu(OnnxOpConverter):
+ """Operator converter for BiasGelu from Microsoft onnxruntime contrib
opset.
+
+ bias_gelu(x, b) = 0.5(x + b)(1 + erf((x + b)/sqrt(2)))
+ """
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ inp = relax.op.add(inputs[0], inputs[1])
+ return relax.op.nn.gelu(inp)
+
+
+class Where(OnnxOpConverter):
+ """Convert an onnx Where node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v16(cls, bb, inputs, attr, params):
+ if all([isinstance(inp, relax.Constant) for inp in inputs]):
+ np_inputs = [inp.data.numpy() for inp in inputs]
+ output = _np.where(*np_inputs)
+ return relax.const(output, output.dtype)
+ return relax.op.where(inputs[0], inputs[1], inputs[2])
+
+
+class Clip(OnnxOpConverter):
+ """Converts an onnx Clip node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ results = inputs[0]
+ if inputs[1] is not None:
+ results = bb.emit_te(topi.maximum, results, inputs[1])
+ if inputs[2] is not None:
+ results = bb.emit_te(topi.minimum, results, inputs[2])
+ return results
+
+
+class Equal(OnnxOpConverter):
+ """Converts an onnx Equal node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ if all([isinstance(inp, relax.Constant) for inp in inputs]):
+ output = inputs[0].data.numpy() == inputs[1].data.numpy()
+ return relax.const(output, output.dtype)
+ return relax.op.equal(inputs[0], inputs[1])
+
+
+class Shape(OnnxOpConverter):
+ """Converts an onnx Equal node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ data_info = inputs[0].struct_info
+
+ # If no shape is defined in the struct info, it must be computed at
runtime.
+ if not data_info.shape:
+ data_shape = bb.normalize(relax.op.shape_of(inputs[0]))
+ return data_shape
+
+ return data_info.shape
+
+
+class Tanh(OnnxOpConverter):
+ """Converts an onnx Tanh node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ return relax.op.tanh(inputs[0])
+
+
+class Sqrt(OnnxOpConverter):
+ """Converts an onnx Sqrt node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ return relax.op.sqrt(inputs[0])
+
+
+class Relu(OnnxOpConverter):
+ """Converts an onnx Relu node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ return relax.op.nn.relu(inputs[0])
+
+
+class Pow(OnnxOpConverter):
+ """Converts an onnx Pow node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ return relax.op.power(inputs[0], inputs[1])
+
+
+class Conv(OnnxOpConverter):
+ """Convert an onnx Conv node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ if hasattr(inputs[0].struct_info, "ndim"):
+ ndim = inputs[0].struct_info.ndim
+ else:
+ ndim = len(inputs[0].struct_info.shape)
+
+ if ndim == 3:
+ conv_out = bb.emit_te(
+ topi.nn.conv1d,
+ inputs[0],
+ inputs[1],
+ attr.get("strides", 1),
+ attr.get("pads", 0),
+ attr.get("dilation", 1),
+ "NCHW",
+ "OIHW",
+ )
+ elif ndim == 4:
+ conv_out = bb.normalize(
+ relax.op.nn.conv2d(
+ data=inputs[0],
+ weight=inputs[1],
+ strides=attr.get("strides", 1),
+ padding=attr.get("pads", 0),
+ dilation=attr.get("dilation", 1),
+ groups=attr.get("group", 1),
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ )
+ )
+ else:
+ raise NotImplementedError("Only 2d conv currently supported.")
+
+ if inputs[2] is not None:
+ bias = relax.op.reshape(
+ inputs[2],
+ [1, -1]
+ + [
+ 1,
+ ]
+ * (ndim - 2),
+ )
+ conv_out = relax.op.add(conv_out, bias)
+
+ return conv_out
+
+
+class Erf(OnnxOpConverter):
+ """Converts an onnx Erf node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ x = inputs[0]
+ sqrt2 = relax.const(_np.sqrt(2), x.struct_info.dtype)
+ # TODO: replace with erf operator once it is implemented
+ mul = relax.op.multiply(x, sqrt2)
+ gelu = relax.op.nn.gelu(mul)
+ mul_2 = relax.op.multiply(gelu, sqrt2)
+ return bb.normalize(
+ relax.op.add(
+ relax.op.divide(mul_2, x),
+ relax.const(-1, x.struct_info.dtype),
+ )
+ )
+
+
+class CumSum(OnnxOpConverter):
+ """Converts an onnx CumSum node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v14(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ axis = get_constant(inputs[1], params)
+ assert not attr.get("exclusive", False), "Exclusive option not yet
supported."
+
+ if isinstance(axis, relax.Constant):
+ axis = int(axis.data.numpy())
+ data = relax.op.cumsum(data, axis)
+ if attr.get("reverse", 0) != 0:
+ data = bb.emit_te(topi.flip, data, axis=axis if axis else 0)
+ return data
+
+
+class Squeeze(OnnxOpConverter):
+ """Converts an onnx Squeeze node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ axis = get_constant(inputs[1], params)
+ if isinstance(axis, relax.Constant):
+ axis = [int(x) for x in axis.data.numpy()]
+ # If data is constant, perform computation directly.
+ if isinstance(inputs[0], relax.Constant):
+ out_data = _np.squeeze(inputs[0].data.numpy(), axis)
+ return relax.const(out_data, inputs[0].struct_info.dtype)
+ return relax.op.squeeze(inputs[0], axis)
+
+
+class Constant(OnnxOpConverter):
+ """Converts an onnx Constant node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ if "value" not in attr:
+ raise ValueError("no value in Constant")
+ value = attr.pop("value")
+ # Constants may rarely have string types. These are likely exported
+ # from other frameworks and not actually used in TVM. We'll just use
+ # a zero valued constant for compatibility.
+ if isinstance(value, bytes):
+ np_value = _np.asarray([0]).astype("int64")
+ else:
+ np_value = get_numpy(value)
+ dtype = np_value.dtype.name
+ value = relax.const(np_value, dtype)
+ return value
+
+
+class ConstantOfShape(OnnxOpConverter):
+ """Converts an onnx ConstantOfShape node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v9(cls, bb, inputs, attr, params):
+ shape = inputs[0]
+ value = get_numpy(attr.get("value", 0))
+ if isinstance(value, _np.ndarray):
+ dtype = str(value.dtype)
+ else:
+ dtype = "float32"
+
+ # If shape is a constant, we can directly create a relax constant.
+ if isinstance(shape, relax.Constant):
+ np_array = _np.zeros(shape=shape.data.numpy()) + value
+ return relax.const(np_array, dtype=dtype)
+ elif isinstance(shape, relax.ShapeExpr):
+ np_array = _np.zeros(shape=[dim.value for dim in shape]) + value
+ return relax.const(np_array, dtype)
+
+ # Otherwise we have to use the value of shape at runtime.
+ # Create a constant for the new value.
+ const_value = relax.const(value, dtype)
+
+ # Convert to shape expression if needed.
+ if not isinstance(shape.struct_info, relax.ShapeStructInfo):
+ shape_ndim = [dim.value for dim in
shape.struct_info.shape.values][0]
+ # Broadcast the constant to the input shape.
+ shape_dataflow_var = bb.emit(
+ relax.Call(
+ relax.ExternFunc("vm.builtin.tensor_to_shape"),
+ [shape],
+ sinfo_args=[relax.ShapeStructInfo(ndim=shape_ndim)],
+ )
+ )
+ shape_vars = []
+ for i in range(shape_ndim):
+ shape_vars.append(tvm.tir.Var("x_%d" % i, "int64"))
+ bb.match_cast(shape_dataflow_var,
relax.ShapeStructInfo(shape_vars))
+ shape = relax.ShapeExpr(shape_vars)
+
+ return relax.op.broadcast_to(const_value, shape)
+
+
+class Sub(OnnxOpConverter):
+ """Converts an onnx Sub node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ if all([isinstance(inp, relax.Constant) for inp in inputs]):
+ output = inputs[0].data.numpy() - inputs[1].data.numpy()
+ return relax.const(output, output.dtype)
+ return relax.op.subtract(inputs[0], inputs[1])
+
+
+class Sin(OnnxOpConverter):
+ """Converts an onnx Sin node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v7(cls, bb, inputs, attr, params):
+ return relax.op.sin(inputs[0])
+
+
+class Cos(OnnxOpConverter):
+ """Converts an onnx Cos node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v7(cls, bb, inputs, attr, params):
+ return relax.op.cos(inputs[0])
+
+
+class Neg(OnnxOpConverter):
+ """Converts an onnx Neg node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ if isinstance(inputs[0], relax.Constant):
+ data_np = inputs[0].data.numpy()
+ return relax.const(_np.negative(data_np),
inputs[0].struct_info.dtype)
+ return relax.op.negative(inputs[0])
+
+
+class Abs(OnnxOpConverter):
+ """Converts an onnx Abs node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ if isinstance(inputs[0], relax.Constant):
+ output = _np.abs(inputs[0].data.numpy())
+ return relax.const(output, output.dtype)
+ return relax.op.abs(inputs[0])
+
+
+class Min(OnnxOpConverter):
+ """Converts an onnx Min node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ if all([isinstance(inp, relax.Constant) for inp in inputs]):
+ np_inputs = [inp.data.numpy() for inp in inputs]
+ output = _np.minimum(*np_inputs)
+ return relax.const(output, output.dtype)
+
+ # Expand inputs, stack them, then perform minimum over the new axis.
+ inputs = [bb.normalize(relax.op.expand_dims(i, axis=0)) for i in
inputs]
+ stacked_tensor = relax.op.concat(inputs, axis=0)
+ return relax.op.min(stacked_tensor, axis=0)
+
+
+class Max(OnnxOpConverter):
+ """Converts an onnx Max node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ if all([isinstance(inp, relax.Constant) for inp in inputs]):
+ np_inputs = [inp.data.numpy() for inp in inputs]
+ output = _np.maximum(*np_inputs)
+ return relax.const(output, output.dtype)
+
+ # Expand inputs, stack them, then perform maximum over the new axis.
+ inputs = [bb.normalize(relax.op.expand_dims(i, axis=0)) for i in
inputs]
+ stacked_tensor = relax.op.concat(inputs, axis=0)
+ return relax.op.max(stacked_tensor, axis=0)
+
+
+class Log(OnnxOpConverter):
+ """Converts an onnx Log node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ if isinstance(inputs[0], relax.Constant):
+ return relax.const(_np.log(inputs[0].data.numpy()),
inputs[0].struct_info.dtype)
+ return relax.op.log(inputs[0])
+
+
+class Exp(OnnxOpConverter):
+ """Converts an onnx Exp node into an equivalent Relax expression."""
+
+ @classmethod
+ def _check_type(cls, dtype, valid_types):
+ assert dtype in valid_types, "Types {} are supported only, but {} is
given".format(
+ valid_types, dtype
+ )
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ valid_types = ["float", "float32", "double", "float64", "float16"]
+ cls._check_type(data.checked_type.dtype, valid_types)
+
+ return relax.op.exp(data)
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ valid_types = ["float", "float32", "double", "float64", "float16",
"bfloat16"]
+ cls._check_type(data.checked_type.dtype, valid_types)
+
+ return relax.op.exp(data)
+
+
+class Less(OnnxOpConverter):
+ """Converts an onnx Less node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ if all([isinstance(inp, relax.Constant) for inp in inputs]):
+ output = _np.less(inputs[0].data.numpy(), inputs[1].data.numpy())
+ return relax.const(output, output.dtype)
+ return relax.op.less(inputs[0], inputs[1])
+
+
+class LessOrEqual(OnnxOpConverter):
+ """Converts an onnx LessOrEqual node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ if all([isinstance(inp, relax.Constant) for inp in inputs]):
+ output = _np.less_equal(inputs[0].data.numpy(),
inputs[1].data.numpy())
+ return relax.const(output, output.dtype)
+ return relax.op.less_equal(inputs[0], inputs[1])
+
+
+class Split(OnnxOpConverter):
+ """Converts an onnx Split node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ splits = attr.get("split", None)
+ if splits is not None and len(splits) > 1:
+ indices = []
+ index = 0
+ for i in splits[:-1]:
+ index += i
+ indices.append(index)
+ # When splits isnt specified divide evenly over axis.
+ else:
+ indices = attr["tvm_custom"]["num_outputs"]
+ return bb.emit_te(topi.split, inputs[0], indices, attr.get("axis", 0))
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ splits = inputs[1]
+ splits_rank = None
+ if splits is not None:
+ splits_rank = splits.checked_type.ndim
+ if splits is not None and splits_rank > 0:
+ if isinstance(splits, relax.Constant):
+ splits = splits.data.asnumpy()
+ indices = []
+ index = 0
+ for i in splits[:-1]:
+ index += i
+ indices.append(index)
+ else:
+ raise ValueError("Dynamic Split not yet supported")
+ # When splits isnt specified divide evenly over axis.
+ else:
+ indices = attr["tvm_custom"]["num_outputs"]
+ return bb.emit_te(topi.split, inputs[0], indices,
axis=attr.get("axis", 0))
+
+
+class Slice(OnnxOpConverter):
+ """Converts an onnx Splice node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ # TODO (jwfromm) currently only supports constant parameters.
+ data = inputs[0]
+ starts = get_constant(inputs[1], params)
+ ends = get_constant(inputs[2], params)
+ axes = get_constant(inputs[3], params)
+ steps = get_constant(inputs[4], params)
+ if not all(
+ [
+ (isinstance(param, relax.Constant) or param is None)
+ for param in [starts, ends, axes, steps]
+ ]
+ ):
+ raise ValueError("Only constant Slice parameters are currently
supported.")
+ # Convert parameters to constant lists.
+ starts = starts.data.numpy().tolist()
+ ends = ends.data.numpy().tolist()
+ if axes is not None:
+ axes = axes.data.numpy().tolist()
+ else:
+ axes = list(range(len(starts)))
+ # Convert negative axis to positive if needed.
+ for i, axis in enumerate(axes):
+ if axis < 0:
+ axes[i] = axis + len(data.struct_info.shape)
+ if steps is not None:
+ steps = steps.data.numpy().tolist()
+ else:
+ steps = [1] * len(axes)
+ # If input is a shape tensor, we can directly extract it.
+ if isinstance(data, relax.ShapeExpr):
+ shape_data = [dim.value for dim in data]
+ # Starts, ends, and steps must be 1-d for shape operation.
+ assert all(len(i) == 1 for i in [starts, ends, steps])
+ sliced_values = shape_data[starts[0] : ends[0] : steps[0]]
+ return relax.const(sliced_values, "int64")
+ return relax.op.strided_slice(data, axes, starts, ends, steps)
+
+
+class Pad(OnnxOpConverter):
+ """Converts an onnx Pad node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ pads = get_constant(inputs[1], params)
+ constant_value = get_constant(inputs[2], params)
+ if constant_value is not None:
+ constant_value = constant_value.data.numpy().item()
+ else:
+ constant_value = 0.0
+
+ if isinstance(pads, relax.Constant):
+ pad_before, pad_after = _np.split(pads.data.numpy(), 2)
+ pad_before = _np.ndarray.tolist(pad_before)
+ pad_after = _np.ndarray.tolist(pad_after)
+ else:
+ raise ValueError("Dynamic pads are not supported yet.")
+
+ pad_mode = attr.get("mode", b"constant").decode("utf-8")
+ if not pad_mode in ["constant", "edge", "reflect"]:
+ raise tvm.error.OpAttributeInvalid(
+ "Value " + pad_mode + ' in attribute "mode" is invalid for
operator Pad.'
+ )
+
+ if pad_mode == "constant":
+ return bb.emit_te(topi.nn.pad, inputs[0], pad_before, pad_after,
constant_value)
+ elif pad_mode == "reflect":
+ return bb.emit_te(topi.nn.mirror_pad, inputs[0], pad_before,
pad_after, "REFLECT")
+ else:
+ # TODO(gigiblender) Support edge mode.
+ raise NotImplementedError("Pad mode {} not
implemented".format(pad_mode))
+
+
+class Tile(OnnxOpConverter):
+ """Converts an onnx Tile node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ reps = get_constant(inputs[1], params)
+ if isinstance(reps, relax.Constant):
+ reps = reps.data.numpy().tolist()
+ else:
+ raise ValueError("Dynamic reps for Tile are supported yet.")
+ return bb.emit_te(topi.tile, inputs[0], reps)
+
+
+class Expand(OnnxOpConverter):
+ """Converts an onnx Expand node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ shape = inputs[1]
+
+ if isinstance(shape, relax.ShapeExpr):
+ return relax.op.broadcast_to(data, shape)
+
+ # If possible, directly expand to constant shape.
+ if isinstance(shape, relax.Constant):
+ new_shape = shape.data.numpy().tolist()
+ # For some reason, onnx allows target shapes to be smaller than
input shapes.
+ # We need to go correct it.
+ data_shape = [dim.value for dim in data.struct_info.shape]
+ for i, s in enumerate(new_shape):
+ if s < data_shape[i]:
+ new_shape[i] = data_shape[i]
+ # If the new shape matches the input shape, no transformation is
needed.
+ if new_shape == data_shape:
+ return data
+ return relax.op.broadcast_to(data, relax.ShapeExpr(new_shape))
+
+ # Otherwise handle dynamic shapes.
+ shape_ndim = [dim.value for dim in shape.struct_info.shape.values][0]
+ shape_dataflow_var = bb.emit(
+ relax.Call(
+ relax.ExternFunc("vm.builtin.tensor_to_shape"),
+ [shape],
+ sinfo_args=[relax.ShapeStructInfo(ndim=shape_ndim)],
+ )
+ )
+
+ shape_vars = []
+ for i in range(shape_ndim):
+ shape_vars.append(tvm.tir.Var("x_%d" % i, "int64"))
+ bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars))
+ return bb.normalize(relax.op.broadcast_to(data,
relax.ShapeExpr(shape_vars)))
+
+
+class Attention(OnnxOpConverter):
+ """Converts an onnx.microsoft Attention node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ num_heads = attr["num_heads"]
+
+ assert "do_rotary" not in attr, "rotary position embedding is not
currently supported"
+ assert (
+ "past_present_share_buffer" not in attr
+ ), "past state for key and value is not currently supported"
+ assert "scale" not in attr, "custom scale is not currently supported"
+ assert "unidirectional" not in attr, "unidirectional attention is not
currently supported"
+
+ if "mask_filter_value" in attr:
+ mask_filter_value = attr["mask_filter_value"]
+ else:
+ mask_filter_value = -10000.0
+
+ # (batch_size, sequence_length, input_hidden_size)
+ input_emb = bb.normalize(inputs[0])
+
+ # (input_hidden_size, hidden_size + hidden_size + v_hidden_size)
+ weight = bb.normalize(inputs[1])
+
+ def optional_input(k: int):
+ if inputs[k] is not None:
+ return bb.normalize(inputs[k])
+ else:
+ return None
+
+ # (hidden_size + hidden_size + v_hidden_size)
+ bias = optional_input(2)
+
+ # 1. ( batch_size, 1, max_seq_len, max_seq_len,)
+ # 2. ( batch_size, total_seq_len,)
+ # 3. ( batch_size, seq_len, total_seq_len,)
+ # 4. ( batch_size,)
+ # 5. (2 * batch_size,)
+ # For now, we only support case 2 & 3.
+ mask_index = optional_input(3)
+
+ # (2, batch_size, num_heads, past_sequence_length, head_size)
+ assert inputs[4] is None, "past state for key and value is not
currently supported"
+
+ # (batch_size, num_heads, sequence_length, total_sequence_length)
+ qk_bias = optional_input(5)
+
+ assert inputs[6] is None, "past_sequence_length is not currently
supported"
+
+ (batch_size, seq_len, input_hidden_size) = [
+ val.value for val in input_emb.struct_info.shape.values
+ ]
+ weight_shape = [val.value for val in weight.struct_info.shape.values]
+
+ assert (
+ weight_shape[0] == input_hidden_size
+ ), "input and weight should share the same input hiden size"
+
+ if "qkv_hidden_sizes" in attr:
+ assert (
+ attr["qkv_hidden_sizes"][0] == attr["qkv_hidden_sizes"][1]
+ ), "Q and K should share the same hidden sizes"
+ hidden_size, _, hidden_size_v = attr["qkv_hidden_sizes"]
+ else:
+ hidden_size = hidden_size_v = weight_shape[1] // 3
+
+ assert (
+ hidden_size % num_heads == 0
+ ), "hidden size should be divisible by number of attention heads"
+ head_size = hidden_size // num_heads
+ head_size_v = hidden_size_v // num_heads
+
+ if mask_index is not None:
+ mask_index_shape = [val.value for val in
mask_index.struct_info.shape.values]
+ assert mask_index_shape in (
+ [batch_size, seq_len],
+ [
+ batch_size,
+ seq_len,
+ seq_len,
+ ],
+ ), """mask index should be in shape of (batch_size, seq_len),
+ or (batch_size, seq_len, seq_len)"""
+ mask_bias = relax.op.subtract(
+ relax.const(1, dtype=mask_index.struct_info.dtype), mask_index
+ )
+ mask_bias = relax.op.astype(mask_bias,
dtype=input_emb.struct_info.dtype)
+ mask_bias = bb.normalize(
+ relax.op.multiply(
+ mask_bias,
+ relax.const(mask_filter_value,
dtype=input_emb.struct_info.dtype),
+ )
+ )
+ if qk_bias is None:
+ qk_bias = mask_bias
+ else:
+ if len(mask_index_shape) == 2:
+ mask_bias = bb.normalize(
+ relax.op.reshape(mask_bias, [batch_size, 1, 1,
seq_len])
+ )
+ elif len(mask_index_shape) == 3:
+ mask_bias = bb.normalize(
+ relax.op.reshape(mask_bias, [batch_size, 1, seq_len,
seq_len])
+ )
+ qk_bias = bb.normalize(relax.op.add(qk_bias, mask_bias))
+
+ QKV = relax.op.matmul(input_emb, weight)
+
+ if bias:
+ bias_shape = [val.value for val in bias.struct_info.shape.values]
+ assert (
+ bias_shape[0] == weight_shape[1]
+ ), "bias and weight should share the same hidden size sum"
+ QKV = relax.op.add(QKV, bias)
+
+ QKV = relax.op.split(QKV, [hidden_size, hidden_size * 2], 2)
+ Q, K, V = QKV[0], QKV[1], QKV[2]
+
+ Q = bb.normalize(relax.op.reshape(Q, (batch_size, seq_len, num_heads,
head_size)))
+ K = bb.normalize(relax.op.reshape(K, (batch_size, seq_len, num_heads,
head_size)))
+ V = bb.normalize(relax.op.reshape(V, (batch_size, seq_len, num_heads,
head_size_v)))
+ output = relax.op.nn.attention(Q, K, V, qk_bias)
+ output = bb.normalize(
+ relax.op.reshape(output, (batch_size, seq_len, num_heads *
head_size_v))
+ )
+ # add placeholder for optional present state supported in the future
+ placeholder = relax.const(0, dtype="float32")
+ return relax.Tuple([output, placeholder])
+
+
+class Identity(OnnxOpConverter):
+ """Converts an onnx Identity node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return inputs[0]
+
+
+class Resize(OnnxOpConverter):
+ """Converts an onnx Resize node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v18(cls, bb, inputs, attr, params):
+ # Extract the many attributes of resize.
+ coord_mode = attr.get("coordinate_transformation_mode",
b"half_pixel").decode("ascii")
+ cubic_coeff_a = attr.get("cubic_coeff_a", -0.75)
+ exclude_outside = attr.get("exclude_outside", 0)
+ extrapolation_value = attr.get("extrapolation_value", 0.0)
+ mode = attr.get("mode", b"nearest").decode("ascii")
+ rounding_method = attr.get("nearest_mode",
b"round_prefer_floor").decode("ascii")
+
+ # Adapt attributes to fit TVM definition.
+ if mode == "nearest":
+ mode = "nearest_neighbor"
+
+ # Unpack inputs.
+ x = inputs[0]
+ roi = get_constant(inputs[1], params)
+ scales = get_constant(inputs[2], params)
+ sizes = get_constant(inputs[3], params)
+ ndims = len(x.struct_info.shape)
+ assert ndims == 4, "Only resize2d is currently supported."
+
+ assert (
+ scales is None or sizes is None
+ ), "Only one of scales and sizes can be provided in Resize."
+
+ # Define relax implementation.
+ if roi is not None:
+ roi = relax.op.concat(
+ [
+ relax.op.strided_slice(roi, axes=[0], begin=[2],
end=[ndims]),
+ relax.op.strided_slice(roi, axes=[0], begin=[ndims + 2],
end=[2 * ndims]),
+ ],
+ axis=0,
+ )
+ else:
+ roi = [0.0] * 4
+
+ # Convert scales to sizes if needed.
+ if scales is not None:
+ assert isinstance(scales, relax.Constant), "Only constant scales
currently supported."
+ scales = scales.data.numpy()
+ sizes_shape = [dim.value for dim in x.struct_info.shape]
+ sizes = (sizes_shape * scales)[2:].astype("int64").tolist()
+ else:
+ assert isinstance(
+ sizes, relax.Constant
+ ), "Only constant output size currently supported."
+ sizes = sizes.data.numpy().astype("int64").tolist()[2:]
+
+ # TODO(jwfromm) relax.image.resize2d runs into some issues with
dynamism.
+ return bb.emit_te(
+ topi.image.resize2d,
+ x,
+ roi,
+ sizes,
+ layout="NCHW",
+ method=mode,
+ coordinate_transformation_mode=coord_mode,
+ rounding_method=rounding_method,
+ bicubic_alpha=cubic_coeff_a,
+ bicubic_exclude=exclude_outside,
+ extrapolation_value=extrapolation_value,
+ )
+
+
+class Einsum(OnnxOpConverter):
+ """Converts an onnx Einsum node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v12(cls, bb, inputs, attr, params):
+ equation = attr["equation"].decode("utf-8")
+ return bb.emit_te(topi.einsum, equation, *inputs)
+
+
+class Range(OnnxOpConverter):
+ """Converts an onnx Range node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v12(cls, bb, inputs, attr, params):
+ start = get_constant(inputs[0], params)
+ limit = get_constant(inputs[1], params)
+ delta = get_constant(inputs[2], params)
+ out_dtype = start.struct_info.dtype
+
+ if isinstance(start, relax.Constant):
+ start = start.data.numpy().tolist()
+
+ if isinstance(limit, relax.Constant):
+ limit = limit.data.numpy().tolist()
+
+ assert isinstance(delta, relax.Constant), "Constant delta required for
Range."
+ step = delta.data.numpy().tolist()
+
+ # If all inputs are constant, compute directly.
+ if isinstance(start, int) and isinstance(limit, int):
+ out_range = _np.arange(start=start, stop=limit, step=step)
+ return relax.const(out_range, out_dtype)
+
+ # Otherwise compute in graph.
+ return bb.emit_te(topi.arange, start, limit, step, out_dtype)
+
+
+class InstanceNormalization(OnnxOpConverter):
+ """Converts an onnx InstanceNormalization node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v6(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ scale = inputs[1]
+ B = inputs[2]
+ epsilon = attr.get("epsilon", 1e-05)
+ epsilon = relax.const(epsilon, dtype=data.struct_info.dtype)
+
+ ndim = len(data.struct_info.shape)
+ redux_axes = list(range(2, ndim))
+
+ mean = relax.op.mean(data, axis=redux_axes, keepdims=True)
+ var = relax.op.variance(data, axis=redux_axes, keepdims=True)
+ sqrt = relax.op.sqrt(relax.op.add(var, epsilon))
+ out = relax.op.divide(relax.op.subtract(data, mean), sqrt)
+ broadcast_shape = [-1] + [
+ 1,
+ ] * (ndim - 2)
+ if scale is not None:
+ scale = relax.op.reshape(scale, broadcast_shape)
+ out = relax.op.multiply(out, scale)
+ if B is not None:
+ B = relax.op.reshape(B, broadcast_shape)
+ out = relax.op.add(out, B)
+ return out
+
+
+class BatchNormalization(OnnxOpConverter):
+ """Converts an onnx BatchNormalization node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v15(cls, bb, inputs, attr, params):
+ # Unpack inputs
+ data = inputs[0]
+ scale = inputs[1]
+ bias = inputs[2]
+ mean = inputs[3]
+ var = inputs[4]
+ epsilon = attr.get("epsilon", 1e-05)
+ return relax.op.nn.batch_norm(
+ data, gamma=scale, beta=bias, moving_mean=mean, moving_var=var,
epsilon=epsilon, axis=1
+ )
+
+
+class MaxPool(OnnxOpConverter):
+ """Converts an onnx MaxPool node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v12(cls, bb, inputs, attr, params):
+ # Unpack inputs and attributes.
+ data = inputs[0]
+ auto_pad = attr.get("auto_pad", b"NOTSET").decode("utf-8")
+ ceil_mode = attr.get("ceil_mode", 0)
+ dilations = attr.get("dilations", [1, 1])
+ kernel_shape = attr.get("kernel_shape")
+ pads = attr.get("pads", 0)
+ strides = attr.get("strides", 1)
+
+ assert len(kernel_shape) == 2, "Currently only 2D pooling is
supported."
+ assert auto_pad in [
+ "NOTSET",
+ "SAME_UPPER",
+ "SAME_LOWER",
+ "VALID",
+ ], f"Value {auto_pad} in attribute auto_pad is invalid."
+
+ if auto_pad in ("SAME_UPPER", "SAME_LOWER"):
+ input_spatial_shape = cls._get_input_spatial_shape(data)
+ output_spatial_shape = [0 for _ in input_spatial_shape]
+
+ pads = _np.array([(0, 0) for _ in range(len(kernel_shape))])
+
+ for i, _ in enumerate(input_spatial_shape):
+ if auto_pad == "SAME_UPPER":
+ output_spatial_shape[i] =
int(_np.ceil(input_spatial_shape[i] / strides[i]))
+ else:
+ output_spatial_shape[i] =
int(_np.floor(input_spatial_shape[i] / strides[i]))
+ pad_i = (
+ (output_spatial_shape[i] - 1) * strides[i]
+ + ((kernel_shape[i] - 1) * dilations[i] + 1)
+ - input_spatial_shape[i]
+ )
+ if auto_pad == "SAME_UPPER":
+ pads[i, 0] = pad_i // 2
+ pads[i, 1] = pad_i - pads[i, 0]
+ else:
+ pads[i, 1] = pad_i // 2
+ pads[i, 0] = pad_i - pads[i, 1]
+
+ # TODO(agladyshev): for now we support only 2D kernel
+ # (top, left, bottom, right)
+ flatten_pads = [pads[0][0], pads[1][0], pads[0][1], pads[1][1]]
+ pads = tuple(flatten_pads)
+
+ return relax.op.nn.max_pool2d(data, kernel_shape, strides, pads,
dilations, ceil_mode)
+
+ @classmethod
+ def _get_input_spatial_shape(cls, tensor):
+ # shape is (N x C x D1 x D2 ... Dn)
+ return _np.array([int(d) for d in tensor.struct_info.shape],
dtype="int64")[2:]
+
+
+class GlobalAveragePool(OnnxOpConverter):
+ """Converts an onnx GlobalAveragePool node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return relax.op.nn.adaptive_avg_pool2d(inputs[0], 1)
+
+
+class Flatten(OnnxOpConverter):
+ """Converts an onnx Flatten node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ axis = attr.get("axis", 1)
+ data_shape = [i.value for i in inputs[0].struct_info.shape]
+ new_shape = (1, -1) if axis == 0 else
(_np.prod(data_shape[0:axis]).astype("int64"), -1)
+ return relax.op.reshape(inputs[0], new_shape)
+
+
+class LayerNormalization(OnnxOpConverter):
+ """Converts an onnx LayerNormalization node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v17(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ scale = inputs[1]
+ bias = inputs[2]
+ axis = attr.get("axis", -1)
+ epsilon = attr.get("epsilon", 1e-05)
+
+ output = relax.op.nn.layer_norm(data, scale, bias, axis, epsilon)
+ # Onnx layernorm has 3 outputs but only the first is used.
+ # We construct two empty constants for this.
+ placeholder = relax.const(0, dtype="float32")
+ return relax.Tuple([output, placeholder, placeholder])
+
+
+class ReduceMax(OnnxOpConverter):
+ """Converts an onnx ReduceMax node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ axes = attr.get("axes", None)
+ keepdims = attr.get("keepdims", 1)
+ return relax.op.max(data, axes, keepdims)
+
+
+class ReduceMin(OnnxOpConverter):
+ """Converts an onnx ReduceMin node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ axes = attr.get("axes", None)
+ keepdims = attr.get("keepdims", 1)
+ return relax.op.min(data, axes, keepdims)
+
+
+class ReduceSum(OnnxOpConverter):
+ """Converts an onnx ReduceSum node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ axes = attr.get("axes", None)
+ keepdims = attr.get("keepdims", 1)
+ return relax.op.sum(data, axes, keepdims)
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ axes = inputs[1]
+ keepdims = attr.get("keepdims", 1)
+ assert isinstance(axes, relax.Constant), "Only constant axes currently
supported."
+ axes = axes.data.numpy().tolist()
+ return relax.op.sum(data, axes, keepdims)
+
+
+class ReduceMean(OnnxOpConverter):
+ """Converts an onnx ReduceMean node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ axes = attr.get("axes", None)
+ keepdims = attr.get("keepdims", 1)
+ return relax.op.mean(data, axes, keepdims)
+
+
+class ReduceProd(OnnxOpConverter):
+ """Converts an onnx ReduceProd node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ axes = attr.get("axes", None)
+ keepdims = attr.get("keepdims", 1)
+ return relax.op.prod(data, axes, keepdims)
+
+
+class ReduceLogSumExp(OnnxOpConverter):
+ """Converts an onnx ReduceLogSumExp node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ x = inputs[0]
+ axes = attr.get("axes", None)
+ keepdims = attr.get("keepdims", 1)
+ max_x = relax.op.max(x, axes, True)
+ exp_x = relax.op.exp(relax.op.subtract(x, max_x))
+ sum_x = relax.op.sum(exp_x, axes, True)
+ out_x = relax.op.add(relax.op.log(sum_x), max_x)
+ if not keepdims:
+ out_x = relax.op.squeeze(out_x, axes)
+ return out_x
+
+
+class ReduceLogSum(OnnxOpConverter):
+ """Converts an onnx ReduceLogSum node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ axes = attr.get("axes", None)
+ keepdims = attr.get("keepdims", 1)
+ return relax.op.log(relax.op.sum(data, axes, keepdims))
+
+
+class ReduceSumSquare(OnnxOpConverter):
+ """Converts an onnx ReduceSumSquare node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ axes = attr.get("axes", None)
+ keepdims = attr.get("keepdims", 1)
+ return relax.op.sum(relax.op.multiply(data, data), axes, keepdims)
+
+
+class ReduceL1(OnnxOpConverter):
+ """Converts an onnx ReduceL1 node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ axes = attr.get("axes", None)
+ keepdims = attr.get("keepdims", 1)
+ return relax.op.sum(relax.op.abs(data), axes, keepdims)
+
+
+class ReduceL2(OnnxOpConverter):
+ """Converts an onnx ReduceL2 node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ axes = attr.get("axes", None)
+ keepdims = attr.get("keepdims", 1)
+ return relax.op.sqrt(relax.op.sum(relax.op.multiply(data, data), axes,
keepdims))
+
+
+class ArgMax(OnnxOpConverter):
+ """Converts an onnx ArgMax node into an equivalent Relax expression."""
+
+ @classmethod
+ def _check_attrs(cls, data, attr, shift_axis=True):
+ dims_num = len(data.struct_info.shape)
+ axis = attr.get("axis", 0)
+ if shift_axis and axis < 0:
+ axis += dims_num
+ assert 0 <= axis < dims_num, "Axis is out of bounds"
+ keepdims = attr.get("keepdims", True)
+ return axis, keepdims
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ axis, keepdims = cls._check_attrs(data, attr, False)
+ return relax.op.argmax(data, axis, keepdims)
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ axis, keepdims = cls._check_attrs(data, attr)
+ return relax.op.argmax(data, axis, keepdims)
+
+ @classmethod
+ def _impl_v12(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ axis, keepdims = cls._check_attrs(data, attr)
+ select_last_index = attr.get("select_last_index", False)
+ if select_last_index:
+ # TODO(vvchernov): support attr
+ raise tvm.error.OpAttributeUnImplemented(
+ "'select_last_index' attribute has not been supported yet"
+ )
+ return relax.op.argmax(data, axis, keepdims)
+
+
+class ArgMin(OnnxOpConverter):
+ """Converts an onnx ArgMin node into an equivalent Relax expression."""
+
+ @classmethod
+ def _check_attrs(cls, data, attr, shift_axis=True):
+ dims_num = len(data.struct_info.shape)
+ axis = attr.get("axis", 0)
+ if shift_axis and axis < 0:
+ axis += dims_num
+ assert 0 <= axis < dims_num, "Axis is out of bounds"
+ keepdims = attr.get("keepdims", True)
+ return axis, keepdims
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ axis, keepdims = cls._check_attrs(data, attr, False)
+ return relax.op.argmin(data, axis, keepdims)
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ axis, keepdims = cls._check_attrs(data, attr)
+ return relax.op.argmin(data, axis, keepdims)
+
+ @classmethod
+ def _impl_v12(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ axis, keepdims = cls._check_attrs(data, attr)
+ select_last_index = attr.get("select_last_index", False)
+ if select_last_index:
+ # TODO(vvchernov): support attr
+ raise tvm.error.OpAttributeUnImplemented(
+ "'select_last_index' attribute has not been supported yet"
+ )
+ return relax.op.argmin(data, axis, keepdims)
+
+
+class SkipLayerNormalization(OnnxOpConverter):
+ """Converts a microsoft contrib SkipLayerNormalization node into a Relax
expression."""
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ skip = inputs[1]
+ gamma = inputs[2]
+ beta = inputs[3]
+ bias = inputs[4]
+
+ assert (
+ beta is not None and bias is not None
+ ), "SkipLayerNormalization import currently only supports required
beta and bias"
+
+ epsilon = attr.get("epsilon", 1e-12)
+
+ data = relax.op.add(data, skip)
+ if bias is not None:
+ data = relax.op.add(data, bias)
+
+ output = relax.op.nn.layer_norm(data, gamma, beta, axes=-1,
epsilon=epsilon)
+
+ # Expects three outputs though only the first is used. Construct a
placeholder for others.
+ placeholder = relax.const(0, dtype="float32")
+ return relax.Tuple([output, placeholder, placeholder])
+
+
+class EmbedLayerNormalization(OnnxOpConverter):
+ """Converts a microsoft contrib EmbedLayerNormalization node into a Relax
expression."""
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ input_ids = inputs[0]
+ segment_ids = inputs[1]
+ word_emb = inputs[2]
+ pos_emb = inputs[3]
+ segment_emb = inputs[4]
+ gamma = inputs[5]
+ beta = inputs[6]
+ mask = inputs[7]
+ pos_ids = inputs[8]
+
+ epsilon = attr.get("epsilon", 1e-12)
+
+ (batch_size, seq_len) = [dim.value for dim in
input_ids.struct_info.shape]
+
+ if segment_ids:
+ assert segment_emb
+
+ if pos_ids is None:
+ pos_ids = relax.const([list(range(seq_len))] * batch_size,
dtype="int64")
+ # TODO(jwfromm) Replace with relax ops once take has better support.
+ word_vec = bb.emit_te(topi.take, word_emb, input_ids, 0)
+ if segment_ids:
+ segment_vec = bb.emit_te(topi.take, segment_emb, segment_ids, 0)
+ pos_vec = bb.emit_te(topi.take, pos_emb, pos_ids, 0)
+
+ vec_sum = relax.op.add(word_vec, pos_vec)
+ if segment_ids:
+ vec_sum = relax.op.add(vec_sum, segment_vec)
+
+ ln = relax.op.nn.layer_norm(vec_sum, gamma, beta, axes=-1,
epsilon=epsilon)
+
+ mask_index = relax.const(_np.zeros((batch_size,), dtype="int64"))
+ if mask:
+ # Caculate number of words per sentence.
+ mask_index = relax.op.sum(mask, axis=1)
+
+ return relax.Tuple([ln, mask_index])
+
+
+class Greater(OnnxOpConverter):
+ """Converts an onnx Greater node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ if all([isinstance(inp, relax.Constant) for inp in inputs]):
+ output = _np.greater(inputs[0].data.numpy(),
inputs[1].data.numpy())
+ return relax.const(output, output.dtype)
+ return relax.op.greater(inputs[0], inputs[1])
+
+
+class Reciprocal(OnnxOpConverter):
+ """Converts an onnx Reciprocal node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ input_dtype = inputs[0].struct_info.dtype
+ return relax.op.divide(relax.const(1, dtype=input_dtype), inputs[0])
+
+
+class OneHot(OnnxOpConverter):
+ """Converts an onnx OneHot node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ indices = inputs[0]
+ depth = get_constant(inputs[1], params)
+ values = get_constant(inputs[2], params)
+ axis = attr.get("axis", -1)
+ dtype = values.struct_info.dtype
+ assert isinstance(depth, relax.Constant), "Only constant depth
currently supported."
+ depth = depth.data.numpy().tolist()
+ assert isinstance(values, relax.Constant), "Only constant values
currently supported."
+ values = values.data.numpy().tolist()
+ off_value, on_value = values
+ return bb.emit_te(topi.one_hot, indices, on_value, off_value, depth,
axis, dtype)
+
+
+def _get_convert_map():
+ return {
+ "MatMul": MatMul,
+ "Concat": Concat,
+ "Add": Add,
+ "Mul": Mul,
+ "Cast": Cast,
+ "Gather": Gather,
+ "Gemm": Gemm,
+ "Reshape": Reshape,
+ "Div": Div,
+ "Sigmoid": Sigmoid,
+ "Softmax": Softmax,
+ "Transpose": Transpose,
+ "Unsqueeze": Unsqueeze,
+ "Gelu": Gelu,
+ "BiasGelu": BiasGelu,
+ "Where": Where,
+ "Clip": Clip,
+ "Equal": Equal,
+ "Shape": Shape,
+ "Tanh": Tanh,
+ "Sqrt": Sqrt,
+ "Relu": Relu,
+ "Conv": Conv,
+ "Pow": Pow,
+ "Erf": Erf,
+ "CumSum": CumSum,
+ "Squeeze": Squeeze,
+ "Constant": Constant,
+ "Sub": Sub,
+ "Sin": Sin,
+ "Cos": Cos,
+ "Neg": Neg,
+ "Abs": Abs,
+ "Min": Min,
+ "Max": Max,
+ "Log": Log,
+ "Exp": Exp,
+ "Less": Less,
+ "LessOrEqual": LessOrEqual,
+ "LayerNormalization": LayerNormalization,
+ "SkipLayerNormalization": SkipLayerNormalization,
+ "EmbedLayerNormalization": EmbedLayerNormalization,
+ "InstanceNormalization": InstanceNormalization,
+ # defs/reduction
+ "ReduceMax": ReduceMax,
+ "ReduceMin": ReduceMin,
+ "ReduceSum": ReduceSum,
+ "ReduceMean": ReduceMean,
+ "ReduceProd": ReduceProd,
+ "ReduceLogSumExp": ReduceLogSumExp,
+ "ReduceLogSum": ReduceLogSum,
+ "ReduceSumSquare": ReduceSumSquare,
+ "ReduceL1": ReduceL1,
+ "ReduceL2": ReduceL2,
+ "ArgMax": ArgMax,
+ "ArgMin": ArgMin,
+ "Expand": Expand,
+ "ConstantOfShape": ConstantOfShape,
+ "Slice": Slice,
+ "Attention": Attention,
+ "Pad": Pad,
+ "Split": Split,
+ "Tile": Tile,
+ "BatchNormalization": BatchNormalization,
+ "GlobalAveragePool": GlobalAveragePool,
+ "Flatten": Flatten,
+ "MaxPool": MaxPool,
+ "Identity": Identity,
+ "Resize": Resize,
+ "Einsum": Einsum,
+ "Range": Range,
+ "Greater": Greater,
+ "Reciprocal": Reciprocal,
+ "OneHot": OneHot,
+ }
+
+
+class ONNXGraphImporter:
+ """A helper class for handling Relax expression copying from
pb2.GraphProto.
+ Definition: https://github.com/onnx/onnx/blob/main/onnx/onnx.proto
+
+ Parameters
+ ----------
+ shape_dict : dict of str to tuple, optional
+ The input shape to the graph
+ dtype_dict : str or dict of str to str
+ The input types to the graph
+ keep_params_in_input : bool
+ If True, parameters will be treated as input variables. If false,
+ parameters are treated as constant and folded directly into the graph.
+ sanitize : bool
+ Whether to sanitize the input names to be valid Relax identifiers.
+ """
+
+ current = None
+
+ def __init__(
+ self,
+ shape_dict: Dict[str, List],
+ dtype_dict: Union[str, Dict[str, str]],
+ keep_params_in_input: bool = False,
+ sanitize: bool = True,
+ ):
+ self._nodes: Dict[str, relax.Expr] = {}
+ self._inputs: Dict[str, relax.Var] = {}
+ self._num_input: int = 0
+ self._shape = shape_dict.copy() if shape_dict else {}
+ self._input_names: List[str] = []
+ self._dtype = dtype_dict
+ self.opset: int = None
+ self._name_supply = NameSupply()
+ self._keep_params_in_input = keep_params_in_input
+ self._sanitize: bool = sanitize
+ self.bb: relax.BlockBuilder = relax.BlockBuilder() # pylint:
disable=invalid-name
+ self._params = {}
+
+ def from_onnx(self, graph: onnx.onnx_ml_pb2.ModelProto, opset: int) ->
IRModule:
+ """Construct Relax expressions from the ONNX graph.
+ Onnx graph is a python protobuf object.
+
+ Parameters
+ ----------
+ graph : onnx protobuf object
+ The loaded onnx graph
+ opset : opset version
+ Returns
+ -------
+ mod : tvm.IRModule
+ The returned relax module
+ """
+ with self.bb.function("main"):
+ with self.bb.dataflow() as df: # pylint: disable=invalid-name,
unused-variable
+ self.opset = opset
+ self._parse_graph_input(graph)
+ self._parse_graph_initializers(graph)
+ self._check_for_unsupported_ops(graph)
+ self._construct_nodes(graph)
+
+ # now return the outputs
+ outputs = [self._nodes[self._parse_value_proto(i)] for i in
graph.output]
+ outputs = outputs[0] if len(outputs) == 1 else
relax.Tuple(outputs)
+
+ output_var = self.bb.emit_output(outputs)
+
+ # Create function attributes for this module
+ func_attrs = {"num_input": self._num_input}
+ # Create a function from our output expression and all input
variables.
+ input_list = [value for value in self._inputs.values() if
isinstance(value, relax.Var)]
+ # Attach params if they are available.
+ if self._keep_params_in_input and self._params:
+ param_var_list, param_value_list = map(list,
zip(*self._params.values()))
+ input_list = input_list + param_var_list
+ func_attrs["params"] = param_value_list
+
+ self.bb.emit_func_output(output_var, params=input_list)
+
+ relax_mod = self.bb.get()
+ # Attach attributes.
+ relax_mod["main"] = relax_mod["main"].with_attrs(func_attrs)
+ return relax_mod
+
+ def _parse_graph_initializers(self, graph: onnx.onnx_ml_pb2.GraphProto):
+ """Parse network inputs to relax, aka parameters."""
+ for init_tensor in graph.initializer:
+ # There are two cases for handling parameters, they are either
+ # treated as variables or constants.
+ if not init_tensor.name.strip():
+ raise ValueError("Tensor's name is required.")
+ array = self._parse_array(init_tensor)
+ # Create variables for constants.
+ if self._keep_params_in_input:
+ init_var = self._new_var(init_tensor.name, shape=array.shape,
dtype=array.dtype)
+ self._nodes[init_tensor.name] = init_var
+ # We need to keep track of both the real value and variable
for this variable.
+ self._params[init_tensor.name] = (init_var, array)
+ # Otherwise we can use the weight as a constant.
+ else:
+ self._nodes[init_tensor.name] = relax.const(array)
+
+ def _sanitize_name(self, name: str) -> str:
+ """Sanitize a name to make it a valid identifier.
+ If the name is None, returns a string input_0, input_1, etc.
+ If the input is an empty string, returns empty_0, empty_1, etc.
+ If the input is a string that does not start with a letter or
underscore,
+ returns input_<name>. Otherwise, returns an unique input name.
+
+ Parameters
+ ----------
+ name : str
+ The name to sanitize
+ Returns
+ -------
+ new_name : str
+ """
+
+ if name == "":
+ return self._name_supply.fresh_name("empty_")
+
+ new_name = name.replace(".", "_")
+ if not new_name[0].isalpha() and new_name[0] != "_":
+ new_name = str(self._name_supply.fresh_name("input_" + new_name))
+ else:
+ new_name = str(self._name_supply.fresh_name(new_name))
+
+ if new_name != name:
+ warnings.warn(("Renaming name %s to %s" % (name, new_name)))
+ return new_name
+
+ def _new_var(self, var_name: str, shape: List, dtype: str = "float32"):
+ """Creates a new Relax variable."""
+ return relax.Var(
+ name_hint=var_name,
struct_info=relax.TensorStructInfo(shape=shape, dtype=dtype)
+ )
+
+ def _parse_graph_input(self, graph: onnx.onnx_ml_pb2.GraphProto):
+ """Parse model inputs to Relax parameters."""
+ for i in graph.input:
+ # from onnx v0.2, GraphProto.input has type ValueInfoProto,
+ # and the name is 'i.name'
+ i_name, i_shape, d_type, i_shape_name = get_info(i)
+ if i_name not in self._nodes:
+ self._num_input += 1
+ self._input_names.append(i_name)
+ if i_name in self._shape:
+ i_shape = self._shape[i_name]
+ else:
+ if "?" in str(i_shape):
+ warning_msg = (
+ "Input %s has unknown dimension shapes: %s. "
+ "Specifying static values may improve performance"
+ % (i_name, str(i_shape_name))
+ )
+ warnings.warn(warning_msg)
+ if isinstance(self._dtype, dict):
+ dtype = self._dtype[i_name] if i_name in self._dtype else
d_type
+ else:
+ dtype = d_type
+ var_name = self._sanitize_name(i_name) if self._sanitize else
i_name
+ self._nodes[i_name] = self._new_var(var_name, shape=i_shape,
dtype=dtype)
+ self._inputs[i_name] = self._nodes[i_name]
+
+ def _check_for_unsupported_ops(self, graph: onnx.onnx_ml_pb2.GraphProto):
+ convert_map = _get_convert_map()
+ unsupported_ops = set()
+ for node in graph.node:
+ op_name = node.op_type
+ if (
+ op_name not in convert_map
+ and op_name != "Constant"
+ # and op_name not in _identity_list
+ ):
+ unsupported_ops.add(op_name)
+ if unsupported_ops:
+ msg = "The following operators are not supported for frontend
ONNX: "
+ msg += ", ".join(unsupported_ops)
+ raise tvm.error.OpNotImplemented(msg)
+
+ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto):
+ """Nodes are stored as directed acyclic graph."""
+ for node in graph.node:
+ op_name = node.op_type
+ attr = self._parse_attr(node.attribute)
+ # Create and populate input list.
+ inputs = onnx_input()
+ for i in node.input:
+ if i != "":
+ inputs.append(self._nodes[i])
+ else:
+ inputs.append(None)
+ i_name = self._parse_value_proto(node)
+ outputs = node.output
+ attr["tvm_custom"] = {}
+ attr["tvm_custom"]["name"] = i_name
+ attr["tvm_custom"]["num_outputs"] = len(outputs)
+
+ # Perform special handling for shape expressions. If an input is a
+ # shape expr, make sure the current op can handle it, otherwise
+ # convert it to a tensor.
+ shape_compatible_ops = ["Reshape", "ConstantOfShape", "Gather",
"Slice", "Expand"]
+ for i, inp in enumerate(inputs):
+ if (
+ inp is not None
+ and isinstance(inp.struct_info, relax.ShapeStructInfo)
+ and op_name not in shape_compatible_ops
+ ):
+ raise ValueError(f"Node {node.name} cannot handle
ShapeExpr inputs.")
+
+ op = self._convert_operator(op_name, inputs, attr, self.opset)
+ # Create struct information for the new operator.
+ op = self.bb.normalize(op)
+
+ if not isinstance(op, relax.Tuple):
+ if isinstance(op.checked_type, tvm.ir.type.TupleType):
+ # This is a var bound to a tuple. We need to unpack it and
create
+ # a new tuple.
+ tuple_items = []
+ for i in range(len(op.checked_type.fields)):
+ tuple_items.append(self.bb.emit(relax.TupleGetItem(op,
i)))
+ op = relax.Tuple(tuple_items)
+ outputs_num = len(tuple_items)
+ else:
+ outputs_num = 1
+ else:
+ outputs_num = len(op)
+ assert (
+ len(outputs) <= outputs_num
+ ), "Missing outputs during conversion. Expected {} but Got {} in
{}.".format(
+ len(outputs), outputs_num, op_name
+ )
+
+ if outputs_num == 1:
+ self._nodes[outputs[0]] = op
+ else:
+ for k, i in zip(list(outputs), range(len(outputs))):
+ self._nodes[k] = op[i]
+
+ def _parse_value_proto(self, value_proto: onnx.onnx_ml_pb2.GraphProto):
+ """Parse ValueProto or raw str."""
+ try:
+ name = value_proto.name
+ except AttributeError:
+ name = value_proto
+ return name
+
+ def _parse_array(self, tensor_proto: onnx.onnx_ml_pb2.TensorProto) ->
tvm.nd.array:
+ np_array = get_numpy(tensor_proto).reshape(tuple(tensor_proto.dims))
+ return tvm.nd.array(np_array)
+
+ def _parse_attr(self, attr_proto: onnx.onnx_ml_pb2.AttributeProto) ->
Dict[str, Any]:
+ """Convert a list of AttributeProto to a dict, with names as keys."""
+ attrs = {}
+ for a in attr_proto:
+ for f in ["f", "i", "s", "g"]:
+ if a.HasField(f):
+ attrs[a.name] = getattr(a, f)
+ for f in ["floats", "ints", "strings"]:
+ if list(getattr(a, f)):
+ assert a.name not in attrs, "Only one type of attr is
allowed"
+ attrs[a.name] = tuple(getattr(a, f))
+ for f in ["t"]:
+ if a.HasField(f):
+ attrs[a.name] = getattr(a, f)
+ for f in ["tensors"]:
+ if list(getattr(a, f)):
+ assert a.name not in attrs, "Only one type of attr is
allowed"
+ attrs[a.name] = tuple(getattr(a, f))
+ for f in ["graphs"]:
+ if list(getattr(a, f)):
+ raise NotImplementedError("Field {} is not supported in
relax.".format(f))
+ if a.name not in attrs:
+ raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
+ return attrs
+
+ def _convert_operator(
+ self,
+ op_name: str,
+ inputs: List[relax.Function],
+ attrs: Dict,
+ opset: int,
+ ) -> relax.Function:
+ """Convert ONNX operator into a Relax operator.
+ The converter must specify conversions explicitly for incompatible
name, and
+ apply handlers to operator attributes.
+
+ Parameters
+ ----------
+ op_name : str
+ Operator name, such as Convolution, FullyConnected
+ inputs : list of tvm.relax.function.Function
+ List of inputs.
+ attrs : dict
+ Dict of operator attributes
+ opset : int
+ Opset version
+ Returns
+ -------
+ sym : tvm.relax.function.Function
+ Converted relax function
+ """
+ convert_map = _get_convert_map()
+ if op_name in convert_map:
+ convert_class = convert_map[op_name]
+ op_function = convert_class.get_converter(opset)
+ sym = op_function(self.bb, inputs, attrs, [self._nodes,
self._params])
+ else:
+ raise NotImplementedError("Operator {} not
implemented.".format(op_name))
+ return sym
+
+
+def from_onnx(
+ model: onnx.onnx_ml_pb2.GraphProto,
+ shape_dict: Optional[Dict[str, List]] = None,
+ dtype_dict: Optional[Union[str, Dict[str, str]]] = "float32",
+ opset: int = None,
+ keep_params_in_input: bool = False,
+ sanitize_input_names: bool = True,
+) -> Tuple[IRModule, Dict]:
+ """Convert a ONNX model into an equivalent Relax Function.
+ ONNX graphs are represented as Python Protobuf objects.
+
+ The current implementation assumes that the input model is after ONNX
v1.1.0.
+
+ Parameters
+ ----------
+ model : protobuf object
+ ONNX ModelProto after ONNX v1.1.0
+ shape_dict : dict of str to tuple, optional
+ The input shape to the graph
+ dtype_dict : str or dict of str to str, optional
+ The input types to the graph
+ opset : int, optional
+ Override to autodetected opset.
+ This can be helpful for some testing.
+ keep_params_in_input : bool
+ If True, parameters will be treated as input variables. If false,
+ parameters are treated as constant and folded directly into the graph.
+ sanitize_input_names : bool, optional
+ Whether to sanitize the input names to ensure they are valid Relax
identifiers.
+
+ Returns
+ -------
+ mod : tvm.IRModule
+ The relax module for compilation
+ params : dict of str to tvm.nd.NDArray
+ The parameter dict to be used by relax
+ """
+ # Error if the model version is below 1.1.0
+ if model.ir_version < 3:
+ raise ValueError(
+ "Model IR version {} not supported. Must be at least after
1.1.0.".format(
+ model.ir_version
+ )
+ )
+
+ try:
+ import onnx # pylint: disable=import-outside-toplevel,
redefined-outer-name
+
+ if hasattr(onnx.checker, "check_model"):
+ # try use onnx's own model checker before converting any model
+ try:
+ onnx.checker.check_model(model)
+ except Exception as exception: # pylint:
disable=c-extension-no-member, broad-except
+ # the checker is a bit violent about errors, so simply print
warnings here
+ warnings.warn(str(exception))
+ except ImportError as error:
+ raise ImportError("Unable to import onnx which is required
{}".format(error))
+
+ g = ONNXGraphImporter(
+ shape_dict,
+ dtype_dict,
+ keep_params_in_input=keep_params_in_input,
+ sanitize=sanitize_input_names,
+ )
+ graph = model.graph
+
+ try:
+ opset_in_model = 1
+ if model.opset_import:
+ # TODO: for now we only really support ai.onnx op set
+ # TODO: handle other namespaces well see
https://github.com/apache/tvm/issues/10950
+ for opset_identifier in model.opset_import:
+ # As per https://github.com/onnx/onnx/blob/main/docs/IR.md
+ # All operator sets except the default one must specify the
operator version
+ if str(opset_identifier.domain) in ["ai.onnx", ""]:
+ opset_in_model = opset_identifier.version
+ break
+ except AttributeError:
+ opset_in_model = 1
+
+ if opset is None:
+ opset = opset_in_model
+ elif opset < opset_in_model:
+ warnings.warn(
+ ""
+ f"You are overwritting original opset ver = {opset_in_model} by
lower ver = {opset}. "
+ f"That might cause model conversion errors."
+ )
+
+ # Use the graph proto as a scope so that ops can access other nodes if
needed.
+ return g.from_onnx(graph, opset)
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
new file mode 100644
index 0000000000..4c4d2d5a95
--- /dev/null
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -0,0 +1,1619 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""
+ONNX testcases
+================
+This file is a test script to test Relax ONNX frontend coverage.
+"""
+
+from typing import Optional, Dict
+
+import numpy as np
+import pytest
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.relax.frontend.onnx import from_onnx
+
+import onnx
+from onnx import helper, TensorProto, ModelProto, mapping
+import onnxruntime
+
+bg = np.random.MT19937(0)
+rg = np.random.Generator(bg)
+
+
+def generate_random_inputs(
+ model: ModelProto, inputs: Optional[Dict[str, np.ndarray]] = None
+) -> Dict[str, np.ndarray]:
+ input_values = {}
+ # Iterate through model inputs and extract their shape.
+ for i in model.graph.input:
+ if inputs is not None and i.name in inputs and inputs[i.name] is not
None:
+ input_values[i.name] = inputs[i.name]
+ continue
+ shape = []
+ for dim in i.type.tensor_type.shape.dim:
+ shape.append(dim.dim_value)
+
+ # Extract datatype for the input.
+ if i.type.tensor_type.elem_type:
+ dtype =
str(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[i.type.tensor_type.elem_type])
+ else:
+ dtype = "float32"
+
+ # Generate random inputs for each input.
+ if dtype == "bool":
+ # random_value = np.random.choice(a=[False, True], size=shape)
+ random_value = rg.choice(a=[False, True], size=shape)
+ else:
+ # random_value = np.random.normal(size=shape).astype(dtype)
+ random_value = rg.standard_normal(size=shape).astype(dtype)
+ input_values[i.name] = random_value
+
+ return input_values
+
+
+def check_correctness(
+ model: ModelProto, inputs: Optional[Dict[str, np.ndarray]] = None, opset:
int = None
+) -> None:
+ """Run an onnx model in both onnxruntime and TVM through our importer
+ confirm that the results match. Otherwise, an exception will be raised.
+
+ Parameters
+ ----------
+ model: ModelProto
+ The input onnx model that should be tested.
+ inputs: Optional[Dict[str, np.ndarray]]
+ An optional dictionary containing values for each input in the onnx
model.
+ opset: int
+ The opset version to use for the onnx importer.
+ """
+ if opset is not None:
+ model.opset_import[0].version = opset
+
+ # If inputs are not provided, extract them from the onnx graph and produce
random
+ # values that we'll use for testing.
+ inputs = generate_random_inputs(model, inputs)
+
+ # Run the model through onnx to get the expected result.
+ ort_session = onnxruntime.InferenceSession(
+ model.SerializeToString(), providers=["CPUExecutionProvider"]
+ )
+ ort_output = ort_session.run([], inputs)
+
+ # Convert the onnx model into relax through the onnx importer.
+ tvm_model = from_onnx(model, opset=opset, keep_params_in_input=True)
+ # Convert operators for inference mode.
+ tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
+ # Legalize any relax ops into tensorir.
+ tvm_model = relax.transform.LegalizeOps()(tvm_model)
+
+ # Separate model from parameters.
+ tvm_model, params = relax.frontend.detach_params(tvm_model)
+ # Compile the relax graph into a VM then run.
+ with tvm.transform.PassContext(opt_level=3):
+ ex = relax.build(tvm_model, target="llvm")
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ # Prepare inputs.
+ input_list = [
+ inputs[key.name_hint] for key in tvm_model["main"].params if
key.name_hint in inputs
+ ]
+ if params:
+ input_list += params["main"]
+
+ # Run model and check outputs.
+ vm.set_input("main", *input_list)
+ vm.invoke_stateful("main")
+ tvm_output = vm.get_outputs("main")
+ # Wrap as a list if there is only one output.
+ if isinstance(tvm_output, tvm.nd.NDArray):
+ tvm_output = [tvm_output]
+ # If the output is a shape tuple, convert it to an ndarray for comparison.
+ if isinstance(tvm_output, tvm.runtime.ShapeTuple):
+ tvm_output = [tvm.nd.array([int(i) for i in tvm_output])]
+
+ tvm_num_outputs = len(tvm_output)
+ # Shape tuples need to be handled specially.
+ if isinstance(tvm_output, tvm.runtime.ShapeTuple):
+ tvm_num_outputs = 1
+
+ # Check that number of outputs match.
+
+ assert tvm_num_outputs == len(ort_output), "Unequal number of outputs"
+
+ for (tvm_out, ort_out) in zip(tvm_output, ort_output):
+ # TODO Allow configurable tolerance.
+ # Sometimes None is used to indicate an unused output.
+ if ort_out is not None:
+ tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, atol=1e-5)
+
+
[email protected](
+ "input_names, expected_names",
+ [
+ ([".", "123"], ["_", "input_123"]),
+ ([".", "_"], ["_", "__1"]),
+ (["123", "input_123"], ["input_123", "input_123_1"]),
+ ],
+)
+def test_sanitize(input_names, expected_names):
+ node = helper.make_node("Add", inputs=input_names, outputs=["output"])
+ graph = helper.make_graph(
+ [node],
+ "test",
+ inputs=[
+ helper.make_tensor_value_info(str(var), TensorProto.FLOAT, [32,
32])
+ for var in input_names
+ ],
+ outputs=[
+ helper.make_tensor_value_info("output", TensorProto.FLOAT, [32,
32]),
+ ],
+ )
+ model = helper.make_model(graph, producer_name="test_sanitizer")
+
+ tvm_model = from_onnx(model)
+
+ for i, param in enumerate(tvm_model["main"].params):
+ assert param.name_hint == expected_names[i]
+
+
+def verify_unary(op_name, shape, attrs={}, domain=None):
+ test_node = helper.make_node(op_name, ["x"], ["y"], **attrs, domain=domain)
+ graph = helper.make_graph(
+ [test_node],
+ "elemwise_test",
+ inputs=[
+ helper.make_tensor_value_info("x", TensorProto.FLOAT, shape),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)],
+ )
+
+ model = helper.make_model(graph, producer_name="elemwise_test")
+ check_correctness(model)
+
+
+def verify_binary(op_name, shape_a, shape_b, shape_c, attrs={}, domain=None):
+ test_node = helper.make_node(op_name, ["a", "b"], ["c"], **attrs,
domain=domain)
+ graph = helper.make_graph(
+ [test_node],
+ "binary_test",
+ inputs=[
+ helper.make_tensor_value_info("a", TensorProto.FLOAT, shape_a),
+ helper.make_tensor_value_info("b", TensorProto.FLOAT, shape_b),
+ ],
+ outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT,
shape_c)],
+ )
+
+ model = helper.make_model(graph, producer_name="binary_test")
+ check_correctness(model)
+
+
+def verify_compare(op_name, shape, attrs={}, domain=None):
+ test_node = helper.make_node(op_name, ["a", "b"], ["c"], **attrs,
domain=domain)
+ graph = helper.make_graph(
+ [test_node],
+ "compare_test",
+ inputs=[
+ helper.make_tensor_value_info("a", TensorProto.FLOAT, shape),
+ helper.make_tensor_value_info("b", TensorProto.FLOAT, shape),
+ ],
+ outputs=[helper.make_tensor_value_info("c", TensorProto.BOOL, shape)],
+ )
+
+ model = helper.make_model(graph, producer_name="compare_test")
+ check_correctness(model)
+
+
+def verify_ternary(op_name, shape_a, shape_b, shape_c, shape_d, attrs={},
domain=None):
+ test_node = helper.make_node(op_name, ["a", "b", "c"], ["d"], **attrs,
domain=domain)
+ graph = helper.make_graph(
+ [test_node],
+ "ternary_test",
+ inputs=[
+ helper.make_tensor_value_info("a", TensorProto.FLOAT, shape_a),
+ helper.make_tensor_value_info("b", TensorProto.FLOAT, shape_b),
+ helper.make_tensor_value_info("c", TensorProto.FLOAT, shape_c),
+ ],
+ outputs=[helper.make_tensor_value_info("d", TensorProto.FLOAT,
shape_d)],
+ )
+
+ model = helper.make_model(graph, producer_name="ternary_test")
+ check_correctness(model)
+
+
[email protected]("dynamic", [True, False])
+def test_matmul(dynamic):
+ matmul_node = helper.make_node("MatMul", ["a", "b"], ["c"])
+
+ a_shape = [32, 48]
+ b_shape = [48, 64]
+ output_shape = [32, 64]
+
+ if dynamic:
+ a_shape = ["?", "?"]
+
+ graph = helper.make_graph(
+ [matmul_node],
+ "matmul_test",
+ inputs=[
+ helper.make_tensor_value_info("a", TensorProto.FLOAT, a_shape),
+ ],
+ initializer=[
+ helper.make_tensor(
+ "b", TensorProto.FLOAT, b_shape,
np.random.normal(size=b_shape).astype("float32")
+ )
+ ],
+ outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT,
output_shape)],
+ )
+
+ model = helper.make_model(graph, producer_name="matmul_test")
+ inputs = None
+ if dynamic:
+ inputs = {
+ "a": np.random.normal(size=[32, 48]).astype("float32"),
+ }
+ check_correctness(model, inputs)
+
+
+def test_concat():
+ verify_binary("Concat", [1, 32], [1, 32], [2, 32], attrs={"axis": 0})
+
+
+def test_add():
+ verify_binary("Add", [1, 32], [1, 32], [1, 32])
+
+
+def test_mul():
+ verify_binary("Mul", [1, 32], [1, 32], [1, 32])
+
+
[email protected]("from_type", [TensorProto.INT32, TensorProto.FLOAT,
TensorProto.FLOAT16])
[email protected]("to_type", [TensorProto.INT32, TensorProto.FLOAT,
TensorProto.FLOAT16])
+def test_cast(from_type, to_type):
+ cast_node = helper.make_node("Cast", ["a"], ["a_float"], to=to_type)
+
+ graph = helper.make_graph(
+ [cast_node],
+ "cast_test",
+ inputs=[
+ helper.make_tensor_value_info("a", from_type, [1, 32]),
+ ],
+ outputs=[helper.make_tensor_value_info("a_float", to_type, [1, 32])],
+ )
+
+ model = helper.make_model(graph, producer_name="cast_test")
+ check_correctness(model, opset=13)
+
+
+def test_gather():
+ def _verify_gather(data_shape, indices, out_shape, axis=0):
+ gather_node = helper.make_node("Gather", ["data", "indices"], ["y"],
axis=axis)
+
+ if isinstance(indices, (list, tuple)):
+ indices_shape = np.asarray(indices).shape
+ else:
+ indices_shape = []
+
+ graph = helper.make_graph(
+ [gather_node],
+ "gather_test",
+ inputs=[
+ helper.make_tensor_value_info("data", TensorProto.FLOAT,
data_shape),
+ helper.make_tensor_value_info("indices", TensorProto.INT64,
indices_shape),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT,
out_shape)],
+ )
+
+ model = helper.make_model(graph, producer_name="gather_test")
+ input_values = {
+ "data": np.random.randn(*data_shape).astype("float32"),
+ "indices": np.array(indices).astype("int64"),
+ }
+ check_correctness(model, inputs=input_values)
+
+ _verify_gather([5, 4, 3, 2], [0, 1, 3], [3, 4, 3, 2])
+ _verify_gather([3], 0, [])
+ _verify_gather([3, 3], [[0, 2]], [3, 1, 2], 1)
+
+
[email protected]("alpha", [None, 0.25])
[email protected]("beta", [None, 0.35])
[email protected]("useC", [False, True])
+def test_gemm(alpha, beta, useC):
+ if useC:
+ gemm_node = helper.make_node(
+ "Gemm", ["a", "b", "c"], ["y"], alpha=alpha, beta=beta, transA=1,
transB=1
+ )
+ else:
+ gemm_node = helper.make_node(
+ "Gemm", ["a", "b"], ["y"], alpha=alpha, beta=beta, transA=1,
transB=1
+ )
+
+ inputs = [
+ helper.make_tensor_value_info("a", TensorProto.FLOAT, [4, 3]),
+ helper.make_tensor_value_info("b", TensorProto.FLOAT, [5, 4]),
+ ]
+ if useC:
+ inputs.append(helper.make_tensor_value_info("c", TensorProto.FLOAT,
[1, 5]))
+
+ graph = helper.make_graph(
+ [gemm_node],
+ "gemm_test",
+ inputs=inputs,
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3,
5])],
+ )
+
+ model = helper.make_model(graph, producer_name="gemm_test")
+ check_correctness(model)
+
+
[email protected](
+ "in_shape, shape, out_shape",
+ [
+ ([7, 32, 32, 8], [224, 256], [224, 256]),
+ ([7, 32, 32, 8], [-1, 8192], [7, 8192]),
+ ([7, 32, 32, 8], [0, 32, 32, 8], [7, 32, 32, 8]),
+ ],
+)
+def test_reshape(in_shape, shape, out_shape):
+ reshape_node = helper.make_node("Reshape", ["data", "shape"], ["reshaped"])
+
+ graph = helper.make_graph(
+ [reshape_node],
+ "reshape_test",
+ inputs=[
+ helper.make_tensor_value_info("data", TensorProto.FLOAT, in_shape),
+ ],
+ initializer=[helper.make_tensor("shape", TensorProto.INT64,
[len(shape)], shape)],
+ outputs=[helper.make_tensor_value_info("reshaped", TensorProto.FLOAT,
out_shape)],
+ )
+ input_values = {
+ "data": np.random.randn(*in_shape).astype("float32"),
+ }
+ model = helper.make_model(graph, producer_name="reshape_test")
+ check_correctness(model, inputs=input_values)
+
+
+def test_div():
+ verify_binary("Div", [32, 32], [32, 32], [32, 32])
+
+
+def test_sigmoid():
+ verify_unary("Sigmoid", [32, 32])
+
+
+def test_softmax():
+ verify_unary("Softmax", [32, 32, 32])
+
+
+def test_transpose():
+ verify_unary("Transpose", [32, 32, 32], attrs={"perm": [1, 2, 0]})
+
+
+def test_unsqueeze():
+ unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"])
+
+ graph = helper.make_graph(
+ [unsqueeze_node],
+ "unsqueeze",
+ inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [32,
32])],
+ initializer=[helper.make_tensor("axes", TensorProto.INT64, [3],
vals=[0, 2, 3])],
+ outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32,
1, 1, 32])],
+ )
+
+ model = helper.make_model(graph, producer_name="unsqueeze_test")
+ check_correctness(model)
+
+
+def test_gelu():
+ verify_unary("Gelu", [32, 32], domain="com.microsoft")
+
+
+def test_bias_gelu():
+ verify_binary("BiasGelu", [32, 32], [32], [32, 32], domain="com.microsoft")
+
+
+def test_where():
+ where_node = helper.make_node("Where", ["a", "b", "c"], ["d"])
+
+ graph = helper.make_graph(
+ [where_node],
+ "where_test",
+ inputs=[
+ helper.make_tensor_value_info("a", TensorProto.BOOL, [32, 32]),
+ helper.make_tensor_value_info("b", TensorProto.FLOAT, [32, 32]),
+ helper.make_tensor_value_info("c", TensorProto.FLOAT, [32, 32]),
+ ],
+ outputs=[helper.make_tensor_value_info("d", TensorProto.FLOAT, [32,
32])],
+ )
+
+ model = helper.make_model(graph, producer_name="where_test")
+ check_correctness(model)
+
+
[email protected]("min", [True, False])
[email protected]("max", [True, False])
+def test_clip(min, max):
+ if min and max:
+ clip_node = helper.make_node("Clip", ["input", "min", "max"],
["output"])
+ elif min:
+ clip_node = helper.make_node("Clip", ["input", "min"], ["output"])
+ elif max:
+ clip_node = helper.make_node("Clip", ["input", "max"], ["output"])
+ else:
+ clip_node = helper.make_node("Clip", ["input"], ["output"])
+
+ inputs = [helper.make_tensor_value_info("input", TensorProto.FLOAT, [32,
64])]
+ if min:
+ inputs.append(helper.make_tensor_value_info("min", TensorProto.FLOAT,
()))
+ if max:
+ inputs.append(helper.make_tensor_value_info("max", TensorProto.FLOAT,
()))
+
+ graph = helper.make_graph(
+ [clip_node],
+ "clip_test",
+ inputs=inputs,
+ outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT,
[32, 64])],
+ )
+
+ model = helper.make_model(graph, producer_name="clip_test")
+ check_correctness(model)
+
+
+def test_equal():
+ equal_node = helper.make_node("Equal", ["a", "b"], ["output"])
+
+ graph = helper.make_graph(
+ [equal_node],
+ "equal_test",
+ inputs=[
+ helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]),
+ helper.make_tensor_value_info("b", TensorProto.FLOAT, [32, 32]),
+ ],
+ outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL,
[32, 32])],
+ )
+
+ model = helper.make_model(graph, producer_name="equal_test")
+ check_correctness(
+ model, {"a": np.zeros([32, 32], dtype="float32"), "b": np.zeros([32,
32], dtype="float32")}
+ )
+ check_correctness(
+ model, {"a": np.ones([32, 32], dtype="float32"), "b": np.zeros([32,
32], dtype="float32")}
+ )
+ check_correctness(model)
+
+
+def test_shape():
+ shape_node = helper.make_node("Shape", ["data"], ["output"])
+
+ graph = helper.make_graph(
+ [shape_node],
+ "shape_test",
+ inputs=[
+ helper.make_tensor_value_info("data", TensorProto.FLOAT, [3, 4, 5,
6]),
+ ],
+ outputs=[helper.make_tensor_value_info("output", TensorProto.INT64,
[4])],
+ )
+
+ model = helper.make_model(graph, producer_name="shape_test")
+ check_correctness(model)
+
+
+def test_tanh():
+ verify_unary("Tanh", [9, 8, 7, 6])
+
+
+def test_sqrt():
+ verify_unary("Sqrt", [32, 32])
+
+
+def test_relu():
+ verify_unary("Relu", [32, 32])
+
+
+def test_conv():
+ def _verify_conv(input_shape, weight_shape, output_shape):
+ bias_shape = [output_shape[1]]
+ conv_node = helper.make_node("Conv", ["x", "w", "b"], ["y"])
+ graph = helper.make_graph(
+ [conv_node],
+ "conv_test",
+ inputs=[
+ helper.make_tensor_value_info("x", TensorProto.FLOAT,
input_shape),
+ helper.make_tensor_value_info("w", TensorProto.FLOAT,
weight_shape),
+ helper.make_tensor_value_info("b", TensorProto.FLOAT,
bias_shape),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT,
output_shape)],
+ )
+
+ model = helper.make_model(graph, producer_name="conv_test")
+ check_correctness(model)
+
+ _verify_conv([3, 12, 32, 32], [4, 12, 3, 3], [3, 4, 30, 30])
+
+
+def test_pow():
+ verify_binary("Pow", [32, 32], [32, 32], [32, 32])
+
+
+def test_erf():
+ verify_unary("Erf", [32, 32])
+
+
[email protected]("reverse", [False])
[email protected]("exclusive", [False])
+def test_cumsum(reverse, exclusive):
+ cumsum_node = helper.make_node(
+ "CumSum", ["x", "axis"], ["y"], reverse=reverse, exclusive=exclusive
+ )
+ shape = [32, 32]
+ graph = helper.make_graph(
+ [cumsum_node],
+ "cumsum_test",
+ inputs=[
+ helper.make_tensor_value_info("x", TensorProto.FLOAT, shape),
+ ],
+ initializer=[helper.make_tensor("axis", TensorProto.INT64, (), [1])],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)],
+ )
+
+ model = helper.make_model(graph, producer_name="cumsum_test")
+ check_correctness(model)
+
+
[email protected]("axis", [[0, 2], None])
+def test_squeeze(axis):
+ if axis:
+ squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"])
+ else:
+ squeeze_node = helper.make_node("Squeeze", ["x"], ["y"])
+ shape = [1, 32, 1, 32]
+
+ initializer = (
+ [helper.make_tensor("axes", TensorProto.INT64, [len(axis)], axis)] if
axis else None
+ )
+
+ graph = helper.make_graph(
+ [squeeze_node],
+ "squeeze_test",
+ inputs=[
+ helper.make_tensor_value_info("x", TensorProto.FLOAT, shape),
+ ],
+ initializer=initializer,
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32,
32])],
+ )
+
+ model = helper.make_model(graph, producer_name="squeeze_test")
+ check_correctness(model, opset=13)
+
+
+def test_const():
+ shape = [32, 32]
+ const_node = helper.make_node(
+ "Constant",
+ [],
+ ["y"],
+ value=helper.make_tensor(
+ "value", TensorProto.FLOAT, shape,
np.random.rand(*shape).astype(np.float32).flatten()
+ ),
+ )
+ graph = helper.make_graph(
+ [const_node],
+ "const_test",
+ inputs=[],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)],
+ )
+
+ model = helper.make_model(graph, producer_name="const_test")
+ check_correctness(model)
+
+
+def test_sub():
+ verify_binary("Sub", [32, 16], [32, 16], [32, 16])
+
+
+def test_min():
+ verify_binary("Min", [32, 16], [32, 16], [32, 16])
+
+
+def test_max():
+ verify_binary("Max", [32, 16], [32, 16], [32, 16])
+
+
+def test_sin():
+ verify_unary("Sin", [32, 16])
+
+
+def test_cos():
+ verify_unary("Cos", [32, 16])
+
+
+def test_identity():
+ verify_unary("Identity", [32, 16])
+
+
+def test_neg():
+ verify_unary("Neg", [32, 16])
+
+
+def test_abs():
+ verify_unary("Abs", [32, 16])
+
+
+def test_log():
+ verify_unary("Log", [32, 16])
+
+
+def test_exp():
+ verify_unary("Exp", [32, 16])
+
+
+def test_instance_norm():
+ verify_ternary(
+ "InstanceNormalization", [1, 3, 32, 32], [3], [3], [1, 3, 32, 32],
attrs={"epsilon": 1e-12}
+ )
+ verify_ternary(
+ "InstanceNormalization", [1, 32, 32], [32], [32], [1, 32, 32],
attrs={"epsilon": 1e-12}
+ )
+
+
+def test_layer_norm():
+ layer_norm_node = helper.make_node("LayerNormalization", ["a", "b", "c"],
["d"], epsilon=1e-12)
+
+ graph = helper.make_graph(
+ [layer_norm_node],
+ "layer_norm_test",
+ inputs=[
+ helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]),
+ helper.make_tensor_value_info("b", TensorProto.FLOAT, [32]),
+ helper.make_tensor_value_info("c", TensorProto.FLOAT, [32]),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("d", TensorProto.FLOAT, [32, 32]),
+ ],
+ )
+
+ model = helper.make_model(graph, producer_name="layer_norm_test")
+ check_correctness(model)
+
+
+# TODO Enable dynamism
[email protected]("dynamic", [False])
+def test_skiplayernormalization(dynamic):
+ def verify_skiplayernormalization(input_, skip, gamma, beta, bias):
+ node = onnx.helper.make_node(
+ "SkipLayerNormalization",
+ inputs=["input", "skip", "gamma", "beta", "bias"],
+ outputs=["output", "mean", "std_dev"],
+ domain="com.microsoft",
+ )
+
+ node.attribute.append(onnx.helper.make_attribute("epsilon", 1e-4))
+
+ input_shape = list(input_.shape)
+ skip_shape = list(skip.shape)
+ gamma_shape = list(gamma.shape)
+ beta_shape = list(beta.shape)
+ bias_shape = list(bias.shape)
+ output_shape = list(input_.shape)
+ mean_shape = list([1])
+ std_dev_shape = list([1])
+ if dynamic:
+ input_shape = ["?" for _ in range(len(input_.shape))]
+ skip_shape = ["?" for _ in range(len(skip.shape))]
+ gamma_shape = ["?" for _ in range(len(gamma.shape))]
+ beta_shape = ["?" for _ in range(len(beta.shape))]
+ bias_shape = ["?" for _ in range(len(bias.shape))]
+ output_shape = ["?" for _ in range(len(input_.shape))]
+
+ graph = helper.make_graph(
+ [node],
+ "skiplayernormalization_test",
+ inputs=[
+ helper.make_tensor_value_info("input", TensorProto.FLOAT,
input_shape),
+ helper.make_tensor_value_info("skip", TensorProto.FLOAT,
skip_shape),
+ helper.make_tensor_value_info("gamma", TensorProto.FLOAT,
gamma_shape),
+ helper.make_tensor_value_info("beta", TensorProto.FLOAT,
beta_shape),
+ helper.make_tensor_value_info("bias", TensorProto.FLOAT,
bias_shape),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("output", TensorProto.FLOAT,
output_shape),
+ helper.make_tensor_value_info("mean", TensorProto.FLOAT,
mean_shape),
+ helper.make_tensor_value_info("std_dev", TensorProto.FLOAT,
std_dev_shape),
+ ],
+ )
+
+ model = helper.make_model(graph,
producer_name="skiplayernormalization_test")
+ check_correctness(
+ model,
+ inputs={"input": input_, "skip": skip, "gamma": gamma, "beta":
beta, "bias": bias},
+ )
+
+ hidden_size = 384
+ batch_size = 4
+ sequence_length = 4
+
+ dtype = "float32"
+ input_array = np.random.random((batch_size, sequence_length,
hidden_size)).astype(dtype)
+ skip = np.random.random((batch_size, sequence_length,
hidden_size)).astype(dtype)
+ gamma = np.random.uniform(0.5, 0.7, hidden_size).astype(dtype)
+ beta = np.random.randn(hidden_size).astype(dtype) * 0.1
+ bias = np.random.randn(hidden_size).astype(dtype)
+
+ verify_skiplayernormalization(input_array, skip, gamma, beta, bias)
+
+
+def test_embedlayernormalization():
+ def verify_embedlayernormalization(
+ input_ids,
+ segment_ids,
+ word_embedding,
+ position_embedding,
+ segment_embedding,
+ gamma,
+ beta,
+ ):
+ node = onnx.helper.make_node(
+ "EmbedLayerNormalization",
+ inputs=[
+ "input_ids",
+ "" if segment_ids is None else "segment_ids",
+ "word_embedding",
+ "position_embedding",
+ "" if segment_embedding is None else "segment_embedding",
+ "gamma",
+ "beta",
+ ],
+ outputs=["output", "mask_index"],
+ domain="com.microsoft",
+ )
+
+ node.attribute.append(onnx.helper.make_attribute("epsilon", 1e-4))
+
+ segment_ids_shape = [] if segment_ids is None else segment_ids.shape
+ segment_embedding_shape = [] if segment_embedding is None else
segment_embedding.shape
+
+ graph = helper.make_graph(
+ [node],
+ "embedlayernormalization_test",
+ inputs=[
+ helper.make_tensor_value_info(
+ "input_ids", TensorProto.INT32, list(input_ids.shape)
+ ),
+ helper.make_tensor_value_info("segment_ids",
TensorProto.INT32, segment_ids_shape),
+ helper.make_tensor_value_info(
+ "word_embedding", TensorProto.FLOAT,
list(word_embedding.shape)
+ ),
+ helper.make_tensor_value_info(
+ "position_embedding", TensorProto.FLOAT,
list(position_embedding.shape)
+ ),
+ helper.make_tensor_value_info(
+ "segment_embedding", TensorProto.FLOAT,
segment_embedding_shape
+ ),
+ helper.make_tensor_value_info("gamma", TensorProto.FLOAT,
list(gamma.shape)),
+ helper.make_tensor_value_info("beta", TensorProto.FLOAT,
list(beta.shape)),
+ ],
+ outputs=[
+ helper.make_tensor_value_info(
+ "output", TensorProto.FLOAT, list((batch_size,
sequence_length, hidden_size))
+ ),
+ helper.make_tensor_value_info("mask_index", TensorProto.INT32,
[batch_size]),
+ ],
+ )
+
+ model = helper.make_model(graph,
producer_name="embedlayernormalization_test")
+
+ inputs = {
+ "input_ids": input_ids,
+ "segment_ids": segment_ids,
+ "word_embedding": word_embedding,
+ "position_embedding": position_embedding,
+ "segment_embedding": segment_embedding,
+ "gamma": gamma,
+ "beta": beta,
+ }
+ check_correctness(model, inputs=inputs)
+
+ # TODO(@anwang2009): onnxruntime v1.9.0 requires empty list for
optional argument,
+ # but v1.10.0+ requires None instead.
+ # verify_with_ort_with_inputs(
+ # model,
+ # [
+ # input_ids,
+ # np.empty(0, dtype="int32") if segment_ids is None else
segment_ids,
+ # word_embedding,
+ # position_embedding,
+ # np.empty(0, dtype="float32") if segment_embedding is None
else segment_embedding,
+ # gamma,
+ # beta,
+ # ],
+ # [
+ # (batch_size, sequence_length, hidden_size),
+ # batch_size,
+ # ],
+ # target=target,
+ # dev=dev,
+ # rtol=1e-4,
+ # atol=1e-4,
+ # )
+
+ hidden_size = 384
+ batch_size = 4
+ sequence_length = 3
+ vocab_size = 5
+
+ input_ids = np.full((batch_size, sequence_length), 3).astype("int32")
+ segment_ids = np.zeros((batch_size, sequence_length)).astype("int32")
+ word_embedding = np.full((vocab_size, hidden_size), 1).astype("float32")
+ position_embedding = np.full((sequence_length, hidden_size),
2).astype("float32")
+ segment_embedding = np.full((vocab_size, hidden_size), 3).astype("float32")
+
+ gamma = np.random.uniform(0.5, 0.7, hidden_size).astype("float32")
+ beta = np.random.randn(hidden_size).astype("float32") * 0.1
+
+ verify_embedlayernormalization(
+ input_ids, segment_ids, word_embedding, position_embedding,
segment_embedding, gamma, beta
+ )
+
+ # Test with undefined segment embedding
+ verify_embedlayernormalization(
+ input_ids, None, word_embedding, position_embedding, None, gamma, beta
+ )
+
+
+def create_reduce_test_parameters():
+ output = []
+ for value in [True, False]:
+ output.append(("ReduceMax", value))
+ output.append(("ReduceMean", value))
+ output.append(("ReduceMin", value))
+ output.append(("ReduceProd", value))
+ output.append(("ReduceSum", value))
+ output.append(("ReduceSumSquare", value))
+ output.append(("ReduceLogSum", value))
+ output.append(("ReduceLogSumExp", value))
+ output.append(("ReduceL1", value))
+ output.append(("ReduceL2", value))
+ return output
+
+
[email protected]("func, dynamic", create_reduce_test_parameters())
+def test_all_reduce_funcs(func, dynamic):
+ def verify_reduce_func(func, data, axis, keepdims):
+ inshape = data.shape
+ outshape = np.sum(data, axis=axis, keepdims=keepdims == 1).shape
+
+ if axis:
+ node = onnx.helper.make_node(
+ func, inputs=["x"], outputs=["y"], axes=axis, keepdims=keepdims
+ )
+ else:
+ node = onnx.helper.make_node(func, inputs=["x"], outputs=["y"],
keepdims=keepdims)
+
+ if dynamic:
+ in_list = ["?" for _ in range(len(inshape))]
+ out_list = ["?" for _ in range(len(outshape))]
+ else:
+ in_list = list(inshape)
+ out_list = list(outshape)
+ graph = helper.make_graph(
+ [node],
+ "reduce_test",
+ inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT,
in_list)],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT,
out_list)],
+ )
+
+ model = helper.make_model(graph, producer_name="reduce_test")
+
+ inputs_dict = {"x": data}
+ check_correctness(model, inputs_dict, opset=11)
+
+ for keepdims in [True, False]:
+ verify_reduce_func(
+ func, np.random.randn(3, 2, 2).astype(np.float32), axis=None,
keepdims=keepdims
+ )
+
+ verify_reduce_func(
+ func, np.random.randn(3, 2, 3).astype(np.float32), axis=None,
keepdims=keepdims
+ )
+
+ verify_reduce_func(
+ func, np.random.randn(3, 3, 3).astype(np.float32), axis=(1,),
keepdims=keepdims
+ )
+
+ verify_reduce_func(
+ func, np.random.randn(3, 3, 3, 1).astype(np.float32), axis=(1, 2),
keepdims=keepdims
+ )
+
+ verify_reduce_func(
+ func, np.random.randn(3, 3, 3, 1).astype(np.float32), axis=(1,),
keepdims=keepdims
+ )
+
+ verify_reduce_func(
+ func, np.random.randn(1, 3, 4, 1).astype(np.float32), axis=(1,),
keepdims=keepdims
+ )
+
+
[email protected]("in_dtype", [np.float32, np.int32])
[email protected]("axis", [None, 0, 1, 2])
[email protected]("keepdims", [None, True, False])
+def test_arg_min_max(in_dtype, axis, keepdims):
+ def verify_arg_min_max(input_dim, in_dtype, op_name="ArgMax", axis=None,
keepdims=None):
+ a_np1 = np.random.uniform(-10, 10, input_dim).astype(in_dtype)
+ out_shape = list(a_np1.shape)
+ def_axis = axis if axis is not None else 0
+ if keepdims == 1 or keepdims is None:
+ out_shape[def_axis] = 1
+ else:
+ out_shape.pop(def_axis)
+
+ node = helper.make_node(op_name, inputs=["a_np1"], outputs=["out"])
+
+ if keepdims is not None:
+ keepdims_attr = helper.make_attribute("keepdims", keepdims)
+ node.attribute.append(keepdims_attr)
+ if axis is not None:
+ axis_attr = helper.make_attribute("axis", axis)
+ node.attribute.append(axis_attr)
+
+ graph = helper.make_graph(
+ [node],
+ "argreduce_test",
+ inputs=[helper.make_tensor_value_info("a_np1", TensorProto.INT32,
list(a_np1.shape))],
+ outputs=[helper.make_tensor_value_info("out", TensorProto.INT64,
list(out_shape))],
+ )
+
+ model = helper.make_model(graph, producer_name="arg_min_max_test")
+ check_correctness(model)
+
+ verify_arg_min_max([3, 4, 4], in_dtype, "ArgMax", axis, keepdims)
+ verify_arg_min_max([3, 4, 4], in_dtype, "ArgMin", axis, keepdims)
+
+
[email protected]("dynamic", [False, True])
+# TODO(jwfromm) Current approach to dynamic expand is technically not well
formed. Reenable once fixed.
[email protected]("Produces ill-formed IR")
+def test_expand(dynamic):
+ if dynamic:
+ # TODO: Support dynamic shape for Expand
+ pytest.skip("Dynamic expand is not supported yet")
+
+ def _test_expand(name, data, shape, ref_data):
+ shape_array = np.array(shape)
+ shape_node = onnx.helper.make_node(
+ "Constant",
+ inputs=[],
+ outputs=["shape"],
+ value=onnx.helper.make_tensor(
+ name="const_tensor",
+ data_type=onnx.TensorProto.INT64,
+ dims=shape_array.shape,
+ vals=shape_array.flatten().astype("int64"),
+ ),
+ )
+ expand_node = helper.make_node("Expand", ["in", "shape"], ["out"])
+
+ in_shape = list(data.shape)
+ out_shape = list(ref_data.shape)
+ if dynamic:
+ in_shape = ["?" for _ in range(len(in_shape))]
+ out_shape = ["?" for _ in range(len(out_shape))]
+ graph = helper.make_graph(
+ [shape_node, expand_node],
+ "expand_teint64st",
+ inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT,
in_shape)],
+ outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT,
out_shape)],
+ )
+
+ model = helper.make_model(graph, producer_name=name)
+ check_correctness(model, inputs={"in": data})
+
+ in_shape = (3, 1)
+ shape = (3, 4)
+ data = np.random.uniform(size=in_shape).astype(np.float32)
+ ref_data = np.tile(data, 4)
+ _test_expand("expand_with_dim_unchanged_test", data, shape, ref_data)
+
+
+# TODO(jwfromm) Current approach to dynamic expand is technically not well
formed. Reenable once fixed.
[email protected]("Produces ill-formed IR")
+def test_constantofshape():
+ def verify_constantofshape(input_dim, value, dtype):
+ fill_node = helper.make_node(
+ "ConstantOfShape",
+ ["input"],
+ ["output"],
+ value=helper.make_tensor(
+ "value", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)],
(1,), (value,)
+ ),
+ )
+
+ inputs = [helper.make_tensor_value_info("input", TensorProto.INT64,
[len(input_dim)])]
+
+ graph = helper.make_graph(
+ [fill_node],
+ "fill_test",
+ inputs,
+ initializer=[
+ helper.make_tensor(
+ "input",
+ TensorProto.INT64,
+ [len(input_dim)],
+ np.asarray(input_dim).astype("int64"),
+ )
+ ],
+ outputs=[
+ helper.make_tensor_value_info(
+ "output", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)],
input_dim
+ )
+ ],
+ )
+
+ model = helper.make_model(graph, producer_name="fill_test")
+ input_np = np.array(input_dim).astype("int64")
+ check_correctness(model, inputs={"input": input_np})
+
+ verify_constantofshape((2, 3, 4, 5), 10, "float32")
+ verify_constantofshape((3, 3), 0, "int32")
+ verify_constantofshape((1, 2, 3), -1, "float32")
+
+
+def test_slice():
+ def verify_slice(data_shape, output_shape, starts, ends, axes=None,
steps=None):
+ if isinstance(starts, list):
+ starts = np.array(starts, "int64")
+ if isinstance(ends, list):
+ ends = np.array(ends, "int64")
+ if isinstance(axes, list):
+ axes = np.array(axes, "int64")
+ if isinstance(steps, list):
+ steps = np.array(steps, "int64")
+
+ slice_inputs = ["x", "starts", "ends"]
+ initializer = [
+ helper.make_tensor("starts", TensorProto.INT64, starts.shape,
starts),
+ helper.make_tensor("ends", TensorProto.INT64, ends.shape, ends),
+ ]
+
+ if axes is not None:
+ initializer.append(helper.make_tensor("axes", TensorProto.INT64,
axes.shape, axes))
+ slice_inputs.append("axes")
+ if steps is not None:
+ initializer.append(helper.make_tensor("steps", TensorProto.INT64,
steps.shape, steps))
+ slice_inputs.append("steps")
+
+ slice_node = helper.make_node("Slice", inputs=slice_inputs,
outputs=["y"])
+
+ graph = helper.make_graph(
+ [slice_node],
+ "slice_test",
+ inputs=[
+ helper.make_tensor_value_info("x", TensorProto.FLOAT,
data_shape),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT,
output_shape)],
+ initializer=initializer,
+ )
+
+ model = helper.make_model(graph, producer_name="slice_test")
+ check_correctness(model)
+
+ # Test with all parameters set.
+ verify_slice([20, 10, 5], [3, 10, 5], starts=[0, 0], ends=[3, 10],
axes=[0, 1], steps=[1, 1])
+ # Test with default axes and steps.
+ verify_slice([20, 10, 5], [3, 10, 5], starts=[0, 0], ends=[3, 10])
+ # Test with negative steps.
+ verify_slice(
+ [20, 10, 5],
+ [19, 3, 2],
+ starts=[20, 10, 4],
+ ends=[0, 0, 1],
+ steps=[-1, -3, -2],
+ axes=[0, 1, 2],
+ )
+ verify_slice([20, 10, 5], [10, 5], starts=[0, 0], ends=[3, 10], axes=[1,
2])
+
+ # TODO (gigiblender): Enable this test when we have a way to pass the
steps but not axes.
+ # verify_slice(
+ # [20, 10, 5],
+ # [19, 3, 2],
+ # starts=[20, 10, 4],
+ # ends=[0, 0, 1],
+ # steps=[-1, -3, -2],
+ # )
+
+
+# TODO Enable dynamism
[email protected]("dynamic", [False])
+def test_attention(dynamic):
+ def verify_attention(
+ input_,
+ weight,
+ bias,
+ mask_index,
+ num_heads,
+ mask_filter_value,
+ qkv_hidden_sizes,
+ relative_position_bias,
+ ):
+ node = onnx.helper.make_node(
+ "Attention",
+ inputs=["input", "weight", "bias", "mask_index", "",
"relative_position_bias"],
+ outputs=["output"],
+ domain="com.microsoft",
+ num_heads=num_heads,
+ # TODO(jwfromm) OnnxRT doesnt work with this attribute, figure out
why not.
+ # mask_filter_value=mask_filter_value,
+ qkv_hidden_sizes=qkv_hidden_sizes,
+ )
+
+ input_shape = list(input_.shape)
+ weight_shape = list(weight.shape)
+ bias_shape = list(bias.shape)
+ mask_shape = list(mask_index.shape)
+ relative_position_bias_shape = list(relative_position_bias.shape)
+ output_shape = list(input_.shape)
+ if dynamic:
+ input_shape = ["?" for _ in range(len(input_.shape))]
+ weight_shape = ["?" for _ in range(len(weight.shape))]
+ bias_shape = ["?" for _ in range(len(bias.shape))]
+ mask_shape = ["?" for _ in range(len(mask_index.shape))]
+ output_shape = ["?" for _ in range(len(input_.shape))]
+
+ graph = helper.make_graph(
+ [node],
+ "attention_test",
+ inputs=[
+ helper.make_tensor_value_info("input", TensorProto.FLOAT,
input_shape),
+ helper.make_tensor_value_info("weight", TensorProto.FLOAT,
weight_shape),
+ helper.make_tensor_value_info("bias", TensorProto.FLOAT,
bias_shape),
+ helper.make_tensor_value_info("mask_index", TensorProto.INT32,
mask_shape),
+ helper.make_tensor_value_info(
+ "relative_position_bias", TensorProto.FLOAT,
relative_position_bias_shape
+ ),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("output", TensorProto.FLOAT,
output_shape),
+ ],
+ )
+
+ model = helper.make_model(graph, producer_name="attention_test")
+
+ check_correctness(
+ model,
+ inputs={
+ "input": input_,
+ "weight": weight,
+ "bias": bias,
+ "mask_index": mask_index,
+ "relative_position_bias": relative_position_bias,
+ },
+ )
+ # "present" output should be nullptr when the "past" input isn't
included,
+ # but ort requires an output shape to be specified?
+ # verify_with_ort_with_inputs(
+ # model,
+ # [input_, weight, bias, mask_index],
+ # [input_.shape, present_output_shape],
+ # target=target,
+ # dev=dev,
+ # rtol=1e-4,
+ # atol=1e-4,
+ # )
+
+ input_hidden_size = 128
+ batch_size = 4
+ sequence_length = 4
+ num_heads = 12
+ qkv_hidden_sizes = [192, 192, 96]
+ mask_filter_value = -512.0
+
+ dtype = "float32"
+ input_array = np.random.random((batch_size, sequence_length,
input_hidden_size)).astype(dtype)
+ weight = np.random.normal(size=(input_hidden_size,
sum(qkv_hidden_sizes))).astype(dtype) * 0.1
+ bias = np.random.randn(sum(qkv_hidden_sizes)).astype(dtype)
+ mask_index = np.random.randint(2, size=(batch_size,
sequence_length)).astype("int32")
+ relative_position_bias = np.random.randn(
+ batch_size, num_heads, sequence_length, sequence_length
+ ).astype(dtype)
+
+ verify_attention(
+ input_array,
+ weight,
+ bias,
+ mask_index,
+ num_heads,
+ mask_filter_value,
+ qkv_hidden_sizes,
+ relative_position_bias,
+ )
+
+
[email protected]("dynamic", [True, False])
+def test_pad(dynamic):
+
+ if dynamic:
+ pytest.skip("Dynamic pad not supported")
+
+ def verify_pad(input_shape, pads, mode="constant", value=0.0):
+ indata = np.random.normal(size=input_shape).astype(np.float32)
+ # numpy expect result
+ len_dim = len(pads) // 2
+ np_pads = [(pads[i], pads[i + len_dim]) for i in range(len_dim)]
+ pads = np.array(pads)
+ # onnx graph
+ if mode in ["edge", "reflect"]:
+ outdata = np.pad(indata, pad_width=np_pads, mode=mode)
+ node = helper.make_node("Pad", inputs=["input", "pads"],
outputs=["output"], mode=mode)
+ graph = helper.make_graph(
+ [node],
+ "pad_test",
+ inputs=[
+ helper.make_tensor_value_info("input", TensorProto.FLOAT,
list(indata.shape))
+ ],
+ initializer=[helper.make_tensor("pads", TensorProto.INT64,
(len(pads),), pads)],
+ outputs=[
+ helper.make_tensor_value_info("output", TensorProto.FLOAT,
list(outdata.shape))
+ ],
+ )
+ else:
+ outdata = np.pad(indata, pad_width=np_pads, mode="constant",
constant_values=value)
+ node = helper.make_node(
+ "Pad",
+ inputs=["input", "pads", "constant_value"],
+ outputs=["output"],
+ mode="constant",
+ )
+ graph = helper.make_graph(
+ [node],
+ "pad_test",
+ inputs=[
+ helper.make_tensor_value_info("input", TensorProto.FLOAT,
list(indata.shape))
+ ],
+ initializer=[
+ helper.make_tensor("pads", TensorProto.INT64,
(len(pads),), pads),
+ helper.make_tensor("constant_value", TensorProto.FLOAT,
(1,), [value]),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("output", TensorProto.FLOAT,
list(outdata.shape))
+ ],
+ )
+ model = helper.make_model(graph, producer_name="pad_test")
+ check_correctness(model)
+
+ verify_pad((2, 2), [0, 1, 0, 0], "constant", 0.0)
+ verify_pad((2, 3), [1, 0, 0, 1], "constant", 0.0)
+ verify_pad((3, 2), [0, 0, 1, 0], "constant", 5.0)
+ verify_pad((1, 3, 4, 5), [0, 1, 1, 1, 0, 0, 1, 1], "reflect")
+
+
[email protected]("fp_arith", [np.float16, np.float32])
[email protected]("dynamic", [True, False])
+def test_split(fp_arith, dynamic):
+ def verify_split(indata_shape, outdata_shapes, split, axis=0,
pass_split=True, opset=11):
+ indata = np.random.normal(size=indata_shape).astype(fp_arith)
+ input_names = ["input"]
+ initializer = []
+
+ if split:
+ split_index = range(len(split))
+ else:
+ split_index = range(len(outdata_shapes))
+
+ indata_shape = list(indata.shape)
+ if dynamic:
+ indata_shape = ["?" for _ in range(len(indata.shape))]
+ outdata_shapes = [["?" for _ in range(len(o))] for o in
outdata_shapes]
+
+ inputs = [
+ helper.make_tensor_value_info(
+ "input", mapping.NP_TYPE_TO_TENSOR_TYPE[indata.dtype],
indata_shape
+ )
+ ]
+
+ if pass_split:
+ if opset >= 13:
+ np_split = np.array(split).astype(np.int64)
+ initializer.append(
+ helper.make_tensor("split", TensorProto.INT64,
list(np_split.shape), np_split)
+ )
+ node = helper.make_node(
+ "Split",
+ inputs=input_names,
+ outputs=[f"output_{i}" for i in range(len(split_index))],
+ axis=axis,
+ )
+
+ if pass_split and opset < 13:
+ split_attr = helper.make_attribute("split", split)
+ node.attribute.append(split_attr)
+
+ graph = helper.make_graph(
+ [node],
+ "split_test",
+ inputs=inputs,
+ initializer=initializer,
+ outputs=[
+ helper.make_tensor_value_info(
+ f"output_{i}",
+ mapping.NP_TYPE_TO_TENSOR_TYPE[indata.dtype],
+ list(outdata_shapes[i]),
+ )
+ for i in range(len(split_index))
+ ],
+ )
+ model = helper.make_model(graph, producer_name="split_test")
+ check_correctness(model, inputs={"input": indata}, opset=opset)
+
+ # 1D
+ verify_split(6, [[2], [2], [2]], [2, 2, 2])
+ verify_split(6, [[2], [2], [2]], [2, 2, 2], pass_split=False)
+ verify_split(6, [[2], [1], [3]], [2, 1, 3])
+ verify_split(6, [[2], [1], [3]], [2, 1, 3], opset=13)
+ # 2D
+ verify_split(
+ (4, 4),
+ [[2, 2], [2, 2]],
+ [2, 2],
+ axis=1,
+ )
+ verify_split(
+ (4, 4),
+ [[2, 2], [2, 2]],
+ [2, 2],
+ axis=1,
+ opset=13,
+ )
+ # Split evenly (unstack)
+ verify_split(3, [[1], [1], [1]], False, pass_split=False)
+ # Split a single value to a single value
+ verify_split(1, [[1]], [1], pass_split=True)
+ # Test that the default case modifies nothing when split list has length
one
+ verify_split((1, 2), [[2]], [2], axis=1)
+ verify_split((1, 2), [[2]], [1])
+
+
[email protected]("dynamic", [True, False])
+def test_tile(dynamic):
+ def verify_tile(in_shape, repeats, out_shape):
+ node = helper.make_node("Tile", inputs=["input", "repeats"],
outputs=["out"])
+
+ if dynamic:
+ indata = np.random.normal(size=in_shape).astype(np.float32)
+ in_shape = ["?" for _ in range(len(in_shape))]
+ out_shape = ["?" for _ in range(len(out_shape))]
+
+ graph = helper.make_graph(
+ [node],
+ "tile_test",
+ inputs=[
+ helper.make_tensor_value_info("input", TensorProto.FLOAT,
in_shape),
+ ],
+ initializer=[
+ helper.make_tensor("repeats", TensorProto.INT64,
list(repeats.shape), repeats)
+ ],
+ outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT,
out_shape)],
+ )
+
+ model = helper.make_model(graph, producer_name="tile_test")
+
+ if dynamic:
+ check_correctness(model, {"input": indata})
+ else:
+ check_correctness(model)
+
+ x = np.random.rand(2, 3, 4, 5).astype(np.float32)
+ repeats = np.random.randint(low=1, high=10,
size=(np.ndim(x),)).astype(np.int64)
+ z_array = np.tile(x, repeats)
+ verify_tile(x.shape, repeats, z_array.shape)
+
+
+def test_resize():
+ resize_node = helper.make_node("Resize", ["X", "", "scales"], ["Y"],
mode="cubic")
+
+ graph = helper.make_graph(
+ [resize_node],
+ "resize_test",
+ inputs=[
+ helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 32,
32]),
+ ],
+ initializer=[
+ helper.make_tensor("scales", TensorProto.FLOAT, [4], [1.0, 1.0,
2.0, 2.0]),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 64,
64]),
+ ],
+ )
+
+ model = helper.make_model(graph, producer_name="resize_test")
+ check_correctness(model)
+
+
+def test_einsum():
+ eqn = "ij->i"
+ einsum_node = helper.make_node("Einsum", ["x"], ["y"], equation=eqn)
+
+ graph = helper.make_graph(
+ [einsum_node],
+ "einsum_test",
+ inputs=[
+ helper.make_tensor_value_info("x", TensorProto.FLOAT, [3, 4]),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("y", TensorProto.FLOAT, [3]),
+ ],
+ )
+
+ model = helper.make_model(graph, producer_name="einsum_test")
+ check_correctness(model)
+
+
+def test_range():
+ range_node = helper.make_node(
+ "Range",
+ ["start", "limit", "delta"],
+ ["output"],
+ )
+
+ graph = helper.make_graph(
+ [range_node],
+ "range_test",
+ inputs=[],
+ initializer=[
+ helper.make_tensor("start", TensorProto.INT64, [], [1]),
+ helper.make_tensor("limit", TensorProto.INT64, [], [5]),
+ helper.make_tensor("delta", TensorProto.INT64, [], [2]),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("output", TensorProto.INT64, [2]),
+ ],
+ )
+
+ model = helper.make_model(graph, producer_name="range_test")
+ check_correctness(model)
+
+
+def test_less():
+ verify_compare("Less", [32, 32])
+
+
+def test_less_equal():
+ verify_compare("LessOrEqual", [32, 32])
+
+
+def test_batch_norm():
+ batch_norm_node = helper.make_node(
+ "BatchNormalization", ["x", "s", "bias", "mean", "var"], ["y"],
epsilon=1e-2
+ )
+ graph = helper.make_graph(
+ [batch_norm_node],
+ "batch_norm_test",
+ inputs=[
+ helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 3, 4,
5]),
+ helper.make_tensor_value_info("s", TensorProto.FLOAT, [3]),
+ helper.make_tensor_value_info("bias", TensorProto.FLOAT, [3]),
+ helper.make_tensor_value_info("mean", TensorProto.FLOAT, [3]),
+ helper.make_tensor_value_info("var", TensorProto.FLOAT, [3]),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 3,
4, 5])],
+ )
+
+ model = helper.make_model(graph, producer_name="batch_norm_test")
+ check_correctness(model, opset=15)
+
+
+def test_max_pool():
+ # Pool2D
+ verify_unary(
+ "MaxPool",
+ [1, 1, 32, 32],
+ dict(
+ auto_pad="NOTSET",
+ kernel_shape=[3, 3],
+ pads=[1, 1, 1, 1],
+ strides=[1, 1],
+ ),
+ )
+ # Pool2D with stride
+ verify_unary(
+ "MaxPool",
+ [1, 1, 32, 32],
+ dict(
+ auto_pad="NOTSET",
+ kernel_shape=[3, 3],
+ pads=[1, 1, 1, 1],
+ strides=[2, 2],
+ ),
+ )
+ # Pool2D with stride and autopadding
+ verify_unary(
+ "MaxPool",
+ [1, 1, 32, 32],
+ dict(
+ auto_pad="SAME_UPPER",
+ kernel_shape=[3, 7],
+ pads=None,
+ strides=[3, 2],
+ ),
+ )
+ verify_unary(
+ "MaxPool",
+ [1, 1, 32, 32],
+ dict(
+ auto_pad="SAME_LOWER",
+ kernel_shape=[3, 3],
+ pads=None,
+ strides=[2, 2],
+ ),
+ )
+ verify_unary(
+ "MaxPool",
+ [1, 1, 32, 32],
+ dict(
+ auto_pad="VALID",
+ kernel_shape=[3, 3],
+ pads=None,
+ strides=[2, 2],
+ ),
+ )
+
+
+def test_global_average_pool():
+ verify_unary("GlobalAveragePool", [1, 3, 32, 32])
+
+
+def test_flatten():
+ verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": 0})
+ verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": -1})
+ verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": 2})
+
+
+def test_greater():
+ verify_compare("Greater", [32, 32])
+ verify_compare("Greater", [64, 16])
+
+
+def test_onehot():
+ one_hot_node = helper.make_node("OneHot", ["indices", "depth", "values"],
["y"], axis=1)
+ graph = helper.make_graph(
+ [one_hot_node],
+ "one_hot_test",
+ inputs=[
+ helper.make_tensor_value_info("indices", TensorProto.INT64, [2,
2]),
+ ],
+ initializer=[
+ helper.make_tensor("depth", TensorProto.INT64, [], [10]),
+ helper.make_tensor("values", TensorProto.FLOAT, [2], [3, 1]),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 10,
2])],
+ )
+
+ model = helper.make_model(graph, producer_name="one_hot_test")
+ values = {
+ "indices": np.array([[1, 9], [2, 4]], dtype="int64"),
+ }
+ check_correctness(model, inputs=values)
+
+
+def test_reciprocal():
+ verify_unary("Reciprocal", [3, 32, 32])
+
+
+if __name__ == "__main__":
+ tvm.testing.main()