This is an automated email from the ASF dual-hosted git repository.

madjam pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 1a3faa6  Initial commit of an MXNet converter. (#7413)
1a3faa6 is described below

commit 1a3faa63f2a24820427e6454f5d6eaa72ea636c1
Author: Krishna Sridhar <1875987+srik...@users.noreply.github.com>
AuthorDate: Thu Aug 10 15:22:07 2017 -0700

    Initial commit of an MXNet converter. (#7413)
---
 tools/coreml/__init__.py            |  18 ++
 tools/coreml/_layers.py             | 397 ++++++++++++++++++++++++++++++
 tools/coreml/_mxnet_converter.py    | 210 ++++++++++++++++
 tools/coreml/test_mxnet_converer.py | 477 ++++++++++++++++++++++++++++++++++++
 4 files changed, 1102 insertions(+)

diff --git a/tools/coreml/__init__.py b/tools/coreml/__init__.py
new file mode 100644
index 0000000..e56490a
--- /dev/null
+++ b/tools/coreml/__init__.py
@@ -0,0 +1,18 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from _mxnet_converter import *
diff --git a/tools/coreml/_layers.py b/tools/coreml/_layers.py
new file mode 100644
index 0000000..5148984
--- /dev/null
+++ b/tools/coreml/_layers.py
@@ -0,0 +1,397 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as _np
+
+def _get_input_output_name(net, node, index = 0):
+    name = node['name']
+    inputs = node['inputs']
+
+    if index == 'all':
+        input_name = [_get_node_name(net, inputs[id][0]) for id in 
range(len(inputs))]
+    elif type(index) == int:
+        input_name = _get_node_name(net, inputs[0][0])
+    else:
+        input_name = [_get_node_name(net, inputs[id][0]) for id in index]
+    return input_name, name
+
+def _get_node_name(net, node_id):
+    return net['nodes'][node_id]['name']
+
+def _get_node_shape(net, node_id):
+    return net['nodes'][node_id]['shape']
+
+def convert_transpose(net, node, model, builder):
+    """Convert a transpose layer from mxnet to coreml.
+
+    Parameters
+    ----------
+    network: net
+        A mxnet network object.
+
+    layer: node
+        Node to convert.
+
+    model: model
+        An model for MXNet
+
+    builder: NeuralNetworkBuilder
+        A neural network builder object.
+    """
+    input_name, output_name = _get_input_output_name(net, node)
+    name = node['name']
+    param = node['attr']
+    from ast import literal_eval
+    axes = literal_eval(param['axes'])
+    builder.add_permute(name, input_name, output_name, axes)
+
+def convert_flatten(net, node, model, builder):
+    """Convert a flatten layer from mxnet to coreml.
+
+    Parameters
+    ----------
+    network: net
+        A mxnet network object.
+
+    layer: node
+        Node to convert.
+
+    model: model
+        An model for MXNet
+
+    builder: NeuralNetworkBuilder
+        A neural network builder object.
+    """
+    input_name, output_name = _get_input_output_name(net, node)
+    name = node['name']
+    builder.add_flatten(0, name, input_name, output_name)
+
+def convert_softmax(net, node, model, builder):
+    """Convert a softmax layer from mxnet to coreml.
+
+    Parameters
+    ----------
+    network: net
+        A mxnet network object.
+
+    layer: node
+        Node to convert.
+
+    model: model
+        An model for MXNet
+
+    builder: NeuralNetworkBuilder
+        A neural network builder object.
+    """
+    input_name, output_name = _get_input_output_name(net, node)
+    name = node['name']
+    builder.add_softmax(name = name,
+                        input_name = input_name,
+                        output_name = output_name)
+
+def convert_activation(net, node, model, builder):
+    """Convert an activation layer from mxnet to coreml.
+
+    Parameters
+    ----------
+    network: net
+        A mxnet network object.
+
+    layer: node
+        Node to convert.
+
+    model: model
+        An model for MXNet
+
+    builder: NeuralNetworkBuilder
+        A neural network builder object.
+    """
+    input_name, output_name = _get_input_output_name(net, node)
+    name = node['name']
+    mx_non_linearity = node['attr']['act_type']
+    if mx_non_linearity == 'relu':
+        non_linearity = 'RELU'
+    elif mx_non_linearity == 'tanh':
+        non_linearity = 'TANH'
+    elif mx_non_linearity == 'sigmoid':
+        non_linearity = 'SIGMOID'
+    else:
+        raise TypeError('Unknown activation type %s' % mx_non_linearity)
+    builder.add_activation(name = name,
+                           non_linearity = non_linearity,
+                           input_name = input_name,
+                           output_name = output_name)
+
+def convert_elementwise_add(net, node, model, builder):
+    """Convert an elementwise add layer from mxnet to coreml.
+
+        Parameters
+        ----------
+        network: net
+        A mxnet network object.
+
+        layer: node
+        Node to convert.
+
+        model: model
+        An model for MXNet
+
+        builder: NeuralNetworkBuilder
+        A neural network builder object.
+        """
+
+    input_names, output_name = _get_input_output_name(net, node,[0,1])
+    name = node['name']
+
+    builder.add_elementwise(name, input_names, output_name, 'ADD')
+
+def convert_dense(net, node, model, builder):
+    """Convert a dense layer from mxnet to coreml.
+
+    Parameters
+    ----------
+    network: net
+        A mxnet network object.
+
+    layer: node
+        Node to convert.
+
+    model: model
+        An model for MXNet
+
+    builder: NeuralNetworkBuilder
+        A neural network builder object.
+    """
+    input_name, output_name = _get_input_output_name(net, node)
+    param = node['attr']
+    has_bias = True
+    name = node['name']
+
+    inputs = node['inputs']
+    outputs = node['outputs']
+    args = model.arg_params
+    W = args[_get_node_name(net, inputs[1][0])].asnumpy()
+    if has_bias:
+        Wb = args[_get_node_name(net, inputs[2][0])].asnumpy()
+    else:
+        Wb = None
+    nC, nB = W.shape
+
+    builder.add_inner_product(name = name,
+            W = W,
+            Wb = Wb,
+            nB = nB,
+            nC = nC,
+            has_bias = has_bias,
+            input_name = input_name,
+            output_name = output_name)
+
+def convert_convolution(net, node, model, builder):
+    """Convert a convolution layer from mxnet to coreml.
+
+    Parameters
+    ----------
+    network: net
+        A mxnet network object.
+
+    layer: node
+        Node to convert.
+
+    model: model
+        An model for MXNet
+
+    builder: NeuralNetworkBuilder
+        A neural network builder object.
+    """
+    input_name, output_name = _get_input_output_name(net, node)
+    name = node['name']
+    param = node['attr']
+    inputs = node['inputs']
+    outputs = node['outputs']
+    args = model.arg_params
+
+    from ast import literal_eval
+
+    if 'no_bias' in param.keys():
+        has_bias = not literal_eval(param['no_bias'])
+    else:
+        has_bias = True
+
+    border_mode = "same" if literal_eval(param['pad']) != (0, 0) else 'valid'
+    border_mode = "valid"
+    n_filters = int(param['num_filter'])
+    output_shape = None  # (needed for de-conv)
+
+    W = args[_get_node_name(net, inputs[1][0])].asnumpy()
+    if has_bias:
+        Wb = args[_get_node_name(net, inputs[2][0])].asnumpy()
+    else:
+        Wb = None
+
+    n_filters, channels = W.shape[0:2]
+    stride_height, stride_width = literal_eval(param['stride'])
+    kernel_height, kernel_width = literal_eval(param['kernel'])
+
+    W = W.transpose((2, 3, 1, 0))
+    builder.add_convolution(name = name,
+             kernelChannels = channels,
+             outputChannels = n_filters,
+             height = kernel_height,
+             width = kernel_width,
+             stride_height = stride_height,
+             stride_width = stride_width,
+             borderMode = border_mode,
+             groups = 1,
+             W = W,
+             b = Wb,
+             has_bias = has_bias,
+             is_deconv = False,
+             output_shape = output_shape,
+             input_name = input_name,
+             output_name = output_name)
+
+    # Add padding if there is any
+    convLayer = builder.nn_spec.layers[-1].convolution
+    pad = literal_eval(param['pad'])
+    for i in range(len(pad)):
+        convLayer.valid.paddingAmounts.borderAmounts[i].startEdgeSize = pad[i]
+        convLayer.valid.paddingAmounts.borderAmounts[i].endEdgeSize = pad[i]
+
+def convert_pooling(net, node, model, builder):
+    """Convert a pooling layer from mxnet to coreml.
+
+    Parameters
+    ----------
+    network: net
+        A mxnet network object.
+
+    layer: node
+        Node to convert.
+
+    model: model
+        An model for MXNet
+
+    builder: NeuralNetworkBuilder
+        A neural network builder object.
+    """
+    input_name, output_name = _get_input_output_name(net, node)
+    name = node['name']
+    inputs = node['inputs']
+    param = node['attr']
+    outputs = node['outputs']
+    args = model.arg_params
+
+    layer_type_mx = param['pool_type']
+    if layer_type_mx == 'max':
+        layer_type= 'MAX'
+    elif layer_type_mx == 'avg':
+        layer_type = 'AVERAGE'
+    else:
+        raise TypeError("Pooling type %s not supported" % layer_type_mx)
+
+    from ast import literal_eval
+    stride_height, stride_width = literal_eval(param['stride'])
+    kernel_width, kernel_height = literal_eval(param['kernel'])
+
+    padding_type = 'VALID'
+    if 'global_pool' in param.keys():
+        is_global = literal_eval(param['global_pool'])
+    else:
+        is_global = False
+    builder.add_pooling(name = name,
+        height = kernel_height,
+        width = kernel_width,
+        stride_height = stride_height,
+        stride_width = stride_width,
+        layer_type = layer_type,
+        padding_type = padding_type,
+        exclude_pad_area = False,
+        is_global = is_global,
+        input_name = input_name,
+        output_name = output_name)
+
+    # Add padding if there is any
+    poolingLayer = builder.nn_spec.layers[-1].pooling
+    pad = literal_eval(param['pad'])
+    for i in range(len(pad)):
+        poolingLayer.valid.paddingAmounts.borderAmounts[i].startEdgeSize = 
pad[i]
+        poolingLayer.valid.paddingAmounts.borderAmounts[i].endEdgeSize = pad[i]
+
+def convert_batchnorm(net, node, model, builder):
+    """Convert a transpose layer from mxnet to coreml.
+
+    Parameters
+    ----------
+    network: net
+        A mxnet network object.
+
+    layer: node
+        Node to convert.
+
+    model: model
+        An model for MXNet
+
+    builder: NeuralNetworkBuilder
+        A neural network builder object.
+    """
+    input_name, output_name = _get_input_output_name(net, node)
+    name = node['name']
+    param = node['attr']
+    inputs = node['inputs']
+    outputs = node['outputs']
+    args = model.arg_params
+    aux = model.aux_params
+
+    gamma = args[_get_node_name(net, inputs[1][0])].asnumpy()
+    beta = args[_get_node_name(net, inputs[2][0])].asnumpy()
+    mean = aux[_get_node_name(net, inputs[3][0])].asnumpy()
+    variance = aux[_get_node_name(net, inputs[4][0])].asnumpy()
+
+    nb_channels = gamma.shape[0]
+
+    builder.add_batchnorm(
+        name = name,
+        channels = nb_channels,
+        gamma = gamma,
+        beta = beta,
+        mean = mean,
+        variance = variance,
+        input_name = input_name,
+        output_name = output_name)
+
+def convert_concat(net, node, model, builder):
+    """Convert concat layer from mxnet to coreml.
+
+    Parameters
+    ----------
+    network: net
+    A mxnet network object.
+
+    layer: node
+    Node to convert.
+
+    model: model
+    An model for MXNet
+
+    builder: NeuralNetworkBuilder
+    A neural network builder object.
+    """
+    # Get input and output names
+    input_names, output_name = _get_input_output_name(net, node, 'all')
+    name = node['name']
+    mode = 'CONCAT'
+    builder.add_elementwise(name = name, input_names = input_names,
+            output_name = output_name, mode = mode)
diff --git a/tools/coreml/_mxnet_converter.py b/tools/coreml/_mxnet_converter.py
new file mode 100644
index 0000000..88a980c
--- /dev/null
+++ b/tools/coreml/_mxnet_converter.py
@@ -0,0 +1,210 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import _layers
+import coremltools as _coremltools
+import coremltools.models.datatypes as _datatypes
+from coremltools.models import neural_network as _neural_network
+
+import json as _json
+import mxnet as _mxnet
+import numpy as _np
+
+_MXNET_LAYER_REGISTRY  = {
+    'FullyConnected' : _layers.convert_dense,
+    'Activation'     : _layers.convert_activation,
+    'SoftmaxOutput'  : _layers.convert_softmax,
+    'Convolution'    : _layers.convert_convolution,
+    'Pooling'        : _layers.convert_pooling,
+    'Flatten'        : _layers.convert_flatten,
+    'transpose'      : _layers.convert_transpose,
+    'Concat'         : _layers.convert_concat,
+    'BatchNorm'      : _layers.convert_batchnorm,
+    'elemwise_add'   : _layers.convert_elementwise_add,
+}
+
+_MXNET_SKIP_LAYERS = [
+    '_MulScalar',
+]
+
+def _mxnet_remove_batch(input_data):
+    for blob in input_data:
+        input_data[blob] = _np.reshape(input_data[blob], 
input_data[blob].shape[1:])
+    return input_data
+
+def check_error(model, path, shapes, output = 'softmax_output', verbose = 
True):
+    """
+    Check the difference between predictions from MXNet and CoreML.
+    """
+    coreml_model = _coremltools.models.MLModel(path)
+    input_data = {}
+    input_data_copy = {}
+    for ip in shapes:
+        input_data[ip] = _np.random.rand(*shapes[ip]).astype('f')
+        input_data_copy[ip] = _np.copy(input_data[ip])
+
+    dataIter = _mxnet.io.NDArrayIter(input_data_copy)
+    mx_out = model.predict(dataIter).flatten()
+
+    e_out_dict = coreml_model.predict(_mxnet_remove_batch(input_data))
+    e_out = e_out_dict[output].flatten()
+    error = _np.linalg.norm(e_out - mx_out)
+
+    if verbose:
+        print "First few predictions from CoreML : %s" % e_out[0:10]
+        print "First few predictions from MXNet  : %s" % e_out[0:10]
+        print "L2 Error on random data %s" % error
+    return error
+
+def _set_input_output_layers(builder, input_names, output_names):
+    input_layers_indices = []
+    output_layers_indices = []
+    spec = builder.spec
+    layers = builder.spec.neuralNetwork.layers
+    for idx, l in enumerate(layers):
+        if set(input_names).intersection(l.input):
+            input_layers_indices.append(idx)
+        if set(output_names).intersection(l.output):
+            output_layers_indices.append(idx)
+
+    builder.input_layers_indices = input_layers_indices
+    builder.output_layers_indices = output_layers_indices
+    builder.input_layers_is1d = [False for i in input_names]
+    builder.output_layers_is1d = [False for i in output_names]
+
+def _get_layer_converter_fn(layer):
+    """Get the right converter function for MXNet
+    """
+    if layer in _MXNET_LAYER_REGISTRY:
+        return _MXNET_LAYER_REGISTRY[layer]
+    else:
+        raise TypeError("MXNet layer of type %s is not supported." % layer)
+
+def convert(model, order = None, **kwargs):
+    """Convert a keras model to the protobuf spec.
+
+    Parameters
+    ----------
+    model: MXNet model
+        A trained MXNet neural network model.
+
+    order: Order of inputs
+
+    **kwargs :
+        Provide keyword arguments of known shapes.
+
+    Returns
+    -------
+    model_spec: An object of type ModelSpec_pb.
+        Protobuf representation of the model
+    """
+    if not kwargs:
+        raise TypeError("Must provide input shape to be able to perform 
conversion")
+
+    def remove_batch(dim):
+        return dim[1:]
+
+    if order is None:
+        input_names = kwargs.keys()
+        input_dims  = map(remove_batch, kwargs.values())
+    else:
+        names = kwargs.keys()
+        shapes = map(remove_batch, kwargs.values())
+        input_names = [names[i] for i in order]
+        input_dims = [shapes[i] for i in order]
+
+    net = model.symbol
+
+    # Infer shapes and store in a dictionary
+    shapes = net.infer_shape(**kwargs)
+    arg_names = net.list_arguments()
+    output_names = net.list_outputs()
+    aux_names = net.list_auxiliary_states()
+    shape_dict = {}
+    for idx, op in enumerate(arg_names):
+        shape_dict[op] = shapes[0][idx]
+    for idx, op in enumerate(output_names):
+        shape_dict[op] = shapes[1][idx]
+    for idx, op in enumerate(aux_names):
+        shape_dict[op] = shapes[2][idx]
+
+
+    # Get the inputs and outputs
+    output_dims = shapes[1]
+    input_types = [_datatypes.Array(*dim) for dim in input_dims]
+    output_types = [_datatypes.Array(*dim) for dim in output_dims]
+
+    # Make the builder
+    input_features = zip(input_names, input_types)
+    output_features = zip(output_names, output_types)
+    builder = _neural_network.NeuralNetworkBuilder(input_features, 
output_features)
+
+    # Get out the layers
+    net = _json.loads(net.tojson())
+    nodes = net['nodes']
+    for i, node in enumerate(nodes):
+        node['id'] = i
+
+        if node['name'] in shape_dict:
+            node['shape'] = shape_dict[node['name']]
+
+        node['outputs'] = []
+        if 'inputs' in node:
+            for ip in node['inputs']:
+                nodes[ip[0]]['outputs'].append([i, 0])
+        else:
+            node['inputs'] = []
+
+    # Mark the head nodes
+    for head in net['heads']:
+        head_id = head[0]
+        head_node = nodes[head_id]
+        head_node['outputs'] = [head]
+        head_node['name'] += "_output"
+        head_node['shape'] = shape_dict[head_node['name']]
+
+    # For skipped layers, make sure nodes are modified
+    for iter, node in enumerate(nodes):
+        op = node['op']
+        inputs = node['inputs']
+        outputs = node['outputs']
+        if op in _MXNET_SKIP_LAYERS:
+            nodes[inputs[0][0]]['outputs'][0] = outputs[0]
+            nodes[outputs[0][0]]['inputs'][0] = inputs[0]
+
+    # Find the input and output names for this node
+    for iter, node in enumerate(nodes):
+        op = node['op']
+        if op == 'null' or op in _MXNET_SKIP_LAYERS:
+            continue
+        name = node['name']
+        print("%d : %s, %s" % (iter, name, op))
+        converter_func = _get_layer_converter_fn(op)
+        converter_func(net, node, model, builder)
+
+    spec = builder.spec
+    layers = spec.neuralNetwork.layers
+
+    # Set the right inputs and outputs
+    _set_input_output_layers(builder, input_names, output_names)
+    builder.set_input(input_names, input_dims)
+    builder.set_output(output_names, output_dims)
+
+    # Return the spec
+    spec = builder.spec
+    layers = spec.neuralNetwork.layers
+    return spec
diff --git a/tools/coreml/test_mxnet_converer.py 
b/tools/coreml/test_mxnet_converer.py
new file mode 100644
index 0000000..179d04a
--- /dev/null
+++ b/tools/coreml/test_mxnet_converer.py
@@ -0,0 +1,477 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+import mxnet as mx
+import numpy as np
+import tempfile
+import os
+import mxnet_converter
+import coremltools
+
+def _mxnet_remove_batch(input_data):
+    for blob in input_data:
+        input_data[blob] = np.reshape(input_data[blob], 
input_data[blob].shape[1:])
+    return input_data
+
+def _get_coreml_model(net, engine, model_path, input_shape,
+            input_names = ['data'], output_names = ['output']):
+    model = mx.model.FeedForward(net, engine, arg_params = engine.arg_dict)
+    spec = mxnet_converter.convert(model, **input_shape)
+    return coremltools.models.MLModel(spec)
+
+def set_weights(net, engine, mode = 'random'):
+    for arg in net.list_arguments():
+        if mode == 'random':
+            engine.arg_dict[arg][:] = np.random.uniform(-0.1, 0.1, 
engine.arg_dict[arg].shape)
+        elif mode == 'zeros':
+            engine.arg_dict[arg][:] = np.zeros(engine.arg_dict[arg].shape)
+        elif mode == 'ones':
+            engine.arg_dict[arg][:] = np.ones(engine.arg_dict[arg].shape)
+    return net
+
+class MXNetSingleLayerTest(unittest.TestCase):
+    """
+    Unit test class for testing mxnet converter.
+    """
+    def _test_mxnet_model(self, net, engine, delta = 1e-3, **input_shape):
+
+        # Generate some dummy data
+        input_data = {}
+        for ip in input_shape:
+            input_data[ip] = engine.arg_dict[ip].asnumpy()
+        output_blob = net.list_outputs()[0]
+
+        # Make predictions from mxnet (only works on single output for now)
+        mxnet_preds = engine.forward()[0].asnumpy().flatten()
+
+        # Get predictions from coreml
+        model_path = os.path.join(tempfile.mkdtemp(), 'mxnet.mlmodel')
+        model = _get_coreml_model(net, engine, model_path, input_shape, 
input_data.keys())
+        coreml_preds = 
model.predict(_mxnet_remove_batch(input_data)).values()[0].flatten()
+
+        # Check prediction accuracy
+        self.assertEquals(len(mxnet_preds), len(coreml_preds))
+        for i in range(len(mxnet_preds)):
+            self.assertAlmostEquals(mxnet_preds[i], coreml_preds[i], delta = 
delta)
+
+    def test_tiny_inner_product_zero_input(self):
+        np.random.seed(1988)
+        input_shape = (1, 10)
+
+        # Define a model
+        net = mx.sym.Variable('data')
+        net = mx.sym.FullyConnected(data = net, name = 'fc1', num_hidden = 5)
+        engine = net.simple_bind(ctx=mx.cpu(), data=input_shape)
+
+        # Set some random weights
+        set_weights(net, engine, mode = 'zeros')
+
+        # Test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_really_tiny_inner_product_ones_input(self):
+        np.random.seed(1988)
+        input_shape = (1, 1)
+
+        # Define a model
+        net = mx.sym.Variable('data')
+        net = mx.sym.FullyConnected(data = net, name = 'fc1', num_hidden = 1)
+        engine = net.simple_bind(ctx=mx.cpu(), data=input_shape)
+
+        # Set some random weights
+        set_weights(net, engine, mode = 'ones')
+
+        # Test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_really_tiny_2_inner_product_ones_input(self):
+        np.random.seed(1988)
+        input_shape = (1, 1)
+
+        # Define a model
+        net = mx.sym.Variable('data')
+        net = mx.sym.FullyConnected(data = net, name = 'fc1', num_hidden = 5)
+        engine = net.simple_bind(ctx=mx.cpu(), data=input_shape)
+
+        # Set some random weights
+        set_weights(net, engine, mode = 'ones')
+
+        # Test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_tiny_inner_product_ones_input(self):
+        np.random.seed(1988)
+        input_shape = (1, 10)
+
+        # Define a model
+        net = mx.sym.Variable('data')
+        net = mx.sym.FullyConnected(data = net, name = 'fc1', num_hidden = 5)
+        engine = net.simple_bind(ctx=mx.cpu(), data=input_shape)
+
+        # Set some random weights
+        set_weights(net, engine, mode = 'ones')
+
+        # Test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_tiny_inner_product_random_input(self):
+        np.random.seed(1988)
+        input_shape = (1, 10)
+
+        # Define a model
+        net = mx.sym.Variable('data')
+        net = mx.sym.FullyConnected(data = net, name = 'fc1', num_hidden = 5)
+        engine = net.simple_bind(ctx=mx.cpu(), data=input_shape)
+
+        # Set some random weights
+        set_weights(net, engine, mode = 'random')
+
+        # Test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_tiny_softmax_random_input(self):
+        np.random.seed(1988)
+        input_shape = (1, 10)
+
+        # Define a model
+        net = mx.sym.Variable('data')
+        net = mx.sym.FullyConnected(data = net, name = 'fc1', num_hidden = 5)
+        net = mx.sym.SoftmaxOutput(net, name = 'softmax')
+        engine = net.simple_bind(ctx = mx.cpu(), data = input_shape)
+
+        # Set some random weights
+        set_weights(net, engine, mode = 'random')
+
+        # Test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_tiny_relu_activation_random_input(self):
+        np.random.seed(1988)
+        input_shape = (1, 10)
+
+        # Define a model
+        net = mx.sym.Variable('data')
+        net = mx.sym.FullyConnected(data = net, name = 'fc1', num_hidden = 5)
+        net = mx.sym.Activation(net, name = 'relu1', act_type = "relu")
+        engine = net.simple_bind(ctx = mx.cpu(), data = input_shape)
+
+        # Set some random weights
+        set_weights(net, engine, mode = 'random')
+
+        # Test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_tiny_sigmoid_activation_random_input(self):
+        np.random.seed(1988)
+        input_shape = (1, 10)
+
+        # Define a model
+        net = mx.sym.Variable('data')
+        net = mx.sym.FullyConnected(data = net, name = 'fc1', num_hidden = 5)
+        net = mx.sym.Activation(net, name = 'sigmoid1', act_type = "sigmoid")
+        engine = net.simple_bind(ctx = mx.cpu(), data = input_shape)
+
+        # Set some random weights
+        set_weights(net, engine, mode = 'random')
+
+        # Test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_tiny_tanh_activation_random_input(self):
+        np.random.seed(1988)
+        input_shape = (1, 10)
+
+        # Define a model
+        net = mx.sym.Variable('data')
+        net = mx.sym.FullyConnected(data = net, name = 'fc1', num_hidden = 5)
+        net = mx.sym.Activation(net, name = 'tanh1', act_type = "tanh")
+        engine = net.simple_bind(ctx = mx.cpu(), data = input_shape)
+
+        # Set some random weights
+        set_weights(net, engine, mode = 'random')
+
+        # Test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_really_tiny_conv_random_input(self):
+        np.random.seed(1988)
+        input_shape = (1, 1, 10, 10)
+        num_filter = 1
+        kernel = (1 ,1)
+        stride = (1, 1)
+        pad = (0, 0)
+
+        # Define a model
+        net = mx.sym.Variable('data')
+        net = mx.symbol.Convolution(data = net, num_filter = num_filter, 
kernel=kernel,
+                stride = stride, pad = pad, name = 'conv_1')
+        engine = net.simple_bind(ctx = mx.cpu(), data = input_shape)
+
+        # Set some random weights
+        set_weights(net, engine, mode = 'random')
+
+        # Test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_tiny_conv_ones_input(self):
+        np.random.seed(1988)
+        input_shape = (1, 1, 10, 10)
+        num_filter = 1
+        kernel = (5 ,5)
+        stride = (1, 1)
+        pad = (0, 0)
+
+        # Define a model
+        net = mx.sym.Variable('data')
+        net = mx.symbol.Convolution(data = net, num_filter = num_filter, 
kernel=kernel,
+                stride = stride, pad = pad, name = 'conv_1')
+        engine = net.simple_bind(ctx = mx.cpu(), data = input_shape)
+
+        # Set some random weights
+        set_weights(net, engine, mode = 'ones')
+
+        # Test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_tiny_conv_random_input(self):
+        np.random.seed(1988)
+        input_shape = (1, 1, 10, 10)
+        num_filter = 1
+        kernel = (5 ,5)
+        stride = (1, 1)
+        pad = (0, 0)
+
+        # define a model
+        net = mx.sym.Variable('data')
+        net = mx.symbol.Convolution(data = net, num_filter = num_filter, 
kernel=kernel,
+                stride = stride, pad = pad, name = 'conv_1')
+        engine = net.simple_bind(ctx = mx.cpu(), data = input_shape)
+
+        # set some random weights
+        set_weights(net, engine, mode = 'random')
+
+        # test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_tiny_asym_conv_random_input(self):
+        np.random.seed(1988)
+        input_shape = (1, 1, 10, 10)
+        num_filter = 1
+        kernel = (5 ,3)
+        stride = (1, 1)
+        pad = (0, 0)
+
+        # define a model
+        net = mx.sym.Variable('data')
+        net = mx.symbol.Convolution(data = net, num_filter = num_filter, 
kernel=kernel,
+                stride = stride, pad = pad, name = 'conv_1')
+        engine = net.simple_bind(ctx = mx.cpu(), data = input_shape)
+
+        # set some random weights
+        set_weights(net, engine, mode = 'random')
+
+        # test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_tiny_asym_conv_random_asym_input(self):
+        np.random.seed(1988)
+        input_shape = (1, 1, 28, 18)
+        num_filter = 16
+        kernel = (5 ,3)
+        stride = (1, 1)
+        pad = (0, 0)
+        dilate = (1, 1)
+
+        # define a model
+        net = mx.sym.Variable('data')
+        net = mx.symbol.Convolution(data = net, num_filter = num_filter, 
kernel=kernel,
+                stride = stride, pad = pad, name = 'conv_1', dilate = dilate)
+        net = mx.sym.Activation(net, name = 'tanh', act_type = "tanh")
+        engine = net.simple_bind(ctx = mx.cpu(), data = input_shape)
+
+        # set some random weights
+        set_weights(net, engine, mode = 'random')
+
+        # test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_tiny_conv_pooling_random_input(self):
+        np.random.seed(1988)
+        input_shape = (1, 1, 10, 10)
+        num_filter = 1
+        kernel = (5 ,5)
+        stride = (1, 1)
+        pad = (0, 0)
+
+        # define a model
+        net = mx.sym.Variable('data')
+        net = mx.symbol.Convolution(data = net, num_filter = num_filter, 
kernel=kernel,
+                stride = stride, pad = pad, name = 'conv_1')
+        net = mx.symbol.Pooling(data = net, kernel=kernel,
+                stride = stride, pad = pad, name = 'pool_1', pool_type = 'max')
+        engine = net.simple_bind(ctx = mx.cpu(), data = input_shape)
+
+        # set some random weights
+        set_weights(net, engine, mode = 'random')
+
+        # test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_really_tiny_conv_random_3d_input(self):
+        np.random.seed(1988)
+        input_shape = (1, 3, 10, 10)
+        num_filter = 1
+        kernel = (1 ,1)
+        stride = (1, 1)
+        pad = (0, 0)
+
+        # define a model
+        net = mx.sym.Variable('data')
+        net = mx.symbol.Convolution(data = net, num_filter = num_filter, 
kernel=kernel,
+                stride = stride, pad = pad, name = 'conv_1')
+        engine = net.simple_bind(ctx = mx.cpu(), data = input_shape)
+
+        # set some random weights
+        set_weights(net, engine, mode = 'random')
+
+        # test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_really_tiny_conv_random_input_multi_filter(self):
+        np.random.seed(1988)
+        input_shape = (1, 1, 10, 10)
+        num_filter = 64
+        kernel = (1 ,1)
+        stride = (1, 1)
+        pad = (0, 0)
+
+        # define a model
+        net = mx.sym.Variable('data')
+        net = mx.symbol.Convolution(data = net, num_filter = num_filter, 
kernel=kernel,
+                stride = stride, pad = pad, name = 'conv_1')
+        engine = net.simple_bind(ctx = mx.cpu(), data = input_shape)
+
+        # set some random weights
+        set_weights(net, engine, mode = 'random')
+
+        # test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_tiny_conv_random_3d_input(self):
+        np.random.seed(1988)
+        input_shape = (1, 3, 10, 10)
+        num_filter = 1
+        kernel = (5 ,5)
+        stride = (1, 1)
+        pad = (0, 0)
+
+        # define a model
+        net = mx.sym.Variable('data')
+        net = mx.symbol.Convolution(data = net, num_filter = num_filter, 
kernel=kernel,
+                stride = stride, pad = pad, name = 'conv_1')
+        engine = net.simple_bind(ctx = mx.cpu(), data = input_shape)
+
+        # set some random weights
+        set_weights(net, engine, mode = 'random')
+
+        # test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_tiny_conv_random_input_multi_filter(self):
+        np.random.seed(1988)
+        input_shape = (1, 1, 10, 10)
+        num_filter = 64
+        kernel = (5 ,5)
+        stride = (1, 1)
+        pad = (0, 0)
+
+        # define a model
+        net = mx.sym.Variable('data')
+        net = mx.symbol.Convolution(data = net, num_filter = num_filter, 
kernel=kernel,
+                stride = stride, pad = pad, name = 'conv_1')
+        engine = net.simple_bind(ctx = mx.cpu(), data = input_shape)
+
+        # set some random weights
+        set_weights(net, engine, mode = 'random')
+
+        # test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_conv_random(self):
+        np.random.seed(1988)
+        input_shape = (1, 3, 10, 10)
+        num_filter = 64
+        kernel = (5 ,5)
+        stride = (1, 1)
+        pad = (0, 0)
+
+        # define a model
+        net = mx.sym.Variable('data')
+        net = mx.symbol.Convolution(data = net, num_filter = num_filter, 
kernel=kernel,
+                stride = stride, pad = pad, name = 'conv_1')
+        engine = net.simple_bind(ctx = mx.cpu(), data = input_shape)
+
+        # set some random weights
+        set_weights(net, engine, mode = 'random')
+
+        # test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_flatten(self):
+        np.random.seed(1988)
+        input_shape = (1, 3, 10, 10)
+        num_filter = 64
+        kernel = (5 ,5)
+        stride = (1, 1)
+        pad = (0, 0)
+
+        # define a model
+        net = mx.sym.Variable('data')
+        net = mx.symbol.Convolution(data = net, num_filter = num_filter, 
kernel=kernel,
+                stride = stride, pad = pad, name = 'conv_1')
+        net = mx.sym.Flatten(data = net, name = 'flatten1')
+        net = mx.sym.FullyConnected(data = net, name = 'fc1', num_hidden = 5)
+        net = mx.sym.SoftmaxOutput(net, name = 'softmax')
+        engine = net.simple_bind(ctx = mx.cpu(), data = input_shape)
+
+        # set some random weights
+        set_weights(net, engine, mode = 'random')
+
+        # test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)
+
+    def test_transpose(self):
+        np.random.seed(1988)
+        input_shape = (1, 3, 10, 10)
+        num_filter = 64
+        kernel = (5 ,5)
+        stride = (1, 1)
+        pad = (0, 0)
+
+        # define a model
+        net = mx.sym.Variable('data')
+        net = mx.sym.transpose(data = net, name = 'transpose', axes = (0, 1, 
2, 3))
+        net = mx.symbol.Convolution(data = net, num_filter = num_filter, 
kernel=kernel,
+                stride = stride, pad = pad, name = 'conv_1')
+        engine = net.simple_bind(ctx = mx.cpu(), data = input_shape)
+
+        # set some random weights
+        set_weights(net, engine, mode = 'random')
+
+        # test the mxnet model
+        self._test_mxnet_model(net, engine, data = input_shape)

-- 
To stop receiving notification emails like this one, please contact
['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].

Reply via email to