srkreddy1238 commented on a change in pull request #5617:
URL: https://github.com/apache/incubator-tvm/pull/5617#discussion_r429331590



##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -2896,15 +2903,29 @@ def _parse_import_prerequisites(self, graph):
         """
         missing_operators = set()
         for node in graph.node:
+            try:
+                from tensorflow.python.framework import op_def_registry

Review comment:
       Can you confirm is op_def_registry is not part of all TF versions ? pls 
confirn.
   If this is not in 1.x we shouldn't error as the front end is compatible to 
TF 1.x now.

##########
File path: tests/python/frontend/tensorflow/test_forward.py
##########
@@ -3179,10 +3183,342 @@ def test_forward_isfinite():
     _verify_infiniteness_ops(tf.is_finite, "isfinite")
 
 
+def _test_spop_placeholder_one():
+    with tf.Graph().as_default():

Review comment:
       Advice to use appropriate name instead of _one / _two ...etc.
   

##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -3155,6 +3176,91 @@ def _convert_control_flow_operator(self, node, inputs, 
attrs, control_flow_node_
 
         return op
 
+    def _partition_call_operator(self, inputs, attr):
+        """
+        Convert the Relay Partition call ops into Relay Function calls and
+        function definitions from Tensorflow graph library attribute to Relay 
global
+        functions
+
+        Parameters
+        ----------
+        node: TensorFlow graph node object.
+            A TensorFlow graph node object.
+
+        inputs : List[tvm.relay.Expr]
+            List of input symbols.
+
+        attrs : Dict[tvm.Attrs]
+            Dict of operator attributes.
+
+        Returns
+        -------
+        op : tvm.relay.Expr
+            Converted relay expression.
+        """
+
+        try:
+            from tensorflow.python.framework import function_def_to_graph
+        except ImportError as e:
+            raise ImportError(
+                "Unable to import tensorflow which is required {}".format(e))
+
+        main_graph_proto = self._main_graph_proto
+        outer_graph_def = main_graph_proto._graph
+
+        node_func_name = attr.get('f').name
+        func = next((f for f in outer_graph_def.library.function
+                     if f.signature.name == node_func_name), None)
+        if func:
+            devices = set(node.device for node in func.node_def)
+            if len(devices) > 1:
+                raise Exception("Found inconsistent Device assignment in the "\
+                                "Stateful Partitioned SubGraph. Rejecting "\
+                                "the subgraph ")
+            # Convert function definition to graph
+            func_input_shapes = func.attr["_input_shapes"].list.shape
+            subgraph, _ = function_def_to_graph.\
+                function_def_to_graph_def(func, func_input_shapes)
+
+            # Computing subgraph's input shape dictionary
+            subgraph_shape_dict, input_expr_dict = {}, {}
+            for f_arg, input in zip(func.signature.input_arg, inputs):
+                input_expr_dict[f_arg.name] = input
+                subgraph_shape_dict[f_arg.name] = _infer_shape(input, 
main_graph_proto._mod)
+
+            func_name = 'func_{}'.format(func.signature.name)
+            try:
+                global_func = main_graph_proto._mod[func_name]

Review comment:
       Is this the case where the function is called multiple times with in a 
graph ?

##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -2896,15 +2903,29 @@ def _parse_import_prerequisites(self, graph):
         """
         missing_operators = set()
         for node in graph.node:
+            try:
+                from tensorflow.python.framework import op_def_registry
+            except ImportError as e:
+                raise ImportError(
+                    "Unable to import tensorflow which is required 
{}".format(e))
+            getOpDef = op_def_registry._registered_ops.get if 
hasattr(op_def_registry,\
+                        "_registered_ops") else op_def_registry.get
+            op_def = getOpDef(node.op)
             if node.op == "Placeholder" or node.op == 'PlaceholderWithDefault':
                 pass
             elif node.op == "Const":
                 pass
+            elif node.op in ["PartitionedCall", "StatefulPartitionedCall"]:
+                pass
             else:
                 if any([node.op in t for t in [_identity_list, _convert_map,
                                                _convert_map_rnn,
                                                _control_flow_nodes]]):
                     pass
+                elif op_def is not None and op_def.is_stateful:
+                    raise Exception("Found a stateful operator in this graph 
{}. "\

Review comment:
       Better to add this to missing op list (with extended info if needed) 
still instead of exception as we may miss over all missing ops list.

##########
File path: tests/python/frontend/tensorflow/test_forward.py
##########
@@ -3179,10 +3183,342 @@ def test_forward_isfinite():
     _verify_infiniteness_ops(tf.is_finite, "isfinite")
 
 
+def _test_spop_placeholder_one():
+    with tf.Graph().as_default():
+
+        @function.Defun(*[tf.int32]*2)
+        def Forward(x,y):
+            print(x.name)
+            print(y.name)
+            b = tf.add(x, y)
+            return b
+        pl1 = tf.placeholder(tf.int32,name="pl1")
+        pl2 = tf.placeholder(tf.int32,name="pl2")
+        pl3 = tf.placeholder(tf.int32, name="pl3")
+        data = np.array([[-1, 1], [2, -2]], dtype=np.int32)
+        data2 = np.array([[-2, 3], [4, -6]], dtype=np.int32)
+        data3 = np.array([[-2, 3], [4, -6]], dtype=np.int32)
+        z1 = gen_functional_ops.StatefulPartitionedCall(args=[pl1,pl2], 
Tout=[tf.int32],f=Forward)
+        z2 = z1 + pl3
+        compare_tf_with_tvm([data, data2, data3], ['pl1:0', 'pl2:0', 'pl3:0'],
+                            ['StatefulPartitionedCall:0',z2.name],  mode='vm', 
init_global_variables=True)
+
+
+def _test_spop_placeholder_two():
+    with tf.Graph().as_default():
+        data = np.ones([1], dtype=int).astype(np.int32)
+        dataVar = tf.Variable(data, shape=data.shape)
+        pl1 = 
array_ops.placeholder_with_default(dataVar,shape=data.shape,name="pl1")
+        tpl = tf.convert_to_tensor(pl1, dtype=tf.int32)
+
+        @function.Defun(*[tf.int32])
+        def pl_with_default(pl):
+            return tf.expand_dims(tf.multiply(pl, pl), 0)
+
+        z = gen_functional_ops.StatefulPartitionedCall(args=[tpl], 
Tout=[tf.int32], f=pl_with_default)
+        compare_tf_with_tvm(data, ['pl1:0'], 'StatefulPartitionedCall:0', 
mode='vm', init_global_variables=True)
+
+
+def _test_spop_placeholder_three():
+    with tf.Graph().as_default():
+        t1 = tf.placeholder(tf.int32, (3, 3, 3), "t1")
+        t1_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3))
+        t2 = tf.placeholder(tf.int32, (3, 3, 3), "t2")
+        t2_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3))
+
+        @tf.function
+        def add(x, y):
+            return tf.add(x, y, "add_t1_t2")
+
+        t3 = add(t1, t2)
+        compare_tf_with_tvm([t1_data, t2_data], ['t1:0', 't2:0'], [t3.name], 
mode='vm', init_global_variables=True)
+
+
+def _test_spop_placeholder_four():
+    with tf.Graph().as_default():
+        t1_data = np.array([[-1, 1, 3], [2, -2, 4], [2, -3, 14]], 
dtype=np.int32)
+        t2_data = np.array([[-2, 1, 2], [12, -2, 14], [12, -3, 4]], 
dtype=np.int32)
+        t1 = tf.placeholder(tf.int32, name="t1")
+        t2 = tf.placeholder(tf.int32, name="t2")
+
+        @tf.function
+        def add(x, y):
+            return tf.add(x, y, "add_t1_t2")
+
+        t3 = add(t1, t2)
+        compare_tf_with_tvm([t1_data, t2_data], ['t1:0', 't2:0'], [t3.name], 
mode='vm', init_global_variables=True)
+
+
+def _test_spop_function_invocation_basic():
+    with tf.Graph().as_default():
+
+        def fun1(a):
+            return tf.multiply(a,a)
+
+        def fun2(b):
+            return tf.multiply(b,10)
+
+        @tf.function
+        def fun3(x,y):
+            x = fun2(x)
+            y = fun1(y)
+            z = tf.add(x,y)
+            return z
+
+        t3 = fun3(tf.constant(10.5), tf.constant(20.4))
+
+        compare_tf_with_tvm([], [], [t3.name], mode='vm', 
init_global_variables=True)
+
+
+def _test_spop_function_invocation_nested():
+    with tf.Graph().as_default():
+        t1 = tf.placeholder(tf.int32, (3, 3, 3), name="t1")
+        t1_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3))
+        t2 = tf.placeholder(tf.int32, name="t2")
+        t2_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3))
+
+        @tf.function
+        def myfunc(x, y):
+            return tf.add(x, y, "myfunc")
+
+        @tf.function
+        def myfunc2(x, y):
+            z = myfunc(x, y)
+            l = myfunc(z, y)
+            m = myfunc(l,z)
+            return tf.add(l, m, "myfunc2")
+
+        res1 = myfunc(t1, t2)
+        res2 = myfunc2(res1, t1)
+
+        compare_tf_with_tvm([t1_data, t2_data], ['t1:0', 't2:0'], [res2.name], 
mode='vm', init_global_variables=True)
+
+
+def _test_spop_function_invocation_no_autograph():
+    with tf.Graph().as_default():
+
+        @tf.function(autograph=False)
+        def fun1(a):
+            return tf.multiply(a,a)
+
+        @tf.function(autograph=False)
+        def fun2(b):
+            return tf.multiply(b,10)
+
+        @tf.function
+        def fun3(x,y):
+            x = fun2(x)
+            y = fun1(y)
+            z = tf.add(x,y)
+            return z
+
+        t3 = fun3(tf.constant(10.5), tf.constant(20.4))
+
+        compare_tf_with_tvm([], [], [t3.name], mode='vm', 
init_global_variables=True)
+
+
+def _test_spop_function_invocation_defun():
+    with tf.Graph().as_default():
+
+        def fun1(a):
+            return tf.multiply(a,a)
+
+        def fun2(b):
+            return tf.multiply(b,b)
+
+        @function.Defun(dtypes.float32, dtypes.float32, func_name="Fun3")
+        def fun3(x,y):
+            x = fun2(x)
+            y = fun1(y)
+            z = tf.add(x,y)
+            return z
+
+        op = 
gen_functional_ops.StatefulPartitionedCall(args=[tf.constant(10.5),tf.constant(20.4)],
+                                                        Tout=[dtypes.float32], 
f=fun3, name="SpopFnInvocation")
+        compare_tf_with_tvm([],[], 'SpopFnInvocation:0', mode='vm', 
init_global_variables=True)
+
+
+def _test_spop_arithmetic():
+    with tf.Graph().as_default():
+        @function.Defun(*[dtypes.int32]*3)
+        def arithmetic(m,x,c):
+            z = tf.add(tf.multiply(m, x), c)
+            return z
+
+        m = tf.constant(10)
+        x = tf.constant(20)
+        c = tf.constant(2)
+        spopFn = 
gen_functional_ops.StatefulPartitionedCall(args=[m,x,c],Tout=[tf.int32], 
f=arithmetic)
+
+        compare_tf_with_tvm([],[],'StatefulPartitionedCall:0', mode='vm', 
init_global_variables=True)
+
+
+def _test_spop_control_flow():
+    with tf.Graph().as_default():
+
+        @function.Defun(*[dtypes.float32] * 2)
+        def Body1(x, y):
+            with ops.device("/job:localhost/replica:0/task:0/device:CPU:0"):
+                z = math_ops.multiply(x, y)
+                i = 0
+                while i<10 :
+                    i +=1
+                    if i == 5:
+                        continue
+                    z = math_ops.multiply(x, y*i)
+            return z
+
+        op = gen_functional_ops.StatefulPartitionedCall(
+            args=[constant_op.constant(32.), constant_op.constant(100.)],
+            Tout=[dtypes.float32], f=Body1)
+        compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', mode='vm', 
init_global_variables=True)
+
+
+def _test_spop_variables():
+    with tf.Graph().as_default():
+        const1 = tf.constant(10)
+        const2 = tf.constant(20)
+        var1 = tf.Variable(const1, dtype=tf.int32)
+        var2 = tf.Variable(const2, dtype=tf.int32)
+
+        @function.Defun(tf.int32,tf.int32)
+        def Forward(x,y):
+            return tf.multiply(x,y)
+
+        z = 
gen_functional_ops.StatefulPartitionedCall(args=[var1,var2],Tout=[tf.int32], 
f=Forward)
+        compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', 
init_global_variables=True, mode="vm")
+
+
+def _test_spop_constants():
+    with tf.Graph().as_default():
+        @function.Defun(*[dtypes.int32] * 2)
+        def constantsFn(x, y):
+            vv = tf.constant([2, 3, 4], name="vv")
+            z = tf.add(vv + x, y)
+            return z
+
+        a = tf.constant(20000, name = "a")
+        b = tf.constant(40000, name = "b")
+        spopFn = gen_functional_ops.StatefulPartitionedCall(args=[a, b], 
Tout=[tf.int32], f=constantsFn)
+
+        compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', mode='vm', 
init_global_variables=True)
+
+
+def _test_spop_stateful():
+
+    tf.reset_default_graph()
+    with tf.Graph().as_default():
+
+        @tf.function
+        def FunctionWithStatefulOp_One(i):
+            b = tf.random.uniform(shape=[2, 4], maxval=10, dtype=tf.float32, 
seed=10)
+            y = tf.multiply(b, i)
+            return y
+
+        @tf.function
+        def FunctionWithStatefulOp(m, n):
+            a = tf.random.uniform(shape=[2, 4], maxval=10, dtype=tf.float32, 
seed = 10)
+            x = tf.multiply(a,m)
+            y = FunctionWithStatefulOp_One(n)
+            z = tf.multiply(x,y)
+            return z
+
+        op = FunctionWithStatefulOp(constant_op.constant(1.), 
constant_op.constant(2.))
+        with pytest.raises(Exception) as execinfo:
+            compare_tf_with_tvm([], [], [op.name], init_global_variables=True, 
mode="vm")
+        assert execinfo.value.args[0].startswith("Found a stateful operator in 
this graph")
+
+
+def _test_spop_device_assignment():
+
+    tf.reset_default_graph()
+    with tf.Graph().as_default():
+
+        def fun1(a):
+            with ops.device("/GPU:0"):
+                return tf.multiply(a,a)
+
+        def fun2(b):
+            with ops.device("/job:localhost/replica:0/task:0/device:CPU:1"):
+                return tf.multiply(b,b)
+
+        @function.Defun(dtypes.float32, dtypes.float32, func_name="Fun3")
+        def fun3(x,y):
+            with ops.device("/CPU:0"):
+                x = fun2(x)
+            with ops.device("/job:localhost/replica:0/task:0/device:CPU:2"):
+                y = fun1(y)
+            with ops.device("/job:localhost/replica:0/task:0/device:CPU:3"):
+                z = tf.add(x,y)
+                return z
+
+        op = 
gen_functional_ops.StatefulPartitionedCall(args=[tf.constant(10.5),tf.constant(20.4)],
+                                                        Tout=[dtypes.float32], 
f=fun3)
+        with pytest.raises(Exception) as execinfo:
+            compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0',
+                                mode='vm', init_global_variables=True)
+        assert execinfo.value.args[0].startswith("Found inconsistent Device 
assignment")
+
+
+def _test_spop_resource_variables():
+    tf.reset_default_graph()
+    with tf.Graph().as_default():
+
+        const1 = tf.constant(10)
+        const2 = tf.constant(20)
+        var1 = tf.Variable(const1, dtype=tf.int32, use_resource=True)
+        var2 = tf.Variable(const2, dtype=tf.int32, use_resource=True)
+
+        @tf.function
+        def resourceVariablesTest(x, y):
+            return tf.multiply(x, y)
+
+        op = resourceVariablesTest(var1,var2)
+        with pytest.raises(Exception) as execinfo:
+            compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0',
+                                mode='vm', init_global_variables=True)
+        assert execinfo.value.args[0].startswith("Found a stateful operator in 
this graph")
+
+def test_forward_spop():
+    # This test case is to test that TVM rejects any TF stateful operations

Review comment:
       Better to move these descriptions into the test case.

##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -2896,15 +2903,29 @@ def _parse_import_prerequisites(self, graph):
         """
         missing_operators = set()
         for node in graph.node:
+            try:
+                from tensorflow.python.framework import op_def_registry
+            except ImportError as e:
+                raise ImportError(
+                    "Unable to import tensorflow which is required 
{}".format(e))
+            getOpDef = op_def_registry._registered_ops.get if 
hasattr(op_def_registry,\
+                        "_registered_ops") else op_def_registry.get
+            op_def = getOpDef(node.op)
             if node.op == "Placeholder" or node.op == 'PlaceholderWithDefault':
                 pass
             elif node.op == "Const":
                 pass
+            elif node.op in ["PartitionedCall", "StatefulPartitionedCall"]:

Review comment:
       Should note TF version here to before supporting.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to