maheshambule commented on a change in pull request #5617:
URL: https://github.com/apache/incubator-tvm/pull/5617#discussion_r431643670



##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -3155,6 +3176,91 @@ def _convert_control_flow_operator(self, node, inputs, 
attrs, control_flow_node_
 
         return op
 
+    def _partition_call_operator(self, inputs, attr):
+        """
+        Convert the Relay Partition call ops into Relay Function calls and
+        function definitions from Tensorflow graph library attribute to Relay 
global
+        functions
+
+        Parameters
+        ----------
+        node: TensorFlow graph node object.
+            A TensorFlow graph node object.
+
+        inputs : List[tvm.relay.Expr]
+            List of input symbols.
+
+        attrs : Dict[tvm.Attrs]
+            Dict of operator attributes.
+
+        Returns
+        -------
+        op : tvm.relay.Expr
+            Converted relay expression.
+        """
+
+        try:
+            from tensorflow.python.framework import function_def_to_graph
+        except ImportError as e:
+            raise ImportError(
+                "Unable to import tensorflow which is required {}".format(e))
+
+        main_graph_proto = self._main_graph_proto
+        outer_graph_def = main_graph_proto._graph
+
+        node_func_name = attr.get('f').name
+        func = next((f for f in outer_graph_def.library.function
+                     if f.signature.name == node_func_name), None)
+        if func:
+            devices = set(node.device for node in func.node_def)
+            if len(devices) > 1:
+                raise Exception("Found inconsistent Device assignment in the "\
+                                "Stateful Partitioned SubGraph. Rejecting "\
+                                "the subgraph ")
+            # Convert function definition to graph
+            func_input_shapes = func.attr["_input_shapes"].list.shape
+            subgraph, _ = function_def_to_graph.\
+                function_def_to_graph_def(func, func_input_shapes)
+
+            # Computing subgraph's input shape dictionary
+            subgraph_shape_dict, input_expr_dict = {}, {}
+            for f_arg, input in zip(func.signature.input_arg, inputs):
+                input_expr_dict[f_arg.name] = input
+                subgraph_shape_dict[f_arg.name] = _infer_shape(input, 
main_graph_proto._mod)
+
+            func_name = 'func_{}'.format(func.signature.name)
+            try:
+                global_func = main_graph_proto._mod[func_name]

Review comment:
       Yes, it is. We add the function definition to main module and if there 
are multiple calls to the same function, we need not add the definition again 
to the main module, we just need to add the function call node.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to