Implemented cnxk wrapper functions to load and unload
ML models. Wrapper functions would invoke the cn10k
model load and unload functions.

Signed-off-by: Srikanth Yalavarthi <syalavar...@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_model.c | 244 ++++++++++++-------------
 drivers/ml/cnxk/cn10k_ml_model.h |  26 ++-
 drivers/ml/cnxk/cn10k_ml_ops.c   | 296 ++++++++++++++++++-------------
 drivers/ml/cnxk/cn10k_ml_ops.h   |  12 +-
 drivers/ml/cnxk/cnxk_ml_dev.h    |  15 ++
 drivers/ml/cnxk/cnxk_ml_ops.c    | 144 ++++++++++++++-
 drivers/ml/cnxk/cnxk_ml_ops.h    |   2 +
 7 files changed, 462 insertions(+), 277 deletions(-)

diff --git a/drivers/ml/cnxk/cn10k_ml_model.c b/drivers/ml/cnxk/cn10k_ml_model.c
index d2f1c761be..48d70027ca 100644
--- a/drivers/ml/cnxk/cn10k_ml_model.c
+++ b/drivers/ml/cnxk/cn10k_ml_model.c
@@ -316,42 +316,31 @@ cn10k_ml_layer_addr_update(struct cnxk_ml_layer *layer, 
uint8_t *buffer, uint8_t
 {
        struct cn10k_ml_model_metadata *metadata;
        struct cn10k_ml_layer_addr *addr;
-       size_t model_data_size;
        uint8_t *dma_addr_load;
-       uint8_t *dma_addr_run;
        int fpos;
 
        metadata = &layer->glow.metadata;
        addr = &layer->glow.addr;
-       model_data_size = metadata->init_model.file_size + 
metadata->main_model.file_size +
-                         metadata->finish_model.file_size + 
metadata->weights_bias.file_size;
 
        /* Base address */
        addr->base_dma_addr_load = base_dma_addr;
-       addr->base_dma_addr_run = PLT_PTR_ADD(addr->base_dma_addr_load, 
model_data_size);
 
        /* Init section */
        dma_addr_load = addr->base_dma_addr_load;
-       dma_addr_run = addr->base_dma_addr_run;
        fpos = sizeof(struct cn10k_ml_model_metadata);
        addr->init_load_addr = dma_addr_load;
-       addr->init_run_addr = dma_addr_run;
        rte_memcpy(dma_addr_load, PLT_PTR_ADD(buffer, fpos), 
metadata->init_model.file_size);
 
        /* Main section */
        dma_addr_load += metadata->init_model.file_size;
-       dma_addr_run += metadata->init_model.file_size;
        fpos += metadata->init_model.file_size;
        addr->main_load_addr = dma_addr_load;
-       addr->main_run_addr = dma_addr_run;
        rte_memcpy(dma_addr_load, PLT_PTR_ADD(buffer, fpos), 
metadata->main_model.file_size);
 
        /* Finish section */
        dma_addr_load += metadata->main_model.file_size;
-       dma_addr_run += metadata->main_model.file_size;
        fpos += metadata->main_model.file_size;
        addr->finish_load_addr = dma_addr_load;
-       addr->finish_run_addr = dma_addr_run;
        rte_memcpy(dma_addr_load, PLT_PTR_ADD(buffer, fpos), 
metadata->finish_model.file_size);
 
        /* Weights and Bias section */
@@ -363,142 +352,148 @@ cn10k_ml_layer_addr_update(struct cnxk_ml_layer *layer, 
uint8_t *buffer, uint8_t
 }
 
 void
-cn10k_ml_layer_info_update(struct cnxk_ml_layer *layer)
+cn10k_ml_layer_io_info_set(struct cnxk_ml_io_info *io_info,
+                          struct cn10k_ml_model_metadata *metadata)
 {
-       struct cn10k_ml_model_metadata *metadata;
        uint8_t i;
        uint8_t j;
 
-       metadata = &layer->glow.metadata;
-
        /* Inputs */
-       layer->info.nb_inputs = metadata->model.num_input;
-       layer->info.total_input_sz_d = 0;
-       layer->info.total_input_sz_q = 0;
+       io_info->nb_inputs = metadata->model.num_input;
+       io_info->total_input_sz_d = 0;
+       io_info->total_input_sz_q = 0;
        for (i = 0; i < metadata->model.num_input; i++) {
                if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
-                       rte_strscpy(layer->info.input[i].name,
-                                   (char *)metadata->input1[i].input_name, 
MRVL_ML_INPUT_NAME_LEN);
-                       layer->info.input[i].dtype = 
metadata->input1[i].input_type;
-                       layer->info.input[i].qtype = 
metadata->input1[i].model_input_type;
-                       layer->info.input[i].nb_dims = 4;
-                       layer->info.input[i].shape[0] = 
metadata->input1[i].shape.w;
-                       layer->info.input[i].shape[1] = 
metadata->input1[i].shape.x;
-                       layer->info.input[i].shape[2] = 
metadata->input1[i].shape.y;
-                       layer->info.input[i].shape[3] = 
metadata->input1[i].shape.z;
-                       layer->info.input[i].nb_elements =
+                       rte_strscpy(io_info->input[i].name, (char 
*)metadata->input1[i].input_name,
+                                   MRVL_ML_INPUT_NAME_LEN);
+                       io_info->input[i].dtype = 
metadata->input1[i].input_type;
+                       io_info->input[i].qtype = 
metadata->input1[i].model_input_type;
+                       io_info->input[i].nb_dims = 4;
+                       io_info->input[i].shape[0] = 
metadata->input1[i].shape.w;
+                       io_info->input[i].shape[1] = 
metadata->input1[i].shape.x;
+                       io_info->input[i].shape[2] = 
metadata->input1[i].shape.y;
+                       io_info->input[i].shape[3] = 
metadata->input1[i].shape.z;
+                       io_info->input[i].nb_elements =
                                metadata->input1[i].shape.w * 
metadata->input1[i].shape.x *
                                metadata->input1[i].shape.y * 
metadata->input1[i].shape.z;
-                       layer->info.input[i].sz_d =
-                               layer->info.input[i].nb_elements *
+                       io_info->input[i].sz_d =
+                               io_info->input[i].nb_elements *
                                
rte_ml_io_type_size_get(metadata->input1[i].input_type);
-                       layer->info.input[i].sz_q =
-                               layer->info.input[i].nb_elements *
+                       io_info->input[i].sz_q =
+                               io_info->input[i].nb_elements *
                                
rte_ml_io_type_size_get(metadata->input1[i].model_input_type);
-                       layer->info.input[i].scale = metadata->input1[i].qscale;
+                       io_info->input[i].scale = metadata->input1[i].qscale;
 
-                       layer->info.total_input_sz_d += 
layer->info.input[i].sz_d;
-                       layer->info.total_input_sz_q += 
layer->info.input[i].sz_q;
+                       io_info->total_input_sz_d += io_info->input[i].sz_d;
+                       io_info->total_input_sz_q += io_info->input[i].sz_q;
 
                        plt_ml_dbg(
-                               "index = %u, input1[%u] - w:%u x:%u y:%u z:%u, 
sz_d = %u sz_q = %u",
-                               layer->index, i, metadata->input1[i].shape.w,
+                               "layer_name = %s, input1[%u] - w:%u x:%u y:%u 
z:%u, sz_d = %u sz_q = %u",
+                               metadata->model.name, i, 
metadata->input1[i].shape.w,
                                metadata->input1[i].shape.x, 
metadata->input1[i].shape.y,
-                               metadata->input1[i].shape.z, 
layer->info.input[i].sz_d,
-                               layer->info.input[i].sz_q);
+                               metadata->input1[i].shape.z, 
io_info->input[i].sz_d,
+                               io_info->input[i].sz_q);
                } else {
                        j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
 
-                       rte_strscpy(layer->info.input[i].name,
-                                   (char *)metadata->input2[j].input_name, 
MRVL_ML_INPUT_NAME_LEN);
-                       layer->info.input[i].dtype = 
metadata->input2[j].input_type;
-                       layer->info.input[i].qtype = 
metadata->input2[j].model_input_type;
-                       layer->info.input[i].nb_dims = 4;
-                       layer->info.input[i].shape[0] = 
metadata->input2[j].shape.w;
-                       layer->info.input[i].shape[1] = 
metadata->input2[j].shape.x;
-                       layer->info.input[i].shape[2] = 
metadata->input2[j].shape.y;
-                       layer->info.input[i].shape[3] = 
metadata->input2[j].shape.z;
-                       layer->info.input[i].nb_elements =
+                       rte_strscpy(io_info->input[i].name, (char 
*)metadata->input2[j].input_name,
+                                   MRVL_ML_INPUT_NAME_LEN);
+                       io_info->input[i].dtype = 
metadata->input2[j].input_type;
+                       io_info->input[i].qtype = 
metadata->input2[j].model_input_type;
+                       io_info->input[i].nb_dims = 4;
+                       io_info->input[i].shape[0] = 
metadata->input2[j].shape.w;
+                       io_info->input[i].shape[1] = 
metadata->input2[j].shape.x;
+                       io_info->input[i].shape[2] = 
metadata->input2[j].shape.y;
+                       io_info->input[i].shape[3] = 
metadata->input2[j].shape.z;
+                       io_info->input[i].nb_elements =
                                metadata->input2[j].shape.w * 
metadata->input2[j].shape.x *
                                metadata->input2[j].shape.y * 
metadata->input2[j].shape.z;
-                       layer->info.input[i].sz_d =
-                               layer->info.input[i].nb_elements *
+                       io_info->input[i].sz_d =
+                               io_info->input[i].nb_elements *
                                
rte_ml_io_type_size_get(metadata->input2[j].input_type);
-                       layer->info.input[i].sz_q =
-                               layer->info.input[i].nb_elements *
+                       io_info->input[i].sz_q =
+                               io_info->input[i].nb_elements *
                                
rte_ml_io_type_size_get(metadata->input2[j].model_input_type);
-                       layer->info.input[i].scale = metadata->input2[j].qscale;
+                       io_info->input[i].scale = metadata->input2[j].qscale;
 
-                       layer->info.total_input_sz_d += 
layer->info.input[i].sz_d;
-                       layer->info.total_input_sz_q += 
layer->info.input[i].sz_q;
+                       io_info->total_input_sz_d += io_info->input[i].sz_d;
+                       io_info->total_input_sz_q += io_info->input[i].sz_q;
 
                        plt_ml_dbg(
-                               "index = %u, input2[%u] - w:%u x:%u y:%u z:%u, 
sz_d = %u sz_q = %u",
-                               layer->index, j, metadata->input2[j].shape.w,
+                               "layer_name = %s, input2[%u] - w:%u x:%u y:%u 
z:%u, sz_d = %u sz_q = %u",
+                               metadata->model.name, j, 
metadata->input2[j].shape.w,
                                metadata->input2[j].shape.x, 
metadata->input2[j].shape.y,
-                               metadata->input2[j].shape.z, 
layer->info.input[i].sz_d,
-                               layer->info.input[i].sz_q);
+                               metadata->input2[j].shape.z, 
io_info->input[i].sz_d,
+                               io_info->input[i].sz_q);
                }
        }
 
        /* Outputs */
-       layer->info.nb_outputs = metadata->model.num_output;
-       layer->info.total_output_sz_q = 0;
-       layer->info.total_output_sz_d = 0;
+       io_info->nb_outputs = metadata->model.num_output;
+       io_info->total_output_sz_q = 0;
+       io_info->total_output_sz_d = 0;
        for (i = 0; i < metadata->model.num_output; i++) {
                if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
-                       rte_strscpy(layer->info.output[i].name,
+                       rte_strscpy(io_info->output[i].name,
                                    (char *)metadata->output1[i].output_name,
                                    MRVL_ML_OUTPUT_NAME_LEN);
-                       layer->info.output[i].dtype = 
metadata->output1[i].output_type;
-                       layer->info.output[i].qtype = 
metadata->output1[i].model_output_type;
-                       layer->info.output[i].nb_dims = 1;
-                       layer->info.output[i].shape[0] = 
metadata->output1[i].size;
-                       layer->info.output[i].nb_elements = 
metadata->output1[i].size;
-                       layer->info.output[i].sz_d =
-                               layer->info.output[i].nb_elements *
+                       io_info->output[i].dtype = 
metadata->output1[i].output_type;
+                       io_info->output[i].qtype = 
metadata->output1[i].model_output_type;
+                       io_info->output[i].nb_dims = 1;
+                       io_info->output[i].shape[0] = metadata->output1[i].size;
+                       io_info->output[i].nb_elements = 
metadata->output1[i].size;
+                       io_info->output[i].sz_d =
+                               io_info->output[i].nb_elements *
                                
rte_ml_io_type_size_get(metadata->output1[i].output_type);
-                       layer->info.output[i].sz_q =
-                               layer->info.output[i].nb_elements *
+                       io_info->output[i].sz_q =
+                               io_info->output[i].nb_elements *
                                
rte_ml_io_type_size_get(metadata->output1[i].model_output_type);
-                       layer->info.output[i].scale = 
metadata->output1[i].dscale;
+                       io_info->output[i].scale = metadata->output1[i].dscale;
 
-                       layer->info.total_output_sz_q += 
layer->info.output[i].sz_q;
-                       layer->info.total_output_sz_d += 
layer->info.output[i].sz_d;
+                       io_info->total_output_sz_q += io_info->output[i].sz_q;
+                       io_info->total_output_sz_d += io_info->output[i].sz_d;
 
-                       plt_ml_dbg("index = %u, output1[%u] - sz_d = %u, sz_q = 
%u", layer->index,
-                                  i, layer->info.output[i].sz_d, 
layer->info.output[i].sz_q);
+                       plt_ml_dbg("layer_name = %s, output1[%u] - sz_d = %u, 
sz_q = %u",
+                                  metadata->model.name, i, 
io_info->output[i].sz_d,
+                                  io_info->output[i].sz_q);
                } else {
                        j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
 
-                       rte_strscpy(layer->info.output[i].name,
+                       rte_strscpy(io_info->output[i].name,
                                    (char *)metadata->output2[j].output_name,
                                    MRVL_ML_OUTPUT_NAME_LEN);
-                       layer->info.output[i].dtype = 
metadata->output2[j].output_type;
-                       layer->info.output[i].qtype = 
metadata->output2[j].model_output_type;
-                       layer->info.output[i].nb_dims = 1;
-                       layer->info.output[i].shape[0] = 
metadata->output2[j].size;
-                       layer->info.output[i].nb_elements = 
metadata->output2[j].size;
-                       layer->info.output[i].sz_d =
-                               layer->info.output[i].nb_elements *
+                       io_info->output[i].dtype = 
metadata->output2[j].output_type;
+                       io_info->output[i].qtype = 
metadata->output2[j].model_output_type;
+                       io_info->output[i].nb_dims = 1;
+                       io_info->output[i].shape[0] = metadata->output2[j].size;
+                       io_info->output[i].nb_elements = 
metadata->output2[j].size;
+                       io_info->output[i].sz_d =
+                               io_info->output[i].nb_elements *
                                
rte_ml_io_type_size_get(metadata->output2[j].output_type);
-                       layer->info.output[i].sz_q =
-                               layer->info.output[i].nb_elements *
+                       io_info->output[i].sz_q =
+                               io_info->output[i].nb_elements *
                                
rte_ml_io_type_size_get(metadata->output2[j].model_output_type);
-                       layer->info.output[i].scale = 
metadata->output2[j].dscale;
+                       io_info->output[i].scale = metadata->output2[j].dscale;
 
-                       layer->info.total_output_sz_q += 
layer->info.output[i].sz_q;
-                       layer->info.total_output_sz_d += 
layer->info.output[i].sz_d;
+                       io_info->total_output_sz_q += io_info->output[i].sz_q;
+                       io_info->total_output_sz_d += io_info->output[i].sz_d;
 
-                       plt_ml_dbg("index = %u, output2[%u] - sz_d = %u, sz_q = 
%u", layer->index,
-                                  j, layer->info.output[i].sz_d, 
layer->info.output[i].sz_q);
+                       plt_ml_dbg("layer_name = %s, output2[%u] - sz_d = %u, 
sz_q = %u",
+                                  metadata->model.name, j, 
io_info->output[i].sz_d,
+                                  io_info->output[i].sz_q);
                }
        }
 }
 
+struct cnxk_ml_io_info *
+cn10k_ml_model_io_info_get(struct cnxk_ml_model *model, uint16_t layer_id)
+{
+       return &model->layer[layer_id].info;
+}
+
 int
-cn10k_ml_model_ocm_pages_count(struct cn10k_ml_dev *cn10k_mldev, uint16_t 
model_id, uint8_t *buffer,
-                              uint16_t *wb_pages, uint16_t *scratch_pages)
+cn10k_ml_model_ocm_pages_count(struct cnxk_ml_dev *cnxk_mldev, struct 
cnxk_ml_layer *layer,
+                              uint8_t *buffer, uint16_t *wb_pages, uint16_t 
*scratch_pages)
 {
        struct cn10k_ml_model_metadata *metadata;
        struct cn10k_ml_ocm *ocm;
@@ -506,7 +501,7 @@ cn10k_ml_model_ocm_pages_count(struct cn10k_ml_dev 
*cn10k_mldev, uint16_t model_
        uint64_t wb_size;
 
        metadata = (struct cn10k_ml_model_metadata *)buffer;
-       ocm = &cn10k_mldev->ocm;
+       ocm = &cnxk_mldev->cn10k_mldev.ocm;
 
        /* Assume wb_size is zero for non-relocatable models */
        if (metadata->model.ocm_relocatable)
@@ -518,7 +513,7 @@ cn10k_ml_model_ocm_pages_count(struct cn10k_ml_dev 
*cn10k_mldev, uint16_t model_
                *wb_pages = wb_size / ocm->page_size + 1;
        else
                *wb_pages = wb_size / ocm->page_size;
-       plt_ml_dbg("model_id = %u, wb_size = %" PRIu64 ", wb_pages = %u", 
model_id, wb_size,
+       plt_ml_dbg("index = %u, wb_size = %" PRIu64 ", wb_pages = %u", 
layer->index, wb_size,
                   *wb_pages);
 
        scratch_size = ocm->size_per_tile - metadata->model.ocm_tmp_range_floor;
@@ -526,15 +521,15 @@ cn10k_ml_model_ocm_pages_count(struct cn10k_ml_dev 
*cn10k_mldev, uint16_t model_
                *scratch_pages = scratch_size / ocm->page_size + 1;
        else
                *scratch_pages = scratch_size / ocm->page_size;
-       plt_ml_dbg("model_id = %u, scratch_size = %" PRIu64 ", scratch_pages = 
%u", model_id,
+       plt_ml_dbg("index = %u, scratch_size = %" PRIu64 ", scratch_pages = 
%u", layer->index,
                   scratch_size, *scratch_pages);
 
        /* Check if the model can be loaded on OCM */
-       if ((*wb_pages + *scratch_pages) > cn10k_mldev->ocm.num_pages) {
+       if ((*wb_pages + *scratch_pages) > ocm->num_pages) {
                plt_err("Cannot create the model, OCM relocatable = %u",
                        metadata->model.ocm_relocatable);
                plt_err("wb_pages (%u) + scratch_pages (%u) > %u", *wb_pages, 
*scratch_pages,
-                       cn10k_mldev->ocm.num_pages);
+                       ocm->num_pages);
                return -ENOMEM;
        }
 
@@ -542,28 +537,25 @@ cn10k_ml_model_ocm_pages_count(struct cn10k_ml_dev 
*cn10k_mldev, uint16_t model_
         * prevent the library from allocating the remaining space on the tile 
to other models.
         */
        if (!metadata->model.ocm_relocatable)
-               *scratch_pages = PLT_MAX(PLT_U64_CAST(*scratch_pages),
-                                        
PLT_U64_CAST(cn10k_mldev->ocm.num_pages));
+               *scratch_pages =
+                       PLT_MAX(PLT_U64_CAST(*scratch_pages), 
PLT_U64_CAST(ocm->num_pages));
 
        return 0;
 }
 
 void
-cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct cnxk_ml_model *model)
+cn10k_ml_model_info_set(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model 
*model,
+                       struct cnxk_ml_io_info *io_info, struct 
cn10k_ml_model_metadata *metadata)
 {
-       struct cn10k_ml_model_metadata *metadata;
-       struct cnxk_ml_dev *cnxk_mldev;
        struct rte_ml_model_info *info;
        struct rte_ml_io_info *output;
        struct rte_ml_io_info *input;
-       struct cnxk_ml_layer *layer;
        uint8_t i;
 
-       cnxk_mldev = dev->data->dev_private;
        metadata = &model->glow.metadata;
        info = PLT_PTR_CAST(model->info);
        input = PLT_PTR_ADD(info, sizeof(struct rte_ml_model_info));
-       output = PLT_PTR_ADD(input, metadata->model.num_input * sizeof(struct 
rte_ml_io_info));
+       output = PLT_PTR_ADD(input, ML_CNXK_MODEL_MAX_INPUT_OUTPUT * 
sizeof(struct rte_ml_io_info));
 
        /* Set model info */
        memset(info, 0, sizeof(struct rte_ml_model_info));
@@ -572,39 +564,37 @@ cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct 
cnxk_ml_model *model)
                 metadata->model.version[1], metadata->model.version[2],
                 metadata->model.version[3]);
        info->model_id = model->model_id;
-       info->device_id = dev->data->dev_id;
+       info->device_id = cnxk_mldev->mldev->data->dev_id;
        info->io_layout = RTE_ML_IO_LAYOUT_PACKED;
        info->min_batches = model->batch_size;
        info->max_batches =
                
cnxk_mldev->cn10k_mldev.fw.req->cn10k_req.jd.fw_load.cap.s.max_num_batches /
                model->batch_size;
-       info->nb_inputs = metadata->model.num_input;
+       info->nb_inputs = io_info->nb_inputs;
        info->input_info = input;
-       info->nb_outputs = metadata->model.num_output;
+       info->nb_outputs = io_info->nb_outputs;
        info->output_info = output;
        info->wb_size = metadata->weights_bias.file_size;
 
        /* Set input info */
-       layer = &model->layer[0];
        for (i = 0; i < info->nb_inputs; i++) {
-               rte_memcpy(input[i].name, layer->info.input[i].name, 
MRVL_ML_INPUT_NAME_LEN);
-               input[i].nb_dims = layer->info.input[i].nb_dims;
-               input[i].shape = &layer->info.input[i].shape[0];
-               input[i].type = layer->info.input[i].qtype;
-               input[i].nb_elements = layer->info.input[i].nb_elements;
-               input[i].size = layer->info.input[i].nb_elements *
-                               
rte_ml_io_type_size_get(layer->info.input[i].qtype);
+               rte_memcpy(input[i].name, io_info->input[i].name, 
MRVL_ML_INPUT_NAME_LEN);
+               input[i].nb_dims = io_info->input[i].nb_dims;
+               input[i].shape = &io_info->input[i].shape[0];
+               input[i].type = io_info->input[i].qtype;
+               input[i].nb_elements = io_info->input[i].nb_elements;
+               input[i].size = io_info->input[i].nb_elements *
+                               
rte_ml_io_type_size_get(io_info->input[i].qtype);
        }
 
        /* Set output info */
-       layer = &model->layer[0];
        for (i = 0; i < info->nb_outputs; i++) {
-               rte_memcpy(output[i].name, layer->info.output[i].name, 
MRVL_ML_INPUT_NAME_LEN);
-               output[i].nb_dims = layer->info.output[i].nb_dims;
-               output[i].shape = &layer->info.output[i].shape[0];
-               output[i].type = layer->info.output[i].qtype;
-               output[i].nb_elements = layer->info.output[i].nb_elements;
-               output[i].size = layer->info.output[i].nb_elements *
-                                
rte_ml_io_type_size_get(layer->info.output[i].qtype);
+               rte_memcpy(output[i].name, io_info->output[i].name, 
MRVL_ML_INPUT_NAME_LEN);
+               output[i].nb_dims = io_info->output[i].nb_dims;
+               output[i].shape = &io_info->output[i].shape[0];
+               output[i].type = io_info->output[i].qtype;
+               output[i].nb_elements = io_info->output[i].nb_elements;
+               output[i].size = io_info->output[i].nb_elements *
+                                
rte_ml_io_type_size_get(io_info->output[i].qtype);
        }
 }
diff --git a/drivers/ml/cnxk/cn10k_ml_model.h b/drivers/ml/cnxk/cn10k_ml_model.h
index 5c32f48c68..b891c9d627 100644
--- a/drivers/ml/cnxk/cn10k_ml_model.h
+++ b/drivers/ml/cnxk/cn10k_ml_model.h
@@ -9,9 +9,11 @@
 
 #include <roc_api.h>
 
-#include "cn10k_ml_dev.h"
 #include "cn10k_ml_ocm.h"
 
+#include "cnxk_ml_io.h"
+
+struct cnxk_ml_dev;
 struct cnxk_ml_model;
 struct cnxk_ml_layer;
 struct cnxk_ml_req;
@@ -366,27 +368,15 @@ struct cn10k_ml_layer_addr {
        /* Base DMA address for load */
        void *base_dma_addr_load;
 
-       /* Base DMA address for run */
-       void *base_dma_addr_run;
-
        /* Init section load address */
        void *init_load_addr;
 
-       /* Init section run address */
-       void *init_run_addr;
-
        /* Main section load address */
        void *main_load_addr;
 
-       /* Main section run address */
-       void *main_run_addr;
-
        /* Finish section load address */
        void *finish_load_addr;
 
-       /* Finish section run address */
-       void *finish_run_addr;
-
        /* Weights and Bias base address */
        void *wb_base_addr;
 
@@ -462,9 +452,13 @@ int cn10k_ml_model_metadata_check(uint8_t *buffer, 
uint64_t size);
 void cn10k_ml_model_metadata_update(struct cn10k_ml_model_metadata *metadata);
 void cn10k_ml_layer_addr_update(struct cnxk_ml_layer *layer, uint8_t *buffer,
                                uint8_t *base_dma_addr);
-void cn10k_ml_layer_info_update(struct cnxk_ml_layer *layer);
-int cn10k_ml_model_ocm_pages_count(struct cn10k_ml_dev *cn10k_mldev, uint16_t 
model_id,
+void cn10k_ml_layer_io_info_set(struct cnxk_ml_io_info *io_info,
+                               struct cn10k_ml_model_metadata *metadata);
+struct cnxk_ml_io_info *cn10k_ml_model_io_info_get(struct cnxk_ml_model 
*model, uint16_t layer_id);
+int cn10k_ml_model_ocm_pages_count(struct cnxk_ml_dev *cnxk_mldev, struct 
cnxk_ml_layer *layer,
                                   uint8_t *buffer, uint16_t *wb_pages, 
uint16_t *scratch_pages);
-void cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct cnxk_ml_model 
*model);
+void cn10k_ml_model_info_set(struct cnxk_ml_dev *cnxk_mldev, struct 
cnxk_ml_model *model,
+                            struct cnxk_ml_io_info *io_info,
+                            struct cn10k_ml_model_metadata *metadata);
 
 #endif /* _CN10K_ML_MODEL_H_ */
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index 9691cf03e3..ab05896b5e 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -15,6 +15,9 @@
 /* ML model macros */
 #define CN10K_ML_MODEL_MEMZONE_NAME "ml_cn10k_model_mz"
 
+/* ML layer macros */
+#define CN10K_ML_LAYER_MEMZONE_NAME "ml_cn10k_layer_mz"
+
 /* Debug print width */
 #define STR_LEN          12
 #define FIELD_LEN 16
@@ -273,7 +276,7 @@ cn10k_ml_prep_sp_job_descriptor(struct cn10k_ml_dev 
*cn10k_mldev, struct cnxk_ml
                req->cn10k_req.jd.model_start.extended_args = PLT_U64_CAST(
                        roc_ml_addr_ap2mlip(&cn10k_mldev->roc, 
&req->cn10k_req.extended_args));
                req->cn10k_req.jd.model_start.model_dst_ddr_addr =
-                       PLT_U64_CAST(roc_ml_addr_ap2mlip(&cn10k_mldev->roc, 
addr->init_run_addr));
+                       PLT_U64_CAST(roc_ml_addr_ap2mlip(&cn10k_mldev->roc, 
addr->init_load_addr));
                req->cn10k_req.jd.model_start.model_init_offset = 0x0;
                req->cn10k_req.jd.model_start.model_main_offset = 
metadata->init_model.file_size;
                req->cn10k_req.jd.model_start.model_finish_offset =
@@ -1261,85 +1264,171 @@ cn10k_ml_dev_selftest(struct rte_ml_dev *dev)
 }
 
 int
-cn10k_ml_model_load(struct rte_ml_dev *dev, struct rte_ml_model_params 
*params, uint16_t *model_id)
+cn10k_ml_layer_load(void *device, uint16_t model_id, const char *layer_name, 
uint8_t *buffer,
+                   size_t size, uint16_t *index)
 {
        struct cn10k_ml_model_metadata *metadata;
        struct cnxk_ml_dev *cnxk_mldev;
        struct cnxk_ml_model *model;
+       struct cnxk_ml_layer *layer;
 
        char str[RTE_MEMZONE_NAMESIZE];
        const struct plt_memzone *mz;
-       size_t model_scratch_size;
-       size_t model_stats_size;
-       size_t model_data_size;
-       size_t model_info_size;
+       size_t layer_object_size = 0;
+       size_t layer_scratch_size;
+       size_t layer_xstats_size;
        uint8_t *base_dma_addr;
        uint16_t scratch_pages;
+       uint16_t layer_id = 0;
        uint16_t wb_pages;
        uint64_t mz_size;
        uint16_t idx;
-       bool found;
        int qp_id;
        int ret;
 
-       ret = cn10k_ml_model_metadata_check(params->addr, params->size);
+       PLT_SET_USED(size);
+       PLT_SET_USED(layer_name);
+
+       cnxk_mldev = (struct cnxk_ml_dev *)device;
+       if (cnxk_mldev == NULL) {
+               plt_err("Invalid device = %p", device);
+               return -EINVAL;
+       }
+
+       model = cnxk_mldev->mldev->data->models[model_id];
+       if (model == NULL) {
+               plt_err("Invalid model_id = %u", model_id);
+               return -EINVAL;
+       }
+
+       layer = &model->layer[layer_id];
+
+       ret = cn10k_ml_model_metadata_check(buffer, size);
        if (ret != 0)
                return ret;
 
-       cnxk_mldev = dev->data->dev_private;
-
-       /* Find model ID */
-       found = false;
-       for (idx = 0; idx < dev->data->nb_models; idx++) {
-               if (dev->data->models[idx] == NULL) {
-                       found = true;
+       /* Get index */
+       for (idx = 0; idx < cnxk_mldev->max_nb_layers; idx++) {
+               if (!cnxk_mldev->index_map[idx].active) {
+                       layer->index = idx;
                        break;
                }
        }
 
-       if (!found) {
-               plt_err("No slots available to load new model");
-               return -ENOMEM;
+       if (idx >= cnxk_mldev->max_nb_layers) {
+               plt_err("No slots available for model layers, model_id = %u, 
layer_id = %u",
+                       model->model_id, layer_id);
+               return -1;
        }
 
+       layer->model = model;
+
        /* Get WB and scratch pages, check if model can be loaded. */
-       ret = cn10k_ml_model_ocm_pages_count(&cnxk_mldev->cn10k_mldev, idx, 
params->addr, &wb_pages,
-                                            &scratch_pages);
+       ret = cn10k_ml_model_ocm_pages_count(cnxk_mldev, layer, buffer, 
&wb_pages, &scratch_pages);
        if (ret < 0)
                return ret;
 
-       /* Compute memzone size */
-       metadata = (struct cn10k_ml_model_metadata *)params->addr;
-       model_data_size = metadata->init_model.file_size + 
metadata->main_model.file_size +
-                         metadata->finish_model.file_size + 
metadata->weights_bias.file_size;
-       model_scratch_size = 
PLT_ALIGN_CEIL(metadata->model.ddr_scratch_range_end -
+       /* Compute layer memzone size */
+       metadata = (struct cn10k_ml_model_metadata *)buffer;
+       layer_object_size = metadata->init_model.file_size + 
metadata->main_model.file_size +
+                           metadata->finish_model.file_size + 
metadata->weights_bias.file_size;
+       layer_object_size = PLT_ALIGN_CEIL(layer_object_size, 
ML_CN10K_ALIGN_SIZE);
+       layer_scratch_size = 
PLT_ALIGN_CEIL(metadata->model.ddr_scratch_range_end -
                                                    
metadata->model.ddr_scratch_range_start + 1,
                                            ML_CN10K_ALIGN_SIZE);
-       model_data_size = PLT_ALIGN_CEIL(model_data_size, ML_CN10K_ALIGN_SIZE);
-       model_info_size = sizeof(struct rte_ml_model_info) +
-                         metadata->model.num_input * sizeof(struct 
rte_ml_io_info) +
-                         metadata->model.num_output * sizeof(struct 
rte_ml_io_info);
-       model_info_size = PLT_ALIGN_CEIL(model_info_size, ML_CN10K_ALIGN_SIZE);
-       model_stats_size = (dev->data->nb_queue_pairs + 1) * sizeof(struct 
cn10k_ml_layer_xstats);
-
-       mz_size = PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_model), 
ML_CN10K_ALIGN_SIZE) +
-                 2 * model_data_size + model_scratch_size + model_info_size +
-                 PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_req), 
ML_CN10K_ALIGN_SIZE) +
-                 model_stats_size;
+       layer_xstats_size = (cnxk_mldev->mldev->data->nb_queue_pairs + 1) *
+                           sizeof(struct cn10k_ml_layer_xstats);
 
-       /* Allocate memzone for model object and model data */
-       snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u", 
CN10K_ML_MODEL_MEMZONE_NAME, idx);
+       /* Allocate memzone for model data */
+       mz_size = layer_object_size + layer_scratch_size +
+                 PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_req), 
ML_CN10K_ALIGN_SIZE) +
+                 layer_xstats_size;
+       snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u_%u", 
CN10K_ML_LAYER_MEMZONE_NAME,
+                model->model_id, layer_id);
        mz = plt_memzone_reserve_aligned(str, mz_size, 0, ML_CN10K_ALIGN_SIZE);
        if (!mz) {
                plt_err("plt_memzone_reserve failed : %s", str);
                return -ENOMEM;
        }
 
-       model = mz->addr;
-       model->cnxk_mldev = cnxk_mldev;
-       model->model_id = idx;
-       dev->data->models[idx] = model;
+       /* Copy metadata to internal buffer */
+       rte_memcpy(&layer->glow.metadata, buffer, sizeof(struct 
cn10k_ml_model_metadata));
+       cn10k_ml_model_metadata_update(&layer->glow.metadata);
+
+       /* Set layer name */
+       rte_memcpy(layer->name, layer->glow.metadata.model.name, 
MRVL_ML_MODEL_NAME_LEN);
+
+       /* Enable support for batch_size of 256 */
+       if (layer->glow.metadata.model.batch_size == 0)
+               layer->batch_size = 256;
+       else
+               layer->batch_size = layer->glow.metadata.model.batch_size;
+
+       /* Set DMA base address */
+       base_dma_addr = mz->addr;
+       cn10k_ml_layer_addr_update(layer, buffer, base_dma_addr);
+
+       /* Set scratch base address */
+       layer->glow.addr.scratch_base_addr = PLT_PTR_ADD(base_dma_addr, 
layer_object_size);
+
+       /* Update internal I/O data structure */
+       cn10k_ml_layer_io_info_set(&layer->info, &layer->glow.metadata);
+
+       /* Initialize model_mem_map */
+       memset(&layer->glow.ocm_map, 0, sizeof(struct cn10k_ml_ocm_layer_map));
+       layer->glow.ocm_map.ocm_reserved = false;
+       layer->glow.ocm_map.tilemask = 0;
+       layer->glow.ocm_map.wb_page_start = -1;
+       layer->glow.ocm_map.wb_pages = wb_pages;
+       layer->glow.ocm_map.scratch_pages = scratch_pages;
+
+       /* Set slow-path request address and state */
+       layer->glow.req = PLT_PTR_ADD(mz->addr, layer_object_size + 
layer_scratch_size);
+
+       /* Reset burst and sync stats */
+       layer->glow.burst_xstats = PLT_PTR_ADD(
+               layer->glow.req, PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_req), 
ML_CN10K_ALIGN_SIZE));
+       for (qp_id = 0; qp_id < cnxk_mldev->mldev->data->nb_queue_pairs + 1; 
qp_id++) {
+               layer->glow.burst_xstats[qp_id].hw_latency_tot = 0;
+               layer->glow.burst_xstats[qp_id].hw_latency_min = UINT64_MAX;
+               layer->glow.burst_xstats[qp_id].hw_latency_max = 0;
+               layer->glow.burst_xstats[qp_id].fw_latency_tot = 0;
+               layer->glow.burst_xstats[qp_id].fw_latency_min = UINT64_MAX;
+               layer->glow.burst_xstats[qp_id].fw_latency_max = 0;
+               layer->glow.burst_xstats[qp_id].hw_reset_count = 0;
+               layer->glow.burst_xstats[qp_id].fw_reset_count = 0;
+               layer->glow.burst_xstats[qp_id].dequeued_count = 0;
+       }
+
+       layer->glow.sync_xstats =
+               PLT_PTR_ADD(layer->glow.burst_xstats, 
cnxk_mldev->mldev->data->nb_queue_pairs *
+                                                             sizeof(struct 
cn10k_ml_layer_xstats));
+
+       /* Update xstats names */
+       cn10k_ml_xstats_model_name_update(cnxk_mldev->mldev, idx);
+
+       layer->state = ML_CNXK_LAYER_STATE_LOADED;
+       cnxk_mldev->index_map[idx].model_id = model->model_id;
+       cnxk_mldev->index_map[idx].layer_id = layer_id;
+       cnxk_mldev->index_map[idx].active = true;
+       *index = idx;
+
+       return 0;
+}
+
+int
+cn10k_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_model_params 
*params,
+                   struct cnxk_ml_model *model)
+{
+       struct cnxk_ml_layer *layer;
+       int ret;
+
+       /* Metadata check */
+       ret = cn10k_ml_model_metadata_check(params->addr, params->size);
+       if (ret != 0)
+               return ret;
 
+       /* Copy metadata to internal buffer */
        rte_memcpy(&model->glow.metadata, params->addr, sizeof(struct 
cn10k_ml_model_metadata));
        cn10k_ml_model_metadata_update(&model->glow.metadata);
 
@@ -1358,99 +1447,62 @@ cn10k_ml_model_load(struct rte_ml_dev *dev, struct 
rte_ml_model_params *params,
         */
        model->nb_layers = 1;
 
-       /* Copy metadata to internal buffer */
-       rte_memcpy(&model->layer[0].glow.metadata, params->addr,
-                  sizeof(struct cn10k_ml_model_metadata));
-       cn10k_ml_model_metadata_update(&model->layer[0].glow.metadata);
-       model->layer[0].model = model;
-
-       /* Set DMA base address */
-       base_dma_addr = PLT_PTR_ADD(
-               mz->addr, PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_model), 
ML_CN10K_ALIGN_SIZE));
-       cn10k_ml_layer_addr_update(&model->layer[0], params->addr, 
base_dma_addr);
-       model->layer[0].glow.addr.scratch_base_addr =
-               PLT_PTR_ADD(base_dma_addr, 2 * model_data_size);
-
-       /* Copy data from load to run. run address to be used by MLIP */
-       rte_memcpy(model->layer[0].glow.addr.base_dma_addr_run,
-                  model->layer[0].glow.addr.base_dma_addr_load, 
model_data_size);
-
-       /* Update internal I/O data structure */
-       cn10k_ml_layer_info_update(&model->layer[0]);
-
-       /* Initialize model_mem_map */
-       memset(&model->layer[0].glow.ocm_map, 0, sizeof(struct 
cn10k_ml_ocm_layer_map));
-       model->layer[0].glow.ocm_map.ocm_reserved = false;
-       model->layer[0].glow.ocm_map.tilemask = 0;
-       model->layer[0].glow.ocm_map.wb_page_start = -1;
-       model->layer[0].glow.ocm_map.wb_pages = wb_pages;
-       model->layer[0].glow.ocm_map.scratch_pages = scratch_pages;
-
-       /* Set model info */
-       model->info = PLT_PTR_ADD(model->layer[0].glow.addr.scratch_base_addr, 
model_scratch_size);
-       cn10k_ml_model_info_set(dev, model);
-
-       /* Set slow-path request address and state */
-       model->layer[0].glow.req = PLT_PTR_ADD(model->info, model_info_size);
-
-       /* Reset burst and sync stats */
-       model->layer[0].glow.burst_xstats =
-               PLT_PTR_ADD(model->layer[0].glow.req,
-                           PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_req), 
ML_CN10K_ALIGN_SIZE));
-       for (qp_id = 0; qp_id < dev->data->nb_queue_pairs + 1; qp_id++) {
-               model->layer[0].glow.burst_xstats[qp_id].hw_latency_tot = 0;
-               model->layer[0].glow.burst_xstats[qp_id].hw_latency_min = 
UINT64_MAX;
-               model->layer[0].glow.burst_xstats[qp_id].hw_latency_max = 0;
-               model->layer[0].glow.burst_xstats[qp_id].fw_latency_tot = 0;
-               model->layer[0].glow.burst_xstats[qp_id].fw_latency_min = 
UINT64_MAX;
-               model->layer[0].glow.burst_xstats[qp_id].fw_latency_max = 0;
-               model->layer[0].glow.burst_xstats[qp_id].hw_reset_count = 0;
-               model->layer[0].glow.burst_xstats[qp_id].fw_reset_count = 0;
-               model->layer[0].glow.burst_xstats[qp_id].dequeued_count = 0;
+       /* Load layer and get the index */
+       layer = &model->layer[0];
+       ret = cn10k_ml_layer_load(cnxk_mldev, model->model_id, NULL, 
params->addr, params->size,
+                                 &layer->index);
+       if (ret != 0) {
+               plt_err("Model layer load failed: model_id = %u, layer_id = 
%u", model->model_id,
+                       0);
+               return ret;
        }
 
-       model->layer[0].glow.sync_xstats =
-               PLT_PTR_ADD(model->layer[0].glow.burst_xstats,
-                           dev->data->nb_queue_pairs * sizeof(struct 
cn10k_ml_layer_xstats));
-
-       plt_spinlock_init(&model->lock);
-       model->state = ML_CNXK_MODEL_STATE_LOADED;
-       dev->data->models[idx] = model;
-       cnxk_mldev->nb_models_loaded++;
-
-       /* Update xstats names */
-       cn10k_ml_xstats_model_name_update(dev, idx);
-
-       *model_id = idx;
+       cn10k_ml_model_info_set(cnxk_mldev, model, &model->layer[0].info, 
&model->glow.metadata);
 
        return 0;
 }
 
 int
-cn10k_ml_model_unload(struct rte_ml_dev *dev, uint16_t model_id)
+cn10k_ml_layer_unload(void *device, uint16_t model_id, const char *layer_name)
 {
-       char str[RTE_MEMZONE_NAMESIZE];
-       struct cnxk_ml_model *model;
        struct cnxk_ml_dev *cnxk_mldev;
+       struct cnxk_ml_model *model;
+       struct cnxk_ml_layer *layer;
 
-       cnxk_mldev = dev->data->dev_private;
-       model = dev->data->models[model_id];
+       char str[RTE_MEMZONE_NAMESIZE];
+       uint16_t layer_id = 0;
+       int ret;
 
+       PLT_SET_USED(layer_name);
+
+       cnxk_mldev = (struct cnxk_ml_dev *)device;
+       if (cnxk_mldev == NULL) {
+               plt_err("Invalid device = %p", device);
+               return -EINVAL;
+       }
+
+       model = cnxk_mldev->mldev->data->models[model_id];
        if (model == NULL) {
                plt_err("Invalid model_id = %u", model_id);
                return -EINVAL;
        }
 
-       if (model->state != ML_CNXK_MODEL_STATE_LOADED) {
-               plt_err("Cannot unload. Model in use.");
-               return -EBUSY;
-       }
+       layer = &model->layer[layer_id];
 
-       dev->data->models[model_id] = NULL;
-       cnxk_mldev->nb_models_unloaded++;
+       snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u_%u", 
CN10K_ML_LAYER_MEMZONE_NAME,
+                model->model_id, layer_id);
+       ret = plt_memzone_free(plt_memzone_lookup(str));
 
-       snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u", 
CN10K_ML_MODEL_MEMZONE_NAME, model_id);
-       return plt_memzone_free(plt_memzone_lookup(str));
+       layer->state = ML_CNXK_LAYER_STATE_UNKNOWN;
+       cnxk_mldev->index_map[layer->index].active = false;
+
+       return ret;
+}
+
+int
+cn10k_ml_model_unload(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model 
*model)
+{
+       return cn10k_ml_layer_unload(cnxk_mldev, model->model_id, NULL);
 }
 
 int
@@ -1748,7 +1800,6 @@ int
 cn10k_ml_model_params_update(struct rte_ml_dev *dev, uint16_t model_id, void 
*buffer)
 {
        struct cnxk_ml_model *model;
-       size_t size;
 
        model = dev->data->models[model_id];
 
@@ -1762,19 +1813,10 @@ cn10k_ml_model_params_update(struct rte_ml_dev *dev, 
uint16_t model_id, void *bu
        else if (model->state != ML_CNXK_MODEL_STATE_LOADED)
                return -EBUSY;
 
-       size = model->layer[0].glow.metadata.init_model.file_size +
-              model->layer[0].glow.metadata.main_model.file_size +
-              model->layer[0].glow.metadata.finish_model.file_size +
-              model->layer[0].glow.metadata.weights_bias.file_size;
-
        /* Update model weights & bias */
        rte_memcpy(model->layer[0].glow.addr.wb_load_addr, buffer,
                   model->layer[0].glow.metadata.weights_bias.file_size);
 
-       /* Copy data from load to run. run address to be used by MLIP */
-       rte_memcpy(model->layer[0].glow.addr.base_dma_addr_run,
-                  model->layer[0].glow.addr.base_dma_addr_load, size);
-
        return 0;
 }
 
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.h b/drivers/ml/cnxk/cn10k_ml_ops.h
index 2d0a49d5cd..677219dfdf 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.h
+++ b/drivers/ml/cnxk/cn10k_ml_ops.h
@@ -12,6 +12,7 @@
 
 struct cnxk_ml_dev;
 struct cnxk_ml_qp;
+struct cnxk_ml_model;
 
 /* Firmware version string length */
 #define MLDEV_FIRMWARE_VERSION_LENGTH 32
@@ -311,9 +312,9 @@ int cn10k_ml_dev_xstats_reset(struct rte_ml_dev *dev, enum 
rte_ml_dev_xstats_mod
                              int32_t model_id, const uint16_t stat_ids[], 
uint16_t nb_ids);
 
 /* Slow-path ops */
-int cn10k_ml_model_load(struct rte_ml_dev *dev, struct rte_ml_model_params 
*params,
-                       uint16_t *model_id);
-int cn10k_ml_model_unload(struct rte_ml_dev *dev, uint16_t model_id);
+int cn10k_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct 
rte_ml_model_params *params,
+                       struct cnxk_ml_model *model);
+int cn10k_ml_model_unload(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model 
*model);
 int cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t model_id);
 int cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id);
 int cn10k_ml_model_info_get(struct rte_ml_dev *dev, uint16_t model_id,
@@ -339,4 +340,9 @@ __rte_hot int cn10k_ml_inference_sync(struct rte_ml_dev 
*dev, struct rte_ml_op *
 /* Misc ops */
 void cn10k_ml_qp_initialize(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_qp 
*qp);
 
+/* Layer ops */
+int cn10k_ml_layer_load(void *device, uint16_t model_id, const char 
*layer_name, uint8_t *buffer,
+                       size_t size, uint16_t *index);
+int cn10k_ml_layer_unload(void *device, uint16_t model_id, const char 
*layer_name);
+
 #endif /* _CN10K_ML_OPS_H_ */
diff --git a/drivers/ml/cnxk/cnxk_ml_dev.h b/drivers/ml/cnxk/cnxk_ml_dev.h
index 02605fa28f..1590249abd 100644
--- a/drivers/ml/cnxk/cnxk_ml_dev.h
+++ b/drivers/ml/cnxk/cnxk_ml_dev.h
@@ -31,6 +31,18 @@ enum cnxk_ml_dev_state {
        ML_CNXK_DEV_STATE_CLOSED
 };
 
+/* Index to model and layer ID map */
+struct cnxk_ml_index_map {
+       /* Model ID */
+       uint16_t model_id;
+
+       /* Layer ID */
+       uint16_t layer_id;
+
+       /* Layer status */
+       bool active;
+};
+
 /* Device private data */
 struct cnxk_ml_dev {
        /* RTE device */
@@ -56,6 +68,9 @@ struct cnxk_ml_dev {
 
        /* Maximum number of layers */
        uint64_t max_nb_layers;
+
+       /* Index map */
+       struct cnxk_ml_index_map *index_map;
 };
 
 #endif /* _CNXK_ML_DEV_H_ */
diff --git a/drivers/ml/cnxk/cnxk_ml_ops.c b/drivers/ml/cnxk/cnxk_ml_ops.c
index aa56dd2276..1d8b84269d 100644
--- a/drivers/ml/cnxk/cnxk_ml_ops.c
+++ b/drivers/ml/cnxk/cnxk_ml_ops.c
@@ -10,6 +10,9 @@
 #include "cnxk_ml_model.h"
 #include "cnxk_ml_ops.h"
 
+/* ML model macros */
+#define CNXK_ML_MODEL_MEMZONE_NAME "ml_cnxk_model_mz"
+
 static void
 qp_memzone_name_get(char *name, int size, int dev_id, int qp_id)
 {
@@ -137,6 +140,7 @@ cnxk_ml_dev_configure(struct rte_ml_dev *dev, const struct 
rte_ml_dev_config *co
        uint16_t model_id;
        uint32_t mz_size;
        uint16_t qp_id;
+       uint64_t i;
        int ret;
 
        if (dev == NULL)
@@ -240,7 +244,7 @@ cnxk_ml_dev_configure(struct rte_ml_dev *dev, const struct 
rte_ml_dev_config *co
                                                plt_err("Could not stop model 
%u", model_id);
                                }
                                if (model->state == ML_CNXK_MODEL_STATE_LOADED) 
{
-                                       if (cn10k_ml_model_unload(dev, 
model_id) != 0)
+                                       if (cnxk_ml_model_unload(dev, model_id) 
!= 0)
                                                plt_err("Could not unload model 
%u", model_id);
                                }
                                dev->data->models[model_id] = NULL;
@@ -271,6 +275,23 @@ cnxk_ml_dev_configure(struct rte_ml_dev *dev, const struct 
rte_ml_dev_config *co
        cnxk_mldev->max_nb_layers =
                
cnxk_mldev->cn10k_mldev.fw.req->cn10k_req.jd.fw_load.cap.s.max_models;
 
+       /* Allocate and initialize index_map */
+       if (cnxk_mldev->index_map == NULL) {
+               cnxk_mldev->index_map =
+                       rte_zmalloc("cnxk_ml_index_map",
+                                   sizeof(struct cnxk_ml_index_map) * 
cnxk_mldev->max_nb_layers,
+                                   RTE_CACHE_LINE_SIZE);
+               if (cnxk_mldev->index_map == NULL) {
+                       plt_err("Failed to get memory for index_map, nb_layers 
%" PRIu64,
+                               cnxk_mldev->max_nb_layers);
+                       ret = -ENOMEM;
+                       goto error;
+               }
+       }
+
+       for (i = 0; i < cnxk_mldev->max_nb_layers; i++)
+               cnxk_mldev->index_map[i].active = false;
+
        cnxk_mldev->nb_models_loaded = 0;
        cnxk_mldev->nb_models_started = 0;
        cnxk_mldev->nb_models_stopped = 0;
@@ -303,6 +324,9 @@ cnxk_ml_dev_close(struct rte_ml_dev *dev)
        if (cn10k_ml_dev_close(cnxk_mldev) != 0)
                plt_err("Failed to close CN10K ML Device");
 
+       if (cnxk_mldev->index_map)
+               rte_free(cnxk_mldev->index_map);
+
        /* Stop and unload all models */
        for (model_id = 0; model_id < dev->data->nb_models; model_id++) {
                model = dev->data->models[model_id];
@@ -312,7 +336,7 @@ cnxk_ml_dev_close(struct rte_ml_dev *dev)
                                        plt_err("Could not stop model %u", 
model_id);
                        }
                        if (model->state == ML_CNXK_MODEL_STATE_LOADED) {
-                               if (cn10k_ml_model_unload(dev, model_id) != 0)
+                               if (cnxk_ml_model_unload(dev, model_id) != 0)
                                        plt_err("Could not unload model %u", 
model_id);
                        }
                        dev->data->models[model_id] = NULL;
@@ -428,6 +452,118 @@ cnxk_ml_dev_queue_pair_setup(struct rte_ml_dev *dev, 
uint16_t queue_pair_id,
        return 0;
 }
 
+static int
+cnxk_ml_model_load(struct rte_ml_dev *dev, struct rte_ml_model_params *params, 
uint16_t *model_id)
+{
+       struct rte_ml_dev_info dev_info;
+       struct cnxk_ml_dev *cnxk_mldev;
+       struct cnxk_ml_model *model;
+
+       char str[RTE_MEMZONE_NAMESIZE];
+       const struct plt_memzone *mz;
+       uint64_t model_info_size;
+       uint16_t lcl_model_id;
+       uint64_t mz_size;
+       bool found;
+       int ret;
+
+       if (dev == NULL)
+               return -EINVAL;
+
+       cnxk_mldev = dev->data->dev_private;
+
+       /* Find model ID */
+       found = false;
+       for (lcl_model_id = 0; lcl_model_id < dev->data->nb_models; 
lcl_model_id++) {
+               if (dev->data->models[lcl_model_id] == NULL) {
+                       found = true;
+                       break;
+               }
+       }
+
+       if (!found) {
+               plt_err("No slots available to load new model");
+               return -ENOMEM;
+       }
+
+       /* Compute memzone size */
+       cnxk_ml_dev_info_get(dev, &dev_info);
+       mz_size = PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_model), 
dev_info.align_size);
+       model_info_size = sizeof(struct rte_ml_model_info) +
+                         ML_CNXK_MODEL_MAX_INPUT_OUTPUT * sizeof(struct 
rte_ml_io_info) +
+                         ML_CNXK_MODEL_MAX_INPUT_OUTPUT * sizeof(struct 
rte_ml_io_info);
+       model_info_size = PLT_ALIGN_CEIL(model_info_size, dev_info.align_size);
+       mz_size += model_info_size;
+
+       /* Allocate memzone for model object */
+       snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u", 
CNXK_ML_MODEL_MEMZONE_NAME, lcl_model_id);
+       mz = plt_memzone_reserve_aligned(str, mz_size, 0, dev_info.align_size);
+       if (!mz) {
+               plt_err("Failed to allocate memory for cnxk_ml_model: %s", str);
+               return -ENOMEM;
+       }
+
+       model = mz->addr;
+       model->cnxk_mldev = cnxk_mldev;
+       model->model_id = lcl_model_id;
+       model->info = PLT_PTR_ADD(
+               model, PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_model), 
dev_info.align_size));
+       dev->data->models[lcl_model_id] = model;
+
+       ret = cn10k_ml_model_load(cnxk_mldev, params, model);
+       if (ret != 0)
+               goto error;
+
+       plt_spinlock_init(&model->lock);
+       model->state = ML_CNXK_MODEL_STATE_LOADED;
+       cnxk_mldev->nb_models_loaded++;
+
+       *model_id = lcl_model_id;
+
+       return 0;
+
+error:
+       rte_memzone_free(mz);
+
+       return ret;
+}
+
+int
+cnxk_ml_model_unload(struct rte_ml_dev *dev, uint16_t model_id)
+{
+       struct cnxk_ml_dev *cnxk_mldev;
+       struct cnxk_ml_model *model;
+
+       char str[RTE_MEMZONE_NAMESIZE];
+       int ret;
+
+       if (dev == NULL)
+               return -EINVAL;
+
+       cnxk_mldev = dev->data->dev_private;
+
+       model = dev->data->models[model_id];
+       if (model == NULL) {
+               plt_err("Invalid model_id = %u", model_id);
+               return -EINVAL;
+       }
+
+       if (model->state != ML_CNXK_MODEL_STATE_LOADED) {
+               plt_err("Cannot unload. Model in use.");
+               return -EBUSY;
+       }
+
+       ret = cn10k_ml_model_unload(cnxk_mldev, model);
+       if (ret != 0)
+               return ret;
+
+       dev->data->models[model_id] = NULL;
+       cnxk_mldev->nb_models_unloaded++;
+
+       snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u", 
CNXK_ML_MODEL_MEMZONE_NAME, model_id);
+       return plt_memzone_free(plt_memzone_lookup(str));
+}
+
 struct rte_ml_dev_ops cnxk_ml_ops = {
        /* Device control ops */
        .dev_info_get = cnxk_ml_dev_info_get,
@@ -451,8 +587,8 @@ struct rte_ml_dev_ops cnxk_ml_ops = {
        .dev_xstats_reset = cn10k_ml_dev_xstats_reset,
 
        /* Model ops */
-       .model_load = cn10k_ml_model_load,
-       .model_unload = cn10k_ml_model_unload,
+       .model_load = cnxk_ml_model_load,
+       .model_unload = cnxk_ml_model_unload,
        .model_start = cn10k_ml_model_start,
        .model_stop = cn10k_ml_model_stop,
        .model_info_get = cn10k_ml_model_info_get,
diff --git a/drivers/ml/cnxk/cnxk_ml_ops.h b/drivers/ml/cnxk/cnxk_ml_ops.h
index a925c07580..bc14f6e5b9 100644
--- a/drivers/ml/cnxk/cnxk_ml_ops.h
+++ b/drivers/ml/cnxk/cnxk_ml_ops.h
@@ -62,4 +62,6 @@ struct cnxk_ml_qp {
 
 extern struct rte_ml_dev_ops cnxk_ml_ops;
 
+int cnxk_ml_model_unload(struct rte_ml_dev *dev, uint16_t model_id);
+
 #endif /* _CNXK_ML_OPS_H_ */
-- 
2.42.0


Reply via email to