piiswrong closed pull request #8989: Symbol __getitem__ using list_outputs() is 
too expensive
URL: https://github.com/apache/incubator-mxnet/pull/8989
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index faa453529e..d34b194554 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1051,6 +1051,16 @@ MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol,
 MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol,
                                   mx_uint *out_size,
                                   const char ***out_str_array);
+
+/*!
+ * \brief Get number of outputs of the symbol.
+ * \param symbol The symbol
+ * \param out_size number of outputs
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol,
+                                     mx_uint *output_count);
+
 /*!
  * \brief Get a symbol that contains all the internals.
  * \param symbol The symbol
@@ -1077,6 +1087,7 @@ MXNET_DLL int MXSymbolGetChildren(SymbolHandle symbol,
 MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol,
                                 mx_uint index,
                                 SymbolHandle *out);
+
 /*!
  * \brief List auxiliary states in the symbol.
  * \param symbol the symbol
diff --git a/nnvm b/nnvm
index 8d79cfd0b4..7a052d6784 160000
--- a/nnvm
+++ b/nnvm
@@ -1 +1 @@
-Subproject commit 8d79cfd0b42fbe9f6ad75886d495065d5500b9dd
+Subproject commit 7a052d678455f1c96538c1cc5a25f11115363558
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index ce7776d948..22212b0bdb 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -491,14 +491,16 @@ def __getitem__(self, index):
             Indexing key
 
         """
-        output_names = self.list_outputs()
+        output_count = len(self)
         if isinstance(index, py_slice):
             start = 0 if index.start is None else index.start
-            stop = len(output_names) if index.stop is None else index.stop
+            stop = output_count if index.stop is None else index.stop
             step = 1 if index.step is None else index.step
             return Group([self[i] for i in range(start, stop, step)])
 
         if isinstance(index, string_types):
+            # Returning this list of names is expensive. Some symbols may have 
hundreds of outputs
+            output_names = self.list_outputs()
             idx = None
             for i, name in enumerate(output_names):
                 if name == index:
@@ -511,7 +513,7 @@ def __getitem__(self, index):
 
         if not isinstance(index, int):
             raise TypeError('Symbol only support integer index to fetch i-th 
output')
-        if index >= len(output_names):
+        if index >= output_count:
             # Important, python determines the end by this exception
             raise IndexError
         handle = SymbolHandle()
@@ -745,6 +747,25 @@ def list_outputs(self):
             self.handle, ctypes.byref(size), ctypes.byref(sarr)))
         return [py_str(sarr[i]) for i in range(size.value)]
 
+    def __len__(self):
+        """Get number of outputs for the symbol.
+
+        Example
+        -------
+        >>> a = mx.sym.var('a')
+        >>> b = mx.sym.var('b')
+        >>> c = a + b
+        >>> len(c)
+
+        Returns
+        -------
+        len(self): Number of outputs
+            Number of outputs
+        """
+        output_count = mx_uint()
+        check_call(_LIB.MXSymbolGetNumOutputs(self.handle, 
ctypes.byref(output_count)))
+        return output_count.value
+
     def list_auxiliary_states(self):
         """Lists all the auxiliary states in the symbol.
 
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index dad71b0816..3668af0600 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -310,6 +310,11 @@ int MXSymbolListOutputs(SymbolHandle symbol,
   return NNSymbolListOutputNames(symbol, out_size, out_str_array);
 }
 
+int MXSymbolGetNumOutputs(SymbolHandle symbol,
+                           mx_uint *output_count) {
+  return NNSymbolGetNumOutputs(symbol, output_count);
+}
+
 int MXSymbolCompose(SymbolHandle sym,
                     const char *name,
                     mx_uint num_args,
diff --git a/tests/python/unittest/test_symbol.py 
b/tests/python/unittest/test_symbol.py
index 30e76a272e..8fba1cc98c 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -46,6 +46,7 @@ def test_symbol_compose():
     composed = net2(fc3_data=net1, name='composed')
     multi_out = mx.symbol.Group([composed, net1])
     assert len(multi_out.list_outputs()) == 2
+    assert len(multi_out) == 2
 
 
 def test_symbol_copy():
@@ -72,7 +73,9 @@ def test_symbol_children():
     net1 = mx.symbol.FullyConnected(data=oldfc, name='fc2', num_hidden=100)
 
     assert net1.get_children().list_outputs() == ['fc1_output', 'fc2_weight', 
'fc2_bias']
+    assert len(net1.get_children()) == 3
     assert net1.get_children().get_children().list_outputs() == ['data', 
'fc1_weight', 'fc1_bias']
+    assert len(net1.get_children().get_children()) == 3
     assert net1.get_children()['fc2_weight'].list_arguments() == ['fc2_weight']
     assert net1.get_children()['fc2_weight'].get_children() is None
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to