Added support to parse TVM model objects from the model
archive buffer. Added support to check for all expected
objects and copy TVM model objects to internal buffers.

Signed-off-by: Srikanth Yalavarthi <syalavar...@marvell.com>
Signed-off-by: Anup Prabhu <apra...@marvell.com>
---
 drivers/ml/cnxk/cnxk_ml_ops.c    | 14 +++++--
 drivers/ml/cnxk/mvtvm_ml_model.c | 62 +++++++++++++++++++++++++++++++
 drivers/ml/cnxk/mvtvm_ml_model.h |  3 ++
 drivers/ml/cnxk/mvtvm_ml_ops.c   | 63 ++++++++++++++++++++++++++++++++
 drivers/ml/cnxk/mvtvm_ml_ops.h   |  3 ++
 5 files changed, 142 insertions(+), 3 deletions(-)

diff --git a/drivers/ml/cnxk/cnxk_ml_ops.c b/drivers/ml/cnxk/cnxk_ml_ops.c
index cbb701f20bb..a99367089b4 100644
--- a/drivers/ml/cnxk/cnxk_ml_ops.c
+++ b/drivers/ml/cnxk/cnxk_ml_ops.c
@@ -1149,9 +1149,17 @@ cnxk_ml_model_load(struct rte_ml_dev *dev, struct 
rte_ml_model_params *params, u
                model, PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_model), 
dev_info.align_size));
        dev->data->models[lcl_model_id] = model;
 
-       ret = cn10k_ml_model_load(cnxk_mldev, params, model);
-       if (ret != 0)
-               goto error;
+       if (type == ML_CNXK_MODEL_TYPE_GLOW) {
+               ret = cn10k_ml_model_load(cnxk_mldev, params, model);
+               if (ret != 0)
+                       goto error;
+#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM
+       } else {
+               ret = mvtvm_ml_model_load(cnxk_mldev, params, model);
+               if (ret != 0)
+                       goto error;
+#endif
+       }
 
        plt_spinlock_init(&model->lock);
        model->state = ML_CNXK_MODEL_STATE_LOADED;
diff --git a/drivers/ml/cnxk/mvtvm_ml_model.c b/drivers/ml/cnxk/mvtvm_ml_model.c
index 64622675345..425a682209f 100644
--- a/drivers/ml/cnxk/mvtvm_ml_model.c
+++ b/drivers/ml/cnxk/mvtvm_ml_model.c
@@ -2,10 +2,72 @@
  * Copyright (c) 2023 Marvell.
  */
 
+#include <archive.h>
+#include <archive_entry.h>
+
 #include <rte_mldev.h>
 
+#include <roc_api.h>
+
 #include "mvtvm_ml_model.h"
 
 /* Objects list */
 char mvtvm_object_list[ML_MVTVM_MODEL_OBJECT_MAX][RTE_ML_STR_MAX] = {"mod.so", 
"mod.json",
                                                                     
"mod.params"};
+
+int
+mvtvm_ml_model_blob_parse(struct rte_ml_model_params *params, struct 
mvtvm_ml_model_object *object)
+{
+       bool object_found[ML_MVTVM_MODEL_OBJECT_MAX] = {false, false, false};
+       struct archive_entry *entry;
+       struct archive *a;
+       uint8_t i;
+       int ret;
+
+       /* Open archive */
+       a = archive_read_new();
+       archive_read_support_filter_all(a);
+       archive_read_support_format_all(a);
+
+       ret = archive_read_open_memory(a, params->addr, params->size);
+       if (ret != ARCHIVE_OK)
+               return archive_errno(a);
+
+       /* Read archive */
+       while (archive_read_next_header(a, &entry) == ARCHIVE_OK) {
+               for (i = 0; i < ML_MVTVM_MODEL_OBJECT_MAX; i++) {
+                       if (!object_found[i] &&
+                           (strcmp(archive_entry_pathname(entry), 
mvtvm_object_list[i]) == 0)) {
+                               memcpy(object[i].name, mvtvm_object_list[i], 
RTE_ML_STR_MAX);
+                               object[i].size = archive_entry_size(entry);
+                               object[i].buffer = rte_malloc(NULL, 
object[i].size, 0);
+
+                               if (archive_read_data(a, object[i].buffer, 
object[i].size) !=
+                                   object[i].size) {
+                                       plt_err("Failed to read object from 
model archive: %s",
+                                               object[i].name);
+                                       goto error;
+                               }
+                               object_found[i] = true;
+                       }
+               }
+               archive_read_data_skip(a);
+       }
+
+       /* Check if all objects are parsed */
+       for (i = 0; i < ML_MVTVM_MODEL_OBJECT_MAX; i++) {
+               if (!object_found[i]) {
+                       plt_err("Object %s not found in archive!\n", 
mvtvm_object_list[i]);
+                       goto error;
+               }
+       }
+       return 0;
+
+error:
+       for (i = 0; i < ML_MVTVM_MODEL_OBJECT_MAX; i++) {
+               if (object[i].buffer != NULL)
+                       rte_free(object[i].buffer);
+       }
+
+       return -EINVAL;
+}
diff --git a/drivers/ml/cnxk/mvtvm_ml_model.h b/drivers/ml/cnxk/mvtvm_ml_model.h
index 1f6b435be02..73a45a91d66 100644
--- a/drivers/ml/cnxk/mvtvm_ml_model.h
+++ b/drivers/ml/cnxk/mvtvm_ml_model.h
@@ -43,4 +43,7 @@ struct mvtvm_ml_model_data {
        struct cnxk_ml_io_info info;
 };
 
+int mvtvm_ml_model_blob_parse(struct rte_ml_model_params *params,
+                             struct mvtvm_ml_model_object *object);
+
 #endif /* _MVTVM_ML_MODEL_H_ */
diff --git a/drivers/ml/cnxk/mvtvm_ml_ops.c b/drivers/ml/cnxk/mvtvm_ml_ops.c
index 0e1fc527daa..1bdd4515771 100644
--- a/drivers/ml/cnxk/mvtvm_ml_ops.c
+++ b/drivers/ml/cnxk/mvtvm_ml_ops.c
@@ -9,9 +9,14 @@
 #include <rte_mldev.h>
 #include <rte_mldev_pmd.h>
 
+#include "mvtvm_ml_model.h"
 #include "mvtvm_ml_ops.h"
 
 #include "cnxk_ml_dev.h"
+#include "cnxk_ml_model.h"
+
+/* ML model macros */
+#define MVTVM_ML_MODEL_MEMZONE_NAME "ml_mvtvm_model_mz"
 
 int
 mvtvm_ml_dev_configure(struct cnxk_ml_dev *cnxk_mldev, const struct 
rte_ml_dev_config *conf)
@@ -42,3 +47,61 @@ mvtvm_ml_dev_close(struct cnxk_ml_dev *cnxk_mldev)
 
        return ret;
 }
+
+int
+mvtvm_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_model_params 
*params,
+                   struct cnxk_ml_model *model)
+{
+       struct mvtvm_ml_model_object object[ML_MVTVM_MODEL_OBJECT_MAX];
+       char str[RTE_MEMZONE_NAMESIZE];
+       const struct plt_memzone *mz;
+       size_t model_object_size = 0;
+       uint64_t mz_size = 0;
+       int ret;
+
+       RTE_SET_USED(cnxk_mldev);
+
+       ret = mvtvm_ml_model_blob_parse(params, object);
+       if (ret != 0)
+               return ret;
+
+       model_object_size = RTE_ALIGN_CEIL(object[0].size, 
RTE_CACHE_LINE_MIN_SIZE) +
+                           RTE_ALIGN_CEIL(object[1].size, 
RTE_CACHE_LINE_MIN_SIZE) +
+                           RTE_ALIGN_CEIL(object[2].size, 
RTE_CACHE_LINE_MIN_SIZE);
+       mz_size += model_object_size;
+
+       /* Allocate memzone for model object */
+       snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u", 
MVTVM_ML_MODEL_MEMZONE_NAME, model->model_id);
+       mz = plt_memzone_reserve_aligned(str, mz_size, 0, ML_CN10K_ALIGN_SIZE);
+       if (!mz) {
+               plt_err("plt_memzone_reserve failed : %s", str);
+               return -ENOMEM;
+       }
+
+       /* Copy mod.so */
+       model->mvtvm.object.so.addr = mz->addr;
+       model->mvtvm.object.so.size = object[0].size;
+       rte_memcpy(model->mvtvm.object.so.name, object[0].name, 
TVMDP_NAME_STRLEN);
+       rte_memcpy(model->mvtvm.object.so.addr, object[0].buffer, 
object[0].size);
+       rte_free(object[0].buffer);
+
+       /* Copy mod.json */
+       model->mvtvm.object.json.addr =
+               RTE_PTR_ADD(model->mvtvm.object.so.addr,
+                           RTE_ALIGN_CEIL(model->mvtvm.object.so.size, 
RTE_CACHE_LINE_MIN_SIZE));
+       model->mvtvm.object.json.size = object[1].size;
+       rte_memcpy(model->mvtvm.object.json.name, object[1].name, 
TVMDP_NAME_STRLEN);
+       rte_memcpy(model->mvtvm.object.json.addr, object[1].buffer, 
object[1].size);
+       rte_free(object[1].buffer);
+
+       /* Copy mod.params */
+       model->mvtvm.object.params.addr =
+               RTE_PTR_ADD(model->mvtvm.object.json.addr,
+                           RTE_ALIGN_CEIL(model->mvtvm.object.json.size, 
RTE_CACHE_LINE_MIN_SIZE));
+       model->mvtvm.object.params.size = object[2].size;
+       rte_memcpy(model->mvtvm.object.params.name, object[2].name, 
TVMDP_NAME_STRLEN);
+       rte_memcpy(model->mvtvm.object.params.addr, object[2].buffer, 
object[2].size);
+       rte_free(object[2].buffer);
+
+       return 0;
+}
diff --git a/drivers/ml/cnxk/mvtvm_ml_ops.h b/drivers/ml/cnxk/mvtvm_ml_ops.h
index 988f3a1fd5e..ca8f57992da 100644
--- a/drivers/ml/cnxk/mvtvm_ml_ops.h
+++ b/drivers/ml/cnxk/mvtvm_ml_ops.h
@@ -8,8 +8,11 @@
 #include <rte_mldev.h>
 
 struct cnxk_ml_dev;
+struct cnxk_ml_model;
 
 int mvtvm_ml_dev_configure(struct cnxk_ml_dev *cnxk_mldev, const struct 
rte_ml_dev_config *conf);
 int mvtvm_ml_dev_close(struct cnxk_ml_dev *cnxk_mldev);
+int mvtvm_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct 
rte_ml_model_params *params,
+                       struct cnxk_ml_model *model);
 
 #endif /* _MVTVM_ML_OPS_H_ */
-- 
2.41.0

Reply via email to