Added cnxk wrapper function to update model params and
fetch model info.

Signed-off-by: Srikanth Yalavarthi <syalavar...@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_ops.c | 38 ++++++---------------------
 drivers/ml/cnxk/cn10k_ml_ops.h |  5 ++--
 drivers/ml/cnxk/cnxk_ml_ops.c  | 48 ++++++++++++++++++++++++++++++++--
 3 files changed, 56 insertions(+), 35 deletions(-)

diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index 40f484158a..3ff82829f0 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -1835,45 +1835,23 @@ cn10k_ml_model_stop(struct cnxk_ml_dev *cnxk_mldev, 
struct cnxk_ml_model *model)
 }
 
 int
-cn10k_ml_model_info_get(struct rte_ml_dev *dev, uint16_t model_id,
-                       struct rte_ml_model_info *model_info)
+cn10k_ml_model_params_update(struct cnxk_ml_dev *cnxk_mldev, struct 
cnxk_ml_model *model,
+                            void *buffer)
 {
-       struct cnxk_ml_model *model;
-
-       model = dev->data->models[model_id];
-
-       if (model == NULL) {
-               plt_err("Invalid model_id = %u", model_id);
-               return -EINVAL;
-       }
-
-       rte_memcpy(model_info, model->info, sizeof(struct rte_ml_model_info));
-       model_info->input_info = ((struct rte_ml_model_info 
*)model->info)->input_info;
-       model_info->output_info = ((struct rte_ml_model_info 
*)model->info)->output_info;
-
-       return 0;
-}
-
-int
-cn10k_ml_model_params_update(struct rte_ml_dev *dev, uint16_t model_id, void 
*buffer)
-{
-       struct cnxk_ml_model *model;
-
-       model = dev->data->models[model_id];
+       struct cnxk_ml_layer *layer;
 
-       if (model == NULL) {
-               plt_err("Invalid model_id = %u", model_id);
-               return -EINVAL;
-       }
+       RTE_SET_USED(cnxk_mldev);
 
        if (model->state == ML_CNXK_MODEL_STATE_UNKNOWN)
                return -1;
        else if (model->state != ML_CNXK_MODEL_STATE_LOADED)
                return -EBUSY;
 
+       layer = &model->layer[0];
+
        /* Update model weights & bias */
-       rte_memcpy(model->layer[0].glow.addr.wb_load_addr, buffer,
-                  model->layer[0].glow.metadata.weights_bias.file_size);
+       rte_memcpy(layer->glow.addr.wb_load_addr, buffer,
+                  layer->glow.metadata.weights_bias.file_size);
 
        return 0;
 }
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.h b/drivers/ml/cnxk/cn10k_ml_ops.h
index a222a43d55..ef12069f0d 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.h
+++ b/drivers/ml/cnxk/cn10k_ml_ops.h
@@ -317,9 +317,8 @@ int cn10k_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, 
struct rte_ml_model_para
 int cn10k_ml_model_unload(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model 
*model);
 int cn10k_ml_model_start(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model 
*model);
 int cn10k_ml_model_stop(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model 
*model);
-int cn10k_ml_model_info_get(struct rte_ml_dev *dev, uint16_t model_id,
-                           struct rte_ml_model_info *model_info);
-int cn10k_ml_model_params_update(struct rte_ml_dev *dev, uint16_t model_id, 
void *buffer);
+int cn10k_ml_model_params_update(struct cnxk_ml_dev *cnxk_mldev, struct 
cnxk_ml_model *model,
+                                void *buffer);
 
 /* I/O ops */
 int cn10k_ml_io_quantize(struct rte_ml_dev *dev, uint16_t model_id,
diff --git a/drivers/ml/cnxk/cnxk_ml_ops.c b/drivers/ml/cnxk/cnxk_ml_ops.c
index b61ed45876..9ce37fcfd1 100644
--- a/drivers/ml/cnxk/cnxk_ml_ops.c
+++ b/drivers/ml/cnxk/cnxk_ml_ops.c
@@ -604,6 +604,50 @@ cnxk_ml_model_stop(struct rte_ml_dev *dev, uint16_t 
model_id)
        return cn10k_ml_model_stop(cnxk_mldev, model);
 }
 
+static int
+cnxk_ml_model_info_get(struct rte_ml_dev *dev, uint16_t model_id,
+                      struct rte_ml_model_info *model_info)
+{
+       struct rte_ml_model_info *info;
+       struct cnxk_ml_model *model;
+
+       if ((dev == NULL) || (model_info == NULL))
+               return -EINVAL;
+
+       model = dev->data->models[model_id];
+       if (model == NULL) {
+               plt_err("Invalid model_id = %u", model_id);
+               return -EINVAL;
+       }
+
+       info = (struct rte_ml_model_info *)model->info;
+       rte_memcpy(model_info, info, sizeof(struct rte_ml_model_info));
+       model_info->input_info = info->input_info;
+       model_info->output_info = info->output_info;
+
+       return 0;
+}
+
+static int
+cnxk_ml_model_params_update(struct rte_ml_dev *dev, uint16_t model_id, void 
*buffer)
+{
+       struct cnxk_ml_dev *cnxk_mldev;
+       struct cnxk_ml_model *model;
+
+       if ((dev == NULL) || (buffer == NULL))
+               return -EINVAL;
+
+       cnxk_mldev = dev->data->dev_private;
+
+       model = dev->data->models[model_id];
+       if (model == NULL) {
+               plt_err("Invalid model_id = %u", model_id);
+               return -EINVAL;
+       }
+
+       return cn10k_ml_model_params_update(cnxk_mldev, model, buffer);
+}
+
 struct rte_ml_dev_ops cnxk_ml_ops = {
        /* Device control ops */
        .dev_info_get = cnxk_ml_dev_info_get,
@@ -631,8 +675,8 @@ struct rte_ml_dev_ops cnxk_ml_ops = {
        .model_unload = cnxk_ml_model_unload,
        .model_start = cnxk_ml_model_start,
        .model_stop = cnxk_ml_model_stop,
-       .model_info_get = cn10k_ml_model_info_get,
-       .model_params_update = cn10k_ml_model_params_update,
+       .model_info_get = cnxk_ml_model_info_get,
+       .model_params_update = cnxk_ml_model_params_update,
 
        /* I/O ops */
        .io_quantize = cn10k_ml_io_quantize,
-- 
2.42.0

Reply via email to