larroy commented on a change in pull request #14977: Add an utility for 
operator benchmarks
URL: https://github.com/apache/incubator-mxnet/pull/14977#discussion_r288287180
 
 

 ##########
 File path: benchmark/opperf/utils/ndarray_utils.py
 ##########
 @@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import mxnet as mx
+import mxnet.ndarray as nd
+
+from .profiler_utils import profile
+
+
+@profile
+def nd_forward_backward_and_profile(op, runs, *args, **kwargs):
+    """Helper function to run a given NDArray operator (op) for 'runs' number 
of times with
+    given args and kwargs. Executes both forward and backward pass.
+
+    NOTE: This is a sync call and waits for all the operations execution to 
complete.
+
+    :param op: NDArray operator (Function reference) to execute. Example: 
mx.nd.add
+    :param runs: Number of times to execute the operation
+    :param args: Arguments for the NDArray operator (op) being executed.
+    :param kwargs: Key value arguments for the NDArray operator (op) being 
executed.
+    :return: any results from NDArray operation execution
+
+    """
+    for _ in range(runs):
+        with mx.autograd.record():
+            res = op(*args, **kwargs)
+        res.backward()
+        nd.waitall()
+    return res
+
+
+@profile
+def nd_forward_and_profile(op, runs, *args, **kwargs):
+    """Helper function to run a given NDArray operator (op) for 'runs' number 
of times with
+    given args and kwargs. Executes ONLY forward pass.
+
+    NOTE: This is a sync call and waits for all the operations execution to 
complete.
+
+    :param op: NDArray operator (Function reference) to execute. Example: 
mx.nd.add
+    :param runs: Number of time to execute the operation
+    :param args: Arguments for the NDArray operator (op) being executed.
+    :param kwargs: Key value arguments for the NDArray operator (op) being 
executed.
+    :return: any results from NDArray operation execution
+    """
+    for _ in range(runs):
+        res = op(*args, **kwargs)
+        nd.waitall()
+    return res
+
+
+def get_mx_ndarray(ctx, in_tensor, dtype, initializer, attach_grad=True):
+    """Helper function to prepare a MXNet NDArray tensor in given Context 
(ctx) of type (dtype) with given
+    initializer. You can get a new Tensor by providing only "Shape" or "Numpy 
NDArray" or another MXNet NDArray as
+    "in_tensor".
+
+    NOTE: This is a sync call and waits for the Tensor to be created.
+
+    :param ctx: Context of the new MXNet NDArray Tensor.
+    :param in_tensor: Can be a tuple of shape or Numpy NDArray or MXNet 
NDArray.
+    :param dtype: Precision or Dtype of the expected Tensor. Ex: "float32", 
"Int64"
+    :param initializer: Function reference to the initialize to use. Ex: 
mx.nd.random.normal, mx.nd.zeros
+    :param attach_grad: To attach a gradient for the Tensor. Default is True.
+    :return: MXNet NDArray Tensor.
+    """
+    if isinstance(in_tensor, tuple):
+        tensor = initializer(ctx=ctx, shape=in_tensor, dtype=dtype)
+    elif isinstance(in_tensor, np.ndarray):
+        tensor = nd.array(in_tensor, ctx=ctx, dtype=dtype)
+    elif isinstance(in_tensor, mx.ndarray):
+        tensor = in_tensor.as_in_context(ctx=ctx).astype(dtype=dtype)
+
 
 Review comment:
   can we remove ws?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to