vvchernov commented on code in PR #13802:
URL: https://github.com/apache/tvm/pull/13802#discussion_r1087003563


##########
tests/python/frontend/onnx/test_forward.py:
##########
@@ -6663,6 +6663,105 @@ def verify_qlinearsigmoid(a_shape):
     verify_qlinearsigmoid([])
 
 
+@tvm.testing.parametrize_targets("llvm")
+def test_random_bernoulli(target, dev):
+    """test_random_bernoulli"""
+
+    def verify_bernoulli_with_ort(
+        shape,
+        in_dtype="float32",
+        out_dtype="int32",
+        seed=None,
+        out_shape=None,
+        target=target,
+        dev=dev,
+        use_vm=False,
+        opset=None,
+        freeze_params=False,
+        rtol=0.1,
+        atol=0.1,
+        opt_level=1,
+        convert_config=None,
+    ):
+        def get_bernoulli_model(shape, in_dtype="float32", out_dtype="int32", 
seed=None):
+            onnx_itype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(in_dtype)]
+            onnx_otype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(out_dtype)]
+            node = helper.make_node(
+                "Bernoulli",
+                ["input"],
+                ["output"],
+            )
+            dtype_attr = helper.make_attribute("dtype", onnx_otype)
+            node.attribute.append(dtype_attr)
+            if seed is not None:
+                seed_attr = helper.make_attribute("seed", seed)
+                node.attribute.append(seed_attr)
+
+            graph = helper.make_graph(
+                [node],
+                "random_bernoulli_test",
+                inputs=[helper.make_tensor_value_info("input", onnx_itype, 
list(shape))],
+                outputs=[helper.make_tensor_value_info("output", onnx_otype, 
list(shape))],
+            )
+            return helper.make_model(graph, 
producer_name="random_bernoulli_test")
+
+        inputs = np.random.uniform(size=shape).astype(in_dtype)
+        if seed is None:
+            ort_seed = None
+        else:
+            ort_seed = float(seed)
+        model = get_bernoulli_model(shape, in_dtype, out_dtype, ort_seed)
+        if opset is not None:
+            model.opset_import[0].version = opset
+
+        ort_out = get_onnxruntime_output(model, inputs)
+        if use_vm:
+            tvm_out = get_tvm_output_with_vm(
+                model,
+                inputs,
+                target,
+                dev,
+                opset=opset,
+                freeze_params=freeze_params,
+                convert_config=convert_config,
+            )
+        else:
+            tvm_out = get_tvm_output(
+                model,
+                inputs,
+                target,
+                dev,
+                out_shape,
+                opset=opset,
+                opt_level=opt_level,
+                convert_config=convert_config,
+            )
+
+        if not isinstance(tvm_out, list):
+            tvm_out = [tvm_out]
+        if not isinstance(ort_out, list):
+            ort_out = [ort_out]
+        for tvm_val, ort_val in zip(tvm_out, ort_out):
+            tvm.testing.assert_allclose(ort_val.mean(), tvm_val.mean(), 
rtol=rtol, atol=atol)

Review Comment:
   all points have been done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to