Added support to parse TVM model objects from the model
archive buffer. Added support to check for all expected
objects and copy TVM model objects to internal buffers.

Signed-off-by: Srikanth Yalavarthi <syalavar...@marvell.com>
Signed-off-by: Anup Prabhu <apra...@marvell.com>
---
 drivers/ml/cnxk/cnxk_ml_ops.c    |  5 ++-
 drivers/ml/cnxk/mvtvm_ml_model.c | 57 +++++++++++++++++++++++++++++
 drivers/ml/cnxk/mvtvm_ml_model.h |  2 ++
 drivers/ml/cnxk/mvtvm_ml_ops.c   | 62 ++++++++++++++++++++++++++++++++
 drivers/ml/cnxk/mvtvm_ml_ops.h   |  3 ++
 drivers/ml/cnxk/mvtvm_ml_stubs.c | 11 ++++++
 drivers/ml/cnxk/mvtvm_ml_stubs.h |  3 ++
 7 files changed, 142 insertions(+), 1 deletion(-)

diff --git a/drivers/ml/cnxk/cnxk_ml_ops.c b/drivers/ml/cnxk/cnxk_ml_ops.c
index ebc78e36e9..85b37161d2 100644
--- a/drivers/ml/cnxk/cnxk_ml_ops.c
+++ b/drivers/ml/cnxk/cnxk_ml_ops.c
@@ -1079,7 +1079,10 @@ cnxk_ml_model_load(struct rte_ml_dev *dev, struct 
rte_ml_model_params *params, u
                model, PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_model), 
dev_info.align_size));
        dev->data->models[lcl_model_id] = model;
 
-       ret = cn10k_ml_model_load(cnxk_mldev, params, model);
+       if (type == ML_CNXK_MODEL_TYPE_GLOW)
+               ret = cn10k_ml_model_load(cnxk_mldev, params, model);
+       else
+               ret = mvtvm_ml_model_load(cnxk_mldev, params, model);
        if (ret != 0)
                goto error;
 
diff --git a/drivers/ml/cnxk/mvtvm_ml_model.c b/drivers/ml/cnxk/mvtvm_ml_model.c
index ab5f8baa67..4c9a080c05 100644
--- a/drivers/ml/cnxk/mvtvm_ml_model.c
+++ b/drivers/ml/cnxk/mvtvm_ml_model.c
@@ -53,3 +53,60 @@ mvtvm_ml_model_type_get(struct rte_ml_model_params *params)
 
        return ML_CNXK_MODEL_TYPE_TVM;
 }
+
+int
+mvtvm_ml_model_blob_parse(struct rte_ml_model_params *params, struct 
mvtvm_ml_model_object *object)
+{
+       bool object_found[ML_MVTVM_MODEL_OBJECT_MAX] = {false, false, false};
+       struct archive_entry *entry;
+       struct archive *a;
+       uint8_t i;
+       int ret;
+
+       /* Open archive */
+       a = archive_read_new();
+       archive_read_support_filter_all(a);
+       archive_read_support_format_all(a);
+
+       ret = archive_read_open_memory(a, params->addr, params->size);
+       if (ret != ARCHIVE_OK)
+               return archive_errno(a);
+
+       /* Read archive */
+       while (archive_read_next_header(a, &entry) == ARCHIVE_OK) {
+               for (i = 0; i < ML_MVTVM_MODEL_OBJECT_MAX; i++) {
+                       if (!object_found[i] &&
+                           (strcmp(archive_entry_pathname(entry), 
mvtvm_object_list[i]) == 0)) {
+                               memcpy(object[i].name, mvtvm_object_list[i], 
RTE_ML_STR_MAX);
+                               object[i].size = archive_entry_size(entry);
+                               object[i].buffer = rte_malloc(NULL, 
object[i].size, 0);
+
+                               if (archive_read_data(a, object[i].buffer, 
object[i].size) !=
+                                   object[i].size) {
+                                       plt_err("Failed to read object from 
model archive: %s",
+                                               object[i].name);
+                                       goto error;
+                               }
+                               object_found[i] = true;
+                       }
+               }
+               archive_read_data_skip(a);
+       }
+
+       /* Check if all objects are parsed */
+       for (i = 0; i < ML_MVTVM_MODEL_OBJECT_MAX; i++) {
+               if (!object_found[i]) {
+                       plt_err("Object %s not found in archive!\n", 
mvtvm_object_list[i]);
+                       goto error;
+               }
+       }
+       return 0;
+
+error:
+       for (i = 0; i < ML_MVTVM_MODEL_OBJECT_MAX; i++) {
+               if (object[i].buffer != NULL)
+                       rte_free(object[i].buffer);
+       }
+
+       return -EINVAL;
+}
diff --git a/drivers/ml/cnxk/mvtvm_ml_model.h b/drivers/ml/cnxk/mvtvm_ml_model.h
index b6162fceec..b11b66f495 100644
--- a/drivers/ml/cnxk/mvtvm_ml_model.h
+++ b/drivers/ml/cnxk/mvtvm_ml_model.h
@@ -44,5 +44,7 @@ struct mvtvm_ml_model_data {
 };
 
 enum cnxk_ml_model_type mvtvm_ml_model_type_get(struct rte_ml_model_params 
*params);
+int mvtvm_ml_model_blob_parse(struct rte_ml_model_params *params,
+                             struct mvtvm_ml_model_object *object);
 
 #endif /* _MVTVM_ML_MODEL_H_ */
diff --git a/drivers/ml/cnxk/mvtvm_ml_ops.c b/drivers/ml/cnxk/mvtvm_ml_ops.c
index 88c6d5a864..e2413b6b15 100644
--- a/drivers/ml/cnxk/mvtvm_ml_ops.c
+++ b/drivers/ml/cnxk/mvtvm_ml_ops.c
@@ -8,8 +8,12 @@
 #include <rte_mldev_pmd.h>
 
 #include "cnxk_ml_dev.h"
+#include "cnxk_ml_model.h"
 #include "cnxk_ml_ops.h"
 
+/* ML model macros */
+#define MVTVM_ML_MODEL_MEMZONE_NAME "ml_mvtvm_model_mz"
+
 int
 mvtvm_ml_dev_configure(struct cnxk_ml_dev *cnxk_mldev, const struct 
rte_ml_dev_config *conf)
 {
@@ -39,3 +43,61 @@ mvtvm_ml_dev_close(struct cnxk_ml_dev *cnxk_mldev)
 
        return ret;
 }
+
+int
+mvtvm_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_model_params 
*params,
+                   struct cnxk_ml_model *model)
+{
+       struct mvtvm_ml_model_object object[ML_MVTVM_MODEL_OBJECT_MAX];
+       char str[RTE_MEMZONE_NAMESIZE];
+       const struct plt_memzone *mz;
+       size_t model_object_size = 0;
+       uint64_t mz_size = 0;
+       int ret;
+
+       RTE_SET_USED(cnxk_mldev);
+
+       ret = mvtvm_ml_model_blob_parse(params, object);
+       if (ret != 0)
+               return ret;
+
+       model_object_size = RTE_ALIGN_CEIL(object[0].size, 
RTE_CACHE_LINE_MIN_SIZE) +
+                           RTE_ALIGN_CEIL(object[1].size, 
RTE_CACHE_LINE_MIN_SIZE) +
+                           RTE_ALIGN_CEIL(object[2].size, 
RTE_CACHE_LINE_MIN_SIZE);
+       mz_size += model_object_size;
+
+       /* Allocate memzone for model object */
+       snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u", 
MVTVM_ML_MODEL_MEMZONE_NAME, model->model_id);
+       mz = plt_memzone_reserve_aligned(str, mz_size, 0, ML_CN10K_ALIGN_SIZE);
+       if (!mz) {
+               plt_err("plt_memzone_reserve failed : %s", str);
+               return -ENOMEM;
+       }
+
+       /* Copy mod.so */
+       model->mvtvm.object.so.addr = mz->addr;
+       model->mvtvm.object.so.size = object[0].size;
+       rte_memcpy(model->mvtvm.object.so.name, object[0].name, 
TVMDP_NAME_STRLEN);
+       rte_memcpy(model->mvtvm.object.so.addr, object[0].buffer, 
object[0].size);
+       rte_free(object[0].buffer);
+
+       /* Copy mod.json */
+       model->mvtvm.object.json.addr =
+               RTE_PTR_ADD(model->mvtvm.object.so.addr,
+                           RTE_ALIGN_CEIL(model->mvtvm.object.so.size, 
RTE_CACHE_LINE_MIN_SIZE));
+       model->mvtvm.object.json.size = object[1].size;
+       rte_memcpy(model->mvtvm.object.json.name, object[1].name, 
TVMDP_NAME_STRLEN);
+       rte_memcpy(model->mvtvm.object.json.addr, object[1].buffer, 
object[1].size);
+       rte_free(object[1].buffer);
+
+       /* Copy mod.params */
+       model->mvtvm.object.params.addr =
+               RTE_PTR_ADD(model->mvtvm.object.json.addr,
+                           RTE_ALIGN_CEIL(model->mvtvm.object.json.size, 
RTE_CACHE_LINE_MIN_SIZE));
+       model->mvtvm.object.params.size = object[2].size;
+       rte_memcpy(model->mvtvm.object.params.name, object[2].name, 
TVMDP_NAME_STRLEN);
+       rte_memcpy(model->mvtvm.object.params.addr, object[2].buffer, 
object[2].size);
+       rte_free(object[2].buffer);
+
+       return 0;
+}
diff --git a/drivers/ml/cnxk/mvtvm_ml_ops.h b/drivers/ml/cnxk/mvtvm_ml_ops.h
index 305b4681ed..6607537599 100644
--- a/drivers/ml/cnxk/mvtvm_ml_ops.h
+++ b/drivers/ml/cnxk/mvtvm_ml_ops.h
@@ -12,8 +12,11 @@
 #include <rte_mldev.h>
 
 struct cnxk_ml_dev;
+struct cnxk_ml_model;
 
 int mvtvm_ml_dev_configure(struct cnxk_ml_dev *cnxk_mldev, const struct 
rte_ml_dev_config *conf);
 int mvtvm_ml_dev_close(struct cnxk_ml_dev *cnxk_mldev);
+int mvtvm_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct 
rte_ml_model_params *params,
+                       struct cnxk_ml_model *model);
 
 #endif /* _MVTVM_ML_OPS_H_ */
diff --git a/drivers/ml/cnxk/mvtvm_ml_stubs.c b/drivers/ml/cnxk/mvtvm_ml_stubs.c
index a7352840a6..7f3b3abb2e 100644
--- a/drivers/ml/cnxk/mvtvm_ml_stubs.c
+++ b/drivers/ml/cnxk/mvtvm_ml_stubs.c
@@ -33,3 +33,14 @@ mvtvm_ml_dev_close(struct cnxk_ml_dev *cnxk_mldev)
 
        return 0;
 }
+
+int
+mvtvm_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_model_params 
*params,
+                   struct cnxk_ml_model *model)
+{
+       RTE_SET_USED(cnxk_mldev);
+       RTE_SET_USED(params);
+       RTE_SET_USED(model);
+
+       return -EINVAL;
+}
diff --git a/drivers/ml/cnxk/mvtvm_ml_stubs.h b/drivers/ml/cnxk/mvtvm_ml_stubs.h
index 467a9d39e5..4bb1772ef4 100644
--- a/drivers/ml/cnxk/mvtvm_ml_stubs.h
+++ b/drivers/ml/cnxk/mvtvm_ml_stubs.h
@@ -8,9 +8,12 @@
 #include <rte_mldev.h>
 
 struct cnxk_ml_dev;
+struct cnxk_ml_model;
 
 enum cnxk_ml_model_type mvtvm_ml_model_type_get(struct rte_ml_model_params 
*params);
 int mvtvm_ml_dev_configure(struct cnxk_ml_dev *cnxk_mldev, const struct 
rte_ml_dev_config *conf);
 int mvtvm_ml_dev_close(struct cnxk_ml_dev *cnxk_mldev);
+int mvtvm_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct 
rte_ml_model_params *params,
+                       struct cnxk_ml_model *model);
 
 #endif /* _MVTVM_ML_STUBS_H_ */
-- 
2.42.0

Reply via email to