Added support to fetch TVM model layer information and
update internal structures based on the layer information
Set callback functions for layer load and unload and
enable model loading using TVMDP library. Added support
to fetch full metadata after model load.

Signed-off-by: Srikanth Yalavarthi <syalavar...@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_ops.c   | 22 ++++++++-
 drivers/ml/cnxk/mvtvm_ml_model.h |  2 +
 drivers/ml/cnxk/mvtvm_ml_ops.c   | 83 ++++++++++++++++++++++++++++++++
 3 files changed, 106 insertions(+), 1 deletion(-)

diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index db18f320527..79217165cd5 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -508,8 +508,10 @@ cn10k_ml_layer_load(void *device, uint16_t model_id, const 
char *layer_name, uin
        int qp_id;
        int ret;
 
-       PLT_SET_USED(size);
+#ifndef RTE_MLDEV_CNXK_ENABLE_MVTVM
        PLT_SET_USED(layer_name);
+#endif
+       PLT_SET_USED(size);
 
        cnxk_mldev = (struct cnxk_ml_dev *)device;
        if (cnxk_mldev == NULL) {
@@ -523,6 +525,24 @@ cn10k_ml_layer_load(void *device, uint16_t model_id, const 
char *layer_name, uin
                return -EINVAL;
        }
 
+#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM
+       if (model->type == ML_CNXK_MODEL_TYPE_TVM) {
+               for (layer_id = 0; layer_id < 
model->mvtvm.metadata.model.nb_layers; layer_id++) {
+                       if (strcmp(model->layer[layer_id].name, layer_name) == 
0)
+                               break;
+               }
+
+               if (layer_id == model->mvtvm.metadata.model.nb_layers) {
+                       plt_err("Invalid layer name: %s", layer_name);
+                       return -EINVAL;
+               }
+
+               if (model->layer[layer_id].type != ML_CNXK_LAYER_TYPE_MRVL) {
+                       plt_err("Invalid layer name / type: %s", layer_name);
+                       return -EINVAL;
+               }
+       }
+#endif
        layer = &model->layer[layer_id];
 
        ret = cn10k_ml_model_metadata_check(buffer, size);
diff --git a/drivers/ml/cnxk/mvtvm_ml_model.h b/drivers/ml/cnxk/mvtvm_ml_model.h
index 73a45a91d66..6c38217c158 100644
--- a/drivers/ml/cnxk/mvtvm_ml_model.h
+++ b/drivers/ml/cnxk/mvtvm_ml_model.h
@@ -11,6 +11,8 @@
 
 #include "cnxk_ml_io.h"
 
+struct cnxk_ml_model;
+
 /* Maximum number of objects per model */
 #define ML_MVTVM_MODEL_OBJECT_MAX 3
 
diff --git a/drivers/ml/cnxk/mvtvm_ml_ops.c b/drivers/ml/cnxk/mvtvm_ml_ops.c
index 1bdd4515771..5c30bbf6b89 100644
--- a/drivers/ml/cnxk/mvtvm_ml_ops.c
+++ b/drivers/ml/cnxk/mvtvm_ml_ops.c
@@ -9,6 +9,8 @@
 #include <rte_mldev.h>
 #include <rte_mldev_pmd.h>
 
+#include "cn10k_ml_ops.h"
+
 #include "mvtvm_ml_model.h"
 #include "mvtvm_ml_ops.h"
 
@@ -53,9 +55,13 @@ mvtvm_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct 
rte_ml_model_params *
                    struct cnxk_ml_model *model)
 {
        struct mvtvm_ml_model_object object[ML_MVTVM_MODEL_OBJECT_MAX];
+       struct tvmrt_glow_callback *callback;
        char str[RTE_MEMZONE_NAMESIZE];
        const struct plt_memzone *mz;
        size_t model_object_size = 0;
+       uint16_t nb_mrvl_layers;
+       uint16_t nb_llvm_layers;
+       uint8_t layer_id = 0;
        uint64_t mz_size = 0;
        int ret;
 
@@ -103,5 +109,82 @@ mvtvm_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct 
rte_ml_model_params *
        rte_memcpy(model->mvtvm.object.params.addr, object[2].buffer, 
object[2].size);
        rte_free(object[2].buffer);
 
+       /* Get metadata - stage 1 */
+       ret = tvmdp_model_metadata_get_stage1(model->mvtvm.object.json.addr,
+                                             model->mvtvm.object.json.size,
+                                             &model->mvtvm.metadata);
+       if (ret != 0) {
+               plt_err("TVMDP: Failed to parse metadata - stage 1, model_id = 
%u, error = %d",
+                       model->model_id, ret);
+               goto error;
+       }
+
+       /* Set model fields */
+       plt_strlcpy(model->name, model->mvtvm.metadata.model.name, 
TVMDP_NAME_STRLEN);
+       model->batch_size = 1;
+       model->nb_layers = model->mvtvm.metadata.model.nb_layers;
+
+       /* Update layer info */
+       nb_mrvl_layers = 0;
+       nb_llvm_layers = 0;
+       for (layer_id = 0; layer_id < model->mvtvm.metadata.model.nb_layers; 
layer_id++) {
+               strncpy(model->layer[layer_id].name,
+                       model->mvtvm.metadata.model.layer[layer_id].name, 
TVMDP_NAME_STRLEN);
+               if (strcmp(model->mvtvm.metadata.model.layer[layer_id].type, 
"mrvl") == 0 ||
+                   strcmp(model->mvtvm.metadata.model.layer[layer_id].type, 
"MRVL") == 0) {
+                       model->layer[layer_id].type = ML_CNXK_LAYER_TYPE_MRVL;
+                       nb_mrvl_layers++;
+               } else if 
(strcmp(model->mvtvm.metadata.model.layer[layer_id].type, "llvm") == 0 ||
+                          
strcmp(model->mvtvm.metadata.model.layer[layer_id].type, "LLVM") == 0) {
+                       model->layer[layer_id].type = ML_CNXK_LAYER_TYPE_LLVM;
+                       nb_llvm_layers++;
+               }
+       }
+
+       if ((nb_llvm_layers == 0) && (nb_mrvl_layers == 0)) {
+               plt_err("Invalid model, nb_llvm_layers = %u, nb_mrvl_layers = 
%u", nb_llvm_layers,
+                       nb_mrvl_layers);
+               goto error;
+       }
+
+       /* Set model subtype */
+       if ((nb_llvm_layers == 0) && (nb_mrvl_layers == 1))
+               model->subtype = ML_CNXK_MODEL_SUBTYPE_TVM_MRVL;
+       else if ((nb_llvm_layers > 0) && (nb_mrvl_layers == 0))
+               model->subtype = ML_CNXK_MODEL_SUBTYPE_TVM_LLVM;
+       else
+               model->subtype = ML_CNXK_MODEL_SUBTYPE_TVM_HYBRID;
+
+       /* Set callback function array */
+       if (model->subtype != ML_CNXK_MODEL_SUBTYPE_TVM_LLVM) {
+               callback = &model->mvtvm.cb;
+               callback->tvmrt_glow_layer_load = cn10k_ml_layer_load;
+               callback->tvmrt_glow_layer_unload = cn10k_ml_layer_unload;
+       } else {
+               callback = NULL;
+       }
+
+       /* Initialize model in TVMDP */
+       ret = tvmdp_model_load(cnxk_mldev, model->model_id, (void 
*)(&model->mvtvm.object),
+                              callback);
+       if (ret != 0) {
+               plt_err("TVMDP: Model load failed, model_id = %u, error = %d", 
model->model_id,
+                       ret);
+               goto error;
+       }
+
+       /* Get model metadata - stage 2 */
+       ret = tvmdp_model_metadata_get_stage2(model->model_id, 
&model->mvtvm.metadata);
+       if (ret != 0) {
+               plt_err("TVMDP: Failed to get metadata, model_id = %u, error = 
%d\n",
+                       model->model_id, ret);
+               goto error;
+       }
+
        return 0;
+
+error:
+       rte_memzone_free(mz);
+
+       return ret;
 }
-- 
2.41.0

Reply via email to