Added cnxk wrapper function to update model params and fetch model info. Signed-off-by: Srikanth Yalavarthi <syalavar...@marvell.com> --- drivers/ml/cnxk/cn10k_ml_ops.c | 38 ++++++--------------------- drivers/ml/cnxk/cn10k_ml_ops.h | 5 ++-- drivers/ml/cnxk/cnxk_ml_ops.c | 48 ++++++++++++++++++++++++++++++++-- 3 files changed, 56 insertions(+), 35 deletions(-)
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c index 40f484158a..3ff82829f0 100644 --- a/drivers/ml/cnxk/cn10k_ml_ops.c +++ b/drivers/ml/cnxk/cn10k_ml_ops.c @@ -1835,45 +1835,23 @@ cn10k_ml_model_stop(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model) } int -cn10k_ml_model_info_get(struct rte_ml_dev *dev, uint16_t model_id, - struct rte_ml_model_info *model_info) +cn10k_ml_model_params_update(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model, + void *buffer) { - struct cnxk_ml_model *model; - - model = dev->data->models[model_id]; - - if (model == NULL) { - plt_err("Invalid model_id = %u", model_id); - return -EINVAL; - } - - rte_memcpy(model_info, model->info, sizeof(struct rte_ml_model_info)); - model_info->input_info = ((struct rte_ml_model_info *)model->info)->input_info; - model_info->output_info = ((struct rte_ml_model_info *)model->info)->output_info; - - return 0; -} - -int -cn10k_ml_model_params_update(struct rte_ml_dev *dev, uint16_t model_id, void *buffer) -{ - struct cnxk_ml_model *model; - - model = dev->data->models[model_id]; + struct cnxk_ml_layer *layer; - if (model == NULL) { - plt_err("Invalid model_id = %u", model_id); - return -EINVAL; - } + RTE_SET_USED(cnxk_mldev); if (model->state == ML_CNXK_MODEL_STATE_UNKNOWN) return -1; else if (model->state != ML_CNXK_MODEL_STATE_LOADED) return -EBUSY; + layer = &model->layer[0]; + /* Update model weights & bias */ - rte_memcpy(model->layer[0].glow.addr.wb_load_addr, buffer, - model->layer[0].glow.metadata.weights_bias.file_size); + rte_memcpy(layer->glow.addr.wb_load_addr, buffer, + layer->glow.metadata.weights_bias.file_size); return 0; } diff --git a/drivers/ml/cnxk/cn10k_ml_ops.h b/drivers/ml/cnxk/cn10k_ml_ops.h index a222a43d55..ef12069f0d 100644 --- a/drivers/ml/cnxk/cn10k_ml_ops.h +++ b/drivers/ml/cnxk/cn10k_ml_ops.h @@ -317,9 +317,8 @@ int cn10k_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_model_para int cn10k_ml_model_unload(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model); int cn10k_ml_model_start(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model); int cn10k_ml_model_stop(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model); -int cn10k_ml_model_info_get(struct rte_ml_dev *dev, uint16_t model_id, - struct rte_ml_model_info *model_info); -int cn10k_ml_model_params_update(struct rte_ml_dev *dev, uint16_t model_id, void *buffer); +int cn10k_ml_model_params_update(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model, + void *buffer); /* I/O ops */ int cn10k_ml_io_quantize(struct rte_ml_dev *dev, uint16_t model_id, diff --git a/drivers/ml/cnxk/cnxk_ml_ops.c b/drivers/ml/cnxk/cnxk_ml_ops.c index b61ed45876..9ce37fcfd1 100644 --- a/drivers/ml/cnxk/cnxk_ml_ops.c +++ b/drivers/ml/cnxk/cnxk_ml_ops.c @@ -604,6 +604,50 @@ cnxk_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id) return cn10k_ml_model_stop(cnxk_mldev, model); } +static int +cnxk_ml_model_info_get(struct rte_ml_dev *dev, uint16_t model_id, + struct rte_ml_model_info *model_info) +{ + struct rte_ml_model_info *info; + struct cnxk_ml_model *model; + + if ((dev == NULL) || (model_info == NULL)) + return -EINVAL; + + model = dev->data->models[model_id]; + if (model == NULL) { + plt_err("Invalid model_id = %u", model_id); + return -EINVAL; + } + + info = (struct rte_ml_model_info *)model->info; + rte_memcpy(model_info, info, sizeof(struct rte_ml_model_info)); + model_info->input_info = info->input_info; + model_info->output_info = info->output_info; + + return 0; +} + +static int +cnxk_ml_model_params_update(struct rte_ml_dev *dev, uint16_t model_id, void *buffer) +{ + struct cnxk_ml_dev *cnxk_mldev; + struct cnxk_ml_model *model; + + if ((dev == NULL) || (buffer == NULL)) + return -EINVAL; + + cnxk_mldev = dev->data->dev_private; + + model = dev->data->models[model_id]; + if (model == NULL) { + plt_err("Invalid model_id = %u", model_id); + return -EINVAL; + } + + return cn10k_ml_model_params_update(cnxk_mldev, model, buffer); +} + struct rte_ml_dev_ops cnxk_ml_ops = { /* Device control ops */ .dev_info_get = cnxk_ml_dev_info_get, @@ -631,8 +675,8 @@ struct rte_ml_dev_ops cnxk_ml_ops = { .model_unload = cnxk_ml_model_unload, .model_start = cnxk_ml_model_start, .model_stop = cnxk_ml_model_stop, - .model_info_get = cn10k_ml_model_info_get, - .model_params_update = cn10k_ml_model_params_update, + .model_info_get = cnxk_ml_model_info_get, + .model_params_update = cnxk_ml_model_params_update, /* I/O ops */ .io_quantize = cn10k_ml_io_quantize, -- 2.42.0