Enable unloading model using external tvmdp library. Updated
layer unload callback to support multiple layers.

Signed-off-by: Srikanth Yalavarthi <syalavar...@marvell.com>
Signed-off-by: Anup Prabhu <apra...@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_ops.c | 20 ++++++++++++++++++++
 drivers/ml/cnxk/cnxk_ml_ops.c  |  9 +++++++--
 drivers/ml/cnxk/mvtvm_ml_ops.c | 28 ++++++++++++++++++++++++++++
 drivers/ml/cnxk/mvtvm_ml_ops.h |  1 +
 4 files changed, 56 insertions(+), 2 deletions(-)

diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index 79217165cd5..85d0a9e18bb 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -725,7 +725,9 @@ cn10k_ml_layer_unload(void *device, uint16_t model_id, 
const char *layer_name)
        uint16_t layer_id = 0;
        int ret;
 
+#ifndef RTE_MLDEV_CNXK_ENABLE_MVTVM
        PLT_SET_USED(layer_name);
+#endif
 
        cnxk_mldev = (struct cnxk_ml_dev *)device;
        if (cnxk_mldev == NULL) {
@@ -739,6 +741,24 @@ cn10k_ml_layer_unload(void *device, uint16_t model_id, 
const char *layer_name)
                return -EINVAL;
        }
 
+#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM
+       if (model->type == ML_CNXK_MODEL_TYPE_TVM) {
+               for (layer_id = 0; layer_id < 
model->mvtvm.metadata.model.nb_layers; layer_id++) {
+                       if (strcmp(model->layer[layer_id].name, layer_name) == 
0)
+                               break;
+               }
+
+               if (layer_id == model->mvtvm.metadata.model.nb_layers) {
+                       plt_err("Invalid layer name: %s", layer_name);
+                       return -EINVAL;
+               }
+
+               if (model->layer[layer_id].type != ML_CNXK_LAYER_TYPE_MRVL) {
+                       plt_err("Invalid layer name / type: %s", layer_name);
+                       return -EINVAL;
+               }
+       }
+#endif
        layer = &model->layer[layer_id];
 
        snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u_%u", 
CN10K_ML_LAYER_MEMZONE_NAME,
diff --git a/drivers/ml/cnxk/cnxk_ml_ops.c b/drivers/ml/cnxk/cnxk_ml_ops.c
index a99367089b4..d8eadcb8121 100644
--- a/drivers/ml/cnxk/cnxk_ml_ops.c
+++ b/drivers/ml/cnxk/cnxk_ml_ops.c
@@ -1182,7 +1182,7 @@ cnxk_ml_model_unload(struct rte_ml_dev *dev, uint16_t 
model_id)
        struct cnxk_ml_model *model;
 
        char str[RTE_MEMZONE_NAMESIZE];
-       int ret;
+       int ret = 0;
 
        if (dev == NULL)
                return -EINVAL;
@@ -1200,7 +1200,12 @@ cnxk_ml_model_unload(struct rte_ml_dev *dev, uint16_t 
model_id)
                return -EBUSY;
        }
 
-       ret = cn10k_ml_model_unload(cnxk_mldev, model);
+       if (model->type == ML_CNXK_MODEL_TYPE_GLOW)
+               ret = cn10k_ml_model_unload(cnxk_mldev, model);
+#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM
+       else
+               ret = mvtvm_ml_model_unload(cnxk_mldev, model);
+#endif
        if (ret != 0)
                return ret;
 
diff --git a/drivers/ml/cnxk/mvtvm_ml_ops.c b/drivers/ml/cnxk/mvtvm_ml_ops.c
index a783e16e6eb..1edbfb0dcc3 100644
--- a/drivers/ml/cnxk/mvtvm_ml_ops.c
+++ b/drivers/ml/cnxk/mvtvm_ml_ops.c
@@ -191,3 +191,31 @@ mvtvm_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct 
rte_ml_model_params *
 
        return ret;
 }
+
+int
+mvtvm_ml_model_unload(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model 
*model)
+{
+       char str[RTE_MEMZONE_NAMESIZE];
+       const struct plt_memzone *mz;
+       int ret;
+
+       RTE_SET_USED(cnxk_mldev);
+
+       /* Initialize model in TVMDP */
+       ret = tvmdp_model_unload(model->model_id);
+       if (ret != 0) {
+               plt_err("TVMDP: Model unload failed, model_id = %u, error = 
%d", model->model_id,
+                       ret);
+               return ret;
+       }
+
+       snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u", 
MVTVM_ML_MODEL_MEMZONE_NAME, model->model_id);
+       mz = rte_memzone_lookup(str);
+       if (mz == NULL) {
+               plt_err("Memzone lookup failed for TVM model: model_id = %u, mz 
= %s",
+                       model->model_id, str);
+               return -EINVAL;
+       }
+
+       return plt_memzone_free(mz);
+}
diff --git a/drivers/ml/cnxk/mvtvm_ml_ops.h b/drivers/ml/cnxk/mvtvm_ml_ops.h
index ca8f57992da..8b4db20fe94 100644
--- a/drivers/ml/cnxk/mvtvm_ml_ops.h
+++ b/drivers/ml/cnxk/mvtvm_ml_ops.h
@@ -14,5 +14,6 @@ int mvtvm_ml_dev_configure(struct cnxk_ml_dev *cnxk_mldev, 
const struct rte_ml_d
 int mvtvm_ml_dev_close(struct cnxk_ml_dev *cnxk_mldev);
 int mvtvm_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct 
rte_ml_model_params *params,
                        struct cnxk_ml_model *model);
+int mvtvm_ml_model_unload(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model 
*model);
 
 #endif /* _MVTVM_ML_OPS_H_ */
-- 
2.41.0

Reply via email to