This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 02fe0c5f0d [Frontend][ArgParse] Pass default values to target 
compiler(#13264) (#17014)
02fe0c5f0d is described below

commit 02fe0c5f0d80fa3d67868066cfc1d5cf07c3ec05
Author: MNGanesan <mngane...@yahoo.co.uk>
AuthorDate: Wed Jun 26 15:58:11 2024 +0530

    [Frontend][ArgParse] Pass default values to target compiler(#13264) (#17014)
    
    * [Frontend][ArgParse] Pass default values to target compiler(#13264)
    
        BYOC Compiler's Config node defines the target compiler's
        command line options, along with default values. This change
        extract the default values from config node, while constructing
        target options for codegen/target compiler.
        Added test case for this feature as well.
    
    Signed-off-by: M N Ganesan <muthusamy...@marvell.com>
    
    * [Frontend][ArgParse] Pass default values to target compiler(#13264)
    
        BYOC Compiler's Config node defines the target compiler's
        command line options, along with default values. This change
        extract the default values from config node, while constructing
        target options for codegen/target compiler.
        Added test case for this feature as well.
    
    Signed-off-by: M N Ganesan <muthusamy...@marvell.com>
    
    * Lint Fix
    
    Signed-off-by: M N Ganesan <muthusamy...@marvell.com>
    
    ---------
    
    Signed-off-by: M N Ganesan <muthusamy...@marvell.com>
    Co-authored-by: M N Ganesan <muthusamy...@marvell.com>
---
 python/tvm/driver/tvmc/composite_target.py      |  8 ++++++++
 python/tvm/driver/tvmc/target.py                | 19 ++++++++++++++++++-
 tests/python/driver/tvmc/test_target_options.py | 16 ++++++++++++++++
 3 files changed, 42 insertions(+), 1 deletion(-)

diff --git a/python/tvm/driver/tvmc/composite_target.py 
b/python/tvm/driver/tvmc/composite_target.py
index cfcf5a14c1..6c51dd1689 100644
--- a/python/tvm/driver/tvmc/composite_target.py
+++ b/python/tvm/driver/tvmc/composite_target.py
@@ -51,34 +51,42 @@ logger = logging.getLogger("TVMC")
 REGISTERED_CODEGEN = {
     "compute-library": {
         "config_key": None,
+        "pass_default": False,
         "pass_pipeline": partition_for_arm_compute_lib,
     },
     "cmsis-nn": {
         "config_key": "relay.ext.cmsisnn.options",
+        "pass_default": False,
         "pass_pipeline": partition_for_cmsisnn,
     },
     "ethos-n": {
         "config_key": "relay.ext.ethos-n.options",
+        "pass_default": False,
         "pass_pipeline": partition_for_ethosn,
     },
     "ethos-u": {
         "config_key": "relay.ext.ethos-u.options",
+        "pass_default": False,
         "pass_pipeline": partition_for_ethosu,
     },
     "bnns": {
         "config_key": None,
+        "pass_default": False,
         "pass_pipeline": partition_for_bnns,
     },
     "vitis-ai": {
         "config_key": "relay.ext.vitis_ai.options",
+        "pass_default": False,
         "pass_pipeline": partition_for_vitis_ai,
     },
     "clml": {
         "config_key": None,
+        "pass_default": False,
         "pass_pipeline": partition_for_clml,
     },
     "mrvl": {
         "config_key": "relay.ext.mrvl.options",
+        "pass_default": True,
         "pass_pipeline": partition_for_mrvl,
     },
 }
diff --git a/python/tvm/driver/tvmc/target.py b/python/tvm/driver/tvmc/target.py
index ec8215184e..b5eee04823 100644
--- a/python/tvm/driver/tvmc/target.py
+++ b/python/tvm/driver/tvmc/target.py
@@ -69,10 +69,28 @@ def _generate_codegen_args(parser, codegen_name):
             for tvm_type, python_type in INTERNAL_TO_NATIVE_TYPE.items():
                 if field.type_info.startswith(tvm_type):
                     target_option = field.name
+                    default_value = None
+
+                    # Retrieve the default value string from attrs(field) of 
config node
+                    # Eg: "default=target_cpu_name"
+                    target_option_default_str = 
field.type_info.split("default=")[1]
+
+                    # Extract the defalut value based on the tvm type
+                    if target_option_default_str and tvm_type == 
"runtime.String":
+                        default_value = target_option_default_str
+                    elif target_option_default_str and tvm_type == "IntImm":
+                        # Extract the numeric value from the python Int 
string, Eg: T.int64(8)
+                        str_slice = target_option_default_str.split("(")[1]
+                        default_value = str_slice.split(")")[0]
+
+                    if codegen["pass_default"] is False:
+                        default_value = None
+
                     target_group.add_argument(
                         f"--target-{codegen_name}-{target_option}",
                         type=python_type,
                         help=field.description,
+                        default=default_value,
                     )
 
 
@@ -133,7 +151,6 @@ def reconstruct_target_args(args):
         codegen_options = _reconstruct_codegen_args(args, codegen_name)
         if codegen_options:
             reconstructed[codegen_name] = codegen_options
-
     return reconstructed
 
 
diff --git a/tests/python/driver/tvmc/test_target_options.py 
b/tests/python/driver/tvmc/test_target_options.py
index 194047e7a6..d98a8d588e 100644
--- a/tests/python/driver/tvmc/test_target_options.py
+++ b/tests/python/driver/tvmc/test_target_options.py
@@ -72,6 +72,21 @@ def test_target_to_argparse_for_mrvl_hybrid():
     assert parsed.target_mrvl_mcpu == "cnf10kb"
 
 
+@tvm.testing.requires_mrvl
+def test_default_arg_for_mrvl_hybrid():
+    parser = argparse.ArgumentParser()
+    generate_target_args(parser)
+    parsed, _ = parser.parse_known_args(
+        [
+            "--target=mrvl, llvm",
+        ]
+    )
+    assert parsed.target == "mrvl, llvm"
+    assert parsed.target_mrvl_mcpu == "cn10ka"
+    assert parsed.target_mrvl_num_tiles == 8
+
+
+@tvm.testing.requires_cmsisnn
 def test_mapping_target_args():
     parser = argparse.ArgumentParser()
     generate_target_args(parser)
@@ -129,6 +144,7 @@ def test_ethosu_compiler_attrs():
     }
 
 
+@tvm.testing.requires_cmsisnn
 def test_skip_target_from_codegen():
     parser = argparse.ArgumentParser()
     generate_target_args(parser)

Reply via email to