Implement CNXK wrapper functions for dev_info_get,
dev_configure, dev_close, dev_start and dev_stop. The
wrapper functions allocate / release common resources
for the ML driver and invoke device specific functions.

Signed-off-by: Srikanth Yalavarthi <syalavar...@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_ops.c | 230 ++------------------------
 drivers/ml/cnxk/cn10k_ml_ops.h |  16 +-
 drivers/ml/cnxk/cnxk_ml_dev.h  |   3 +
 drivers/ml/cnxk/cnxk_ml_ops.c  | 286 ++++++++++++++++++++++++++++++++-
 drivers/ml/cnxk/cnxk_ml_ops.h  |   3 +
 5 files changed, 314 insertions(+), 224 deletions(-)

diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index a44fb26215..f8c51ab394 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -101,7 +101,7 @@ qp_memzone_name_get(char *name, int size, int dev_id, int 
qp_id)
        snprintf(name, size, "cnxk_ml_qp_mem_%u:%u", dev_id, qp_id);
 }
 
-static int
+int
 cnxk_ml_qp_destroy(const struct rte_ml_dev *dev, struct cnxk_ml_qp *qp)
 {
        const struct rte_memzone *qp_mem;
@@ -861,20 +861,12 @@ cn10k_ml_cache_model_data(struct rte_ml_dev *dev, 
uint16_t model_id)
 }
 
 int
-cn10k_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info *dev_info)
+cn10k_ml_dev_info_get(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_dev_info 
*dev_info)
 {
        struct cn10k_ml_dev *cn10k_mldev;
-       struct cnxk_ml_dev *cnxk_mldev;
 
-       if (dev_info == NULL)
-               return -EINVAL;
-
-       cnxk_mldev = dev->data->dev_private;
        cn10k_mldev = &cnxk_mldev->cn10k_mldev;
 
-       memset(dev_info, 0, sizeof(struct rte_ml_dev_info));
-       dev_info->driver_name = dev->device->driver->name;
-       dev_info->max_models = ML_CNXK_MAX_MODELS;
        if (cn10k_mldev->hw_queue_lock)
                dev_info->max_queue_pairs = ML_CN10K_MAX_QP_PER_DEVICE_SL;
        else
@@ -889,143 +881,17 @@ cn10k_ml_dev_info_get(struct rte_ml_dev *dev, struct 
rte_ml_dev_info *dev_info)
 }
 
 int
-cn10k_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config 
*conf)
+cn10k_ml_dev_configure(struct cnxk_ml_dev *cnxk_mldev, const struct 
rte_ml_dev_config *conf)
 {
-       struct rte_ml_dev_info dev_info;
        struct cn10k_ml_dev *cn10k_mldev;
-       struct cnxk_ml_dev *cnxk_mldev;
-       struct cnxk_ml_model *model;
        struct cn10k_ml_ocm *ocm;
-       struct cnxk_ml_qp *qp;
-       uint16_t model_id;
-       uint32_t mz_size;
        uint16_t tile_id;
-       uint16_t qp_id;
        int ret;
 
-       if (dev == NULL || conf == NULL)
-               return -EINVAL;
+       RTE_SET_USED(conf);
 
-       /* Get CN10K device handle */
-       cnxk_mldev = dev->data->dev_private;
        cn10k_mldev = &cnxk_mldev->cn10k_mldev;
 
-       cn10k_ml_dev_info_get(dev, &dev_info);
-       if (conf->nb_models > dev_info.max_models) {
-               plt_err("Invalid device config, nb_models > %u\n", 
dev_info.max_models);
-               return -EINVAL;
-       }
-
-       if (conf->nb_queue_pairs > dev_info.max_queue_pairs) {
-               plt_err("Invalid device config, nb_queue_pairs > %u\n", 
dev_info.max_queue_pairs);
-               return -EINVAL;
-       }
-
-       if (cnxk_mldev->state == ML_CNXK_DEV_STATE_PROBED) {
-               plt_ml_dbg("Configuring ML device, nb_queue_pairs = %u, 
nb_models = %u",
-                          conf->nb_queue_pairs, conf->nb_models);
-
-               /* Load firmware */
-               ret = cn10k_ml_fw_load(cnxk_mldev);
-               if (ret != 0)
-                       return ret;
-       } else if (cnxk_mldev->state == ML_CNXK_DEV_STATE_CONFIGURED) {
-               plt_ml_dbg("Re-configuring ML device, nb_queue_pairs = %u, 
nb_models = %u",
-                          conf->nb_queue_pairs, conf->nb_models);
-       } else if (cnxk_mldev->state == ML_CNXK_DEV_STATE_STARTED) {
-               plt_err("Device can't be reconfigured in started state\n");
-               return -ENOTSUP;
-       } else if (cnxk_mldev->state == ML_CNXK_DEV_STATE_CLOSED) {
-               plt_err("Device can't be reconfigured after close\n");
-               return -ENOTSUP;
-       }
-
-       /* Configure queue-pairs */
-       if (dev->data->queue_pairs == NULL) {
-               mz_size = sizeof(dev->data->queue_pairs[0]) * 
conf->nb_queue_pairs;
-               dev->data->queue_pairs =
-                       rte_zmalloc("cn10k_mldev_queue_pairs", mz_size, 
RTE_CACHE_LINE_SIZE);
-               if (dev->data->queue_pairs == NULL) {
-                       dev->data->nb_queue_pairs = 0;
-                       plt_err("Failed to get memory for queue_pairs, 
nb_queue_pairs %u",
-                               conf->nb_queue_pairs);
-                       return -ENOMEM;
-               }
-       } else { /* Re-configure */
-               void **queue_pairs;
-
-               /* Release all queue pairs as ML spec doesn't support 
queue_pair_destroy. */
-               for (qp_id = 0; qp_id < dev->data->nb_queue_pairs; qp_id++) {
-                       qp = dev->data->queue_pairs[qp_id];
-                       if (qp != NULL) {
-                               ret = cn10k_ml_dev_queue_pair_release(dev, 
qp_id);
-                               if (ret < 0)
-                                       return ret;
-                       }
-               }
-
-               queue_pairs = dev->data->queue_pairs;
-               queue_pairs =
-                       rte_realloc(queue_pairs, sizeof(queue_pairs[0]) * 
conf->nb_queue_pairs,
-                                   RTE_CACHE_LINE_SIZE);
-               if (queue_pairs == NULL) {
-                       dev->data->nb_queue_pairs = 0;
-                       plt_err("Failed to realloc queue_pairs, nb_queue_pairs 
= %u",
-                               conf->nb_queue_pairs);
-                       ret = -ENOMEM;
-                       goto error;
-               }
-
-               memset(queue_pairs, 0, sizeof(queue_pairs[0]) * 
conf->nb_queue_pairs);
-               dev->data->queue_pairs = queue_pairs;
-       }
-       dev->data->nb_queue_pairs = conf->nb_queue_pairs;
-
-       /* Allocate ML models */
-       if (dev->data->models == NULL) {
-               mz_size = sizeof(dev->data->models[0]) * conf->nb_models;
-               dev->data->models = rte_zmalloc("cn10k_mldev_models", mz_size, 
RTE_CACHE_LINE_SIZE);
-               if (dev->data->models == NULL) {
-                       dev->data->nb_models = 0;
-                       plt_err("Failed to get memory for ml_models, nb_models 
%u",
-                               conf->nb_models);
-                       ret = -ENOMEM;
-                       goto error;
-               }
-       } else {
-               /* Re-configure */
-               void **models;
-
-               /* Stop and unload all models */
-               for (model_id = 0; model_id < dev->data->nb_models; model_id++) 
{
-                       model = dev->data->models[model_id];
-                       if (model != NULL) {
-                               if (model->state == 
ML_CNXK_MODEL_STATE_STARTED) {
-                                       if (cn10k_ml_model_stop(dev, model_id) 
!= 0)
-                                               plt_err("Could not stop model 
%u", model_id);
-                               }
-                               if (model->state == ML_CNXK_MODEL_STATE_LOADED) 
{
-                                       if (cn10k_ml_model_unload(dev, 
model_id) != 0)
-                                               plt_err("Could not unload model 
%u", model_id);
-                               }
-                               dev->data->models[model_id] = NULL;
-                       }
-               }
-
-               models = dev->data->models;
-               models = rte_realloc(models, sizeof(models[0]) * 
conf->nb_models,
-                                    RTE_CACHE_LINE_SIZE);
-               if (models == NULL) {
-                       dev->data->nb_models = 0;
-                       plt_err("Failed to realloc ml_models, nb_models = %u", 
conf->nb_models);
-                       ret = -ENOMEM;
-                       goto error;
-               }
-               memset(models, 0, sizeof(models[0]) * conf->nb_models);
-               dev->data->models = models;
-       }
-       dev->data->nb_models = conf->nb_models;
-
        ocm = &cn10k_mldev->ocm;
        ocm->num_tiles = ML_CN10K_OCM_NUMTILES;
        ocm->size_per_tile = ML_CN10K_OCM_TILESIZE;
@@ -1038,8 +904,7 @@ cn10k_ml_dev_configure(struct rte_ml_dev *dev, const 
struct rte_ml_dev_config *c
                rte_zmalloc("ocm_mask", ocm->mask_words * ocm->num_tiles, 
RTE_CACHE_LINE_SIZE);
        if (ocm->ocm_mask == NULL) {
                plt_err("Unable to allocate memory for OCM mask");
-               ret = -ENOMEM;
-               goto error;
+               return -ENOMEM;
        }
 
        for (tile_id = 0; tile_id < ocm->num_tiles; tile_id++) {
@@ -1050,10 +915,10 @@ cn10k_ml_dev_configure(struct rte_ml_dev *dev, const 
struct rte_ml_dev_config *c
        rte_spinlock_init(&ocm->lock);
 
        /* Initialize xstats */
-       ret = cn10k_ml_xstats_init(dev);
+       ret = cn10k_ml_xstats_init(cnxk_mldev->mldev);
        if (ret != 0) {
                plt_err("Failed to initialize xstats");
-               goto error;
+               return ret;
        }
 
        /* Set JCMDQ enqueue function */
@@ -1067,77 +932,25 @@ cn10k_ml_dev_configure(struct rte_ml_dev *dev, const 
struct rte_ml_dev_config *c
        cn10k_mldev->set_poll_ptr = cn10k_ml_set_poll_ptr;
        cn10k_mldev->get_poll_ptr = cn10k_ml_get_poll_ptr;
 
-       dev->enqueue_burst = cn10k_ml_enqueue_burst;
-       dev->dequeue_burst = cn10k_ml_dequeue_burst;
-       dev->op_error_get = cn10k_ml_op_error_get;
-
-       cnxk_mldev->nb_models_loaded = 0;
-       cnxk_mldev->nb_models_started = 0;
-       cnxk_mldev->nb_models_stopped = 0;
-       cnxk_mldev->nb_models_unloaded = 0;
-       cnxk_mldev->state = ML_CNXK_DEV_STATE_CONFIGURED;
+       cnxk_mldev->mldev->enqueue_burst = cn10k_ml_enqueue_burst;
+       cnxk_mldev->mldev->dequeue_burst = cn10k_ml_dequeue_burst;
+       cnxk_mldev->mldev->op_error_get = cn10k_ml_op_error_get;
 
        return 0;
-
-error:
-       rte_free(dev->data->queue_pairs);
-
-       rte_free(dev->data->models);
-
-       return ret;
 }
 
 int
-cn10k_ml_dev_close(struct rte_ml_dev *dev)
+cn10k_ml_dev_close(struct cnxk_ml_dev *cnxk_mldev)
 {
        struct cn10k_ml_dev *cn10k_mldev;
-       struct cnxk_ml_dev *cnxk_mldev;
-       struct cnxk_ml_model *model;
-       struct cnxk_ml_qp *qp;
-       uint16_t model_id;
-       uint16_t qp_id;
 
-       if (dev == NULL)
-               return -EINVAL;
-
-       cnxk_mldev = dev->data->dev_private;
        cn10k_mldev = &cnxk_mldev->cn10k_mldev;
 
        /* Release ocm_mask memory */
        rte_free(cn10k_mldev->ocm.ocm_mask);
 
-       /* Stop and unload all models */
-       for (model_id = 0; model_id < dev->data->nb_models; model_id++) {
-               model = dev->data->models[model_id];
-               if (model != NULL) {
-                       if (model->state == ML_CNXK_MODEL_STATE_STARTED) {
-                               if (cn10k_ml_model_stop(dev, model_id) != 0)
-                                       plt_err("Could not stop model %u", 
model_id);
-                       }
-                       if (model->state == ML_CNXK_MODEL_STATE_LOADED) {
-                               if (cn10k_ml_model_unload(dev, model_id) != 0)
-                                       plt_err("Could not unload model %u", 
model_id);
-                       }
-                       dev->data->models[model_id] = NULL;
-               }
-       }
-
-       rte_free(dev->data->models);
-
-       /* Destroy all queue pairs */
-       for (qp_id = 0; qp_id < dev->data->nb_queue_pairs; qp_id++) {
-               qp = dev->data->queue_pairs[qp_id];
-               if (qp != NULL) {
-                       if (cnxk_ml_qp_destroy(dev, qp) != 0)
-                               plt_err("Could not destroy queue pair %u", 
qp_id);
-                       dev->data->queue_pairs[qp_id] = NULL;
-               }
-       }
-
-       rte_free(dev->data->queue_pairs);
-
        /* Un-initialize xstats */
-       cn10k_ml_xstats_uninit(dev);
+       cn10k_ml_xstats_uninit(cnxk_mldev->mldev);
 
        /* Unload firmware */
        cn10k_ml_fw_unload(cnxk_mldev);
@@ -1154,20 +967,15 @@ cn10k_ml_dev_close(struct rte_ml_dev *dev)
        roc_ml_reg_write64(&cn10k_mldev->roc, 0, ML_MLR_BASE);
        plt_ml_dbg("ML_MLR_BASE = 0x%016lx", 
roc_ml_reg_read64(&cn10k_mldev->roc, ML_MLR_BASE));
 
-       cnxk_mldev->state = ML_CNXK_DEV_STATE_CLOSED;
-
-       /* Remove PCI device */
-       return rte_dev_remove(dev->device);
+       return 0;
 }
 
 int
-cn10k_ml_dev_start(struct rte_ml_dev *dev)
+cn10k_ml_dev_start(struct cnxk_ml_dev *cnxk_mldev)
 {
        struct cn10k_ml_dev *cn10k_mldev;
-       struct cnxk_ml_dev *cnxk_mldev;
        uint64_t reg_val64;
 
-       cnxk_mldev = dev->data->dev_private;
        cn10k_mldev = &cnxk_mldev->cn10k_mldev;
 
        reg_val64 = roc_ml_reg_read64(&cn10k_mldev->roc, ML_CFG);
@@ -1175,19 +983,15 @@ cn10k_ml_dev_start(struct rte_ml_dev *dev)
        roc_ml_reg_write64(&cn10k_mldev->roc, reg_val64, ML_CFG);
        plt_ml_dbg("ML_CFG => 0x%016lx", roc_ml_reg_read64(&cn10k_mldev->roc, 
ML_CFG));
 
-       cnxk_mldev->state = ML_CNXK_DEV_STATE_STARTED;
-
        return 0;
 }
 
 int
-cn10k_ml_dev_stop(struct rte_ml_dev *dev)
+cn10k_ml_dev_stop(struct cnxk_ml_dev *cnxk_mldev)
 {
        struct cn10k_ml_dev *cn10k_mldev;
-       struct cnxk_ml_dev *cnxk_mldev;
        uint64_t reg_val64;
 
-       cnxk_mldev = dev->data->dev_private;
        cn10k_mldev = &cnxk_mldev->cn10k_mldev;
 
        reg_val64 = roc_ml_reg_read64(&cn10k_mldev->roc, ML_CFG);
@@ -1195,8 +999,6 @@ cn10k_ml_dev_stop(struct rte_ml_dev *dev)
        roc_ml_reg_write64(&cn10k_mldev->roc, reg_val64, ML_CFG);
        plt_ml_dbg("ML_CFG => 0x%016lx", roc_ml_reg_read64(&cn10k_mldev->roc, 
ML_CFG));
 
-       cnxk_mldev->state = ML_CNXK_DEV_STATE_CONFIGURED;
-
        return 0;
 }
 
@@ -1217,7 +1019,7 @@ cn10k_ml_dev_queue_pair_setup(struct rte_ml_dev *dev, 
uint16_t queue_pair_id,
        if (dev->data->queue_pairs[queue_pair_id] != NULL)
                cn10k_ml_dev_queue_pair_release(dev, queue_pair_id);
 
-       cn10k_ml_dev_info_get(dev, &dev_info);
+       cnxk_ml_dev_info_get(dev, &dev_info);
        if ((qp_conf->nb_desc > dev_info.max_desc) || (qp_conf->nb_desc == 0)) {
                plt_err("Could not setup queue pair for %u descriptors", 
qp_conf->nb_desc);
                return -EINVAL;
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.h b/drivers/ml/cnxk/cn10k_ml_ops.h
index 16480b9ad8..d50b5bede7 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.h
+++ b/drivers/ml/cnxk/cn10k_ml_ops.h
@@ -10,6 +10,9 @@
 
 #include <roc_api.h>
 
+struct cnxk_ml_dev;
+struct cnxk_ml_qp;
+
 /* Firmware version string length */
 #define MLDEV_FIRMWARE_VERSION_LENGTH 32
 
@@ -286,11 +289,11 @@ struct cn10k_ml_req {
 };
 
 /* Device ops */
-int cn10k_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info 
*dev_info);
-int cn10k_ml_dev_configure(struct rte_ml_dev *dev, const struct 
rte_ml_dev_config *conf);
-int cn10k_ml_dev_close(struct rte_ml_dev *dev);
-int cn10k_ml_dev_start(struct rte_ml_dev *dev);
-int cn10k_ml_dev_stop(struct rte_ml_dev *dev);
+int cn10k_ml_dev_info_get(struct cnxk_ml_dev *cnxk_mldev, struct 
rte_ml_dev_info *dev_info);
+int cn10k_ml_dev_configure(struct cnxk_ml_dev *cnxk_mldev, const struct 
rte_ml_dev_config *conf);
+int cn10k_ml_dev_close(struct cnxk_ml_dev *cnxk_mldev);
+int cn10k_ml_dev_start(struct cnxk_ml_dev *cnxk_mldev);
+int cn10k_ml_dev_stop(struct cnxk_ml_dev *cnxk_mldev);
 int cn10k_ml_dev_dump(struct rte_ml_dev *dev, FILE *fp);
 int cn10k_ml_dev_selftest(struct rte_ml_dev *dev);
 int cn10k_ml_dev_queue_pair_setup(struct rte_ml_dev *dev, uint16_t 
queue_pair_id,
@@ -336,4 +339,7 @@ __rte_hot int cn10k_ml_op_error_get(struct rte_ml_dev *dev, 
struct rte_ml_op *op
                                    struct rte_ml_op_error *error);
 __rte_hot int cn10k_ml_inference_sync(struct rte_ml_dev *dev, struct rte_ml_op 
*op);
 
+/* Temporarily set below functions as non-static */
+int cnxk_ml_qp_destroy(const struct rte_ml_dev *dev, struct cnxk_ml_qp *qp);
+
 #endif /* _CN10K_ML_OPS_H_ */
diff --git a/drivers/ml/cnxk/cnxk_ml_dev.h b/drivers/ml/cnxk/cnxk_ml_dev.h
index 51315de622..02605fa28f 100644
--- a/drivers/ml/cnxk/cnxk_ml_dev.h
+++ b/drivers/ml/cnxk/cnxk_ml_dev.h
@@ -53,6 +53,9 @@ struct cnxk_ml_dev {
 
        /* CN10K device structure */
        struct cn10k_ml_dev cn10k_mldev;
+
+       /* Maximum number of layers */
+       uint64_t max_nb_layers;
 };
 
 #endif /* _CNXK_ML_DEV_H_ */
diff --git a/drivers/ml/cnxk/cnxk_ml_ops.c b/drivers/ml/cnxk/cnxk_ml_ops.c
index 03402681c5..07a4daabc5 100644
--- a/drivers/ml/cnxk/cnxk_ml_ops.c
+++ b/drivers/ml/cnxk/cnxk_ml_ops.c
@@ -5,15 +5,291 @@
 #include <rte_mldev.h>
 #include <rte_mldev_pmd.h>
 
+#include "cnxk_ml_dev.h"
+#include "cnxk_ml_io.h"
+#include "cnxk_ml_model.h"
 #include "cnxk_ml_ops.h"
 
+int
+cnxk_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info *dev_info)
+{
+       struct cnxk_ml_dev *cnxk_mldev;
+
+       if (dev == NULL || dev_info == NULL)
+               return -EINVAL;
+
+       cnxk_mldev = dev->data->dev_private;
+
+       memset(dev_info, 0, sizeof(struct rte_ml_dev_info));
+       dev_info->driver_name = dev->device->driver->name;
+       dev_info->max_models = ML_CNXK_MAX_MODELS;
+
+       return cn10k_ml_dev_info_get(cnxk_mldev, dev_info);
+}
+
+static int
+cnxk_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config 
*conf)
+{
+       struct rte_ml_dev_info dev_info;
+       struct cnxk_ml_dev *cnxk_mldev;
+       struct cnxk_ml_model *model;
+       struct cnxk_ml_qp *qp;
+       uint16_t model_id;
+       uint32_t mz_size;
+       uint16_t qp_id;
+       int ret;
+
+       if (dev == NULL)
+               return -EINVAL;
+
+       /* Get CNXK device handle */
+       cnxk_mldev = dev->data->dev_private;
+
+       cnxk_ml_dev_info_get(dev, &dev_info);
+       if (conf->nb_models > dev_info.max_models) {
+               plt_err("Invalid device config, nb_models > %u\n", 
dev_info.max_models);
+               return -EINVAL;
+       }
+
+       if (conf->nb_queue_pairs > dev_info.max_queue_pairs) {
+               plt_err("Invalid device config, nb_queue_pairs > %u\n", 
dev_info.max_queue_pairs);
+               return -EINVAL;
+       }
+
+       if (cnxk_mldev->state == ML_CNXK_DEV_STATE_PROBED) {
+               plt_ml_dbg("Configuring ML device, nb_queue_pairs = %u, 
nb_models = %u",
+                          conf->nb_queue_pairs, conf->nb_models);
+
+               /* Load firmware */
+               ret = cn10k_ml_fw_load(cnxk_mldev);
+               if (ret != 0)
+                       return ret;
+       } else if (cnxk_mldev->state == ML_CNXK_DEV_STATE_CONFIGURED) {
+               plt_ml_dbg("Re-configuring ML device, nb_queue_pairs = %u, 
nb_models = %u",
+                          conf->nb_queue_pairs, conf->nb_models);
+       } else if (cnxk_mldev->state == ML_CNXK_DEV_STATE_STARTED) {
+               plt_err("Device can't be reconfigured in started state\n");
+               return -ENOTSUP;
+       } else if (cnxk_mldev->state == ML_CNXK_DEV_STATE_CLOSED) {
+               plt_err("Device can't be reconfigured after close\n");
+               return -ENOTSUP;
+       }
+
+       /* Configure queue-pairs */
+       if (dev->data->queue_pairs == NULL) {
+               mz_size = sizeof(dev->data->queue_pairs[0]) * 
conf->nb_queue_pairs;
+               dev->data->queue_pairs =
+                       rte_zmalloc("cnxk_mldev_queue_pairs", mz_size, 
RTE_CACHE_LINE_SIZE);
+               if (dev->data->queue_pairs == NULL) {
+                       dev->data->nb_queue_pairs = 0;
+                       plt_err("Failed to get memory for queue_pairs, 
nb_queue_pairs %u",
+                               conf->nb_queue_pairs);
+                       return -ENOMEM;
+               }
+       } else { /* Re-configure */
+               void **queue_pairs;
+
+               /* Release all queue pairs as ML spec doesn't support 
queue_pair_destroy. */
+               for (qp_id = 0; qp_id < dev->data->nb_queue_pairs; qp_id++) {
+                       qp = dev->data->queue_pairs[qp_id];
+                       if (qp != NULL) {
+                               ret = cn10k_ml_dev_queue_pair_release(dev, 
qp_id);
+                               if (ret < 0)
+                                       return ret;
+                       }
+               }
+
+               queue_pairs = dev->data->queue_pairs;
+               queue_pairs =
+                       rte_realloc(queue_pairs, sizeof(queue_pairs[0]) * 
conf->nb_queue_pairs,
+                                   RTE_CACHE_LINE_SIZE);
+               if (queue_pairs == NULL) {
+                       dev->data->nb_queue_pairs = 0;
+                       plt_err("Failed to realloc queue_pairs, nb_queue_pairs 
= %u",
+                               conf->nb_queue_pairs);
+                       ret = -ENOMEM;
+                       goto error;
+               }
+
+               memset(queue_pairs, 0, sizeof(queue_pairs[0]) * 
conf->nb_queue_pairs);
+               dev->data->queue_pairs = queue_pairs;
+       }
+       dev->data->nb_queue_pairs = conf->nb_queue_pairs;
+
+       /* Allocate ML models */
+       if (dev->data->models == NULL) {
+               mz_size = sizeof(dev->data->models[0]) * conf->nb_models;
+               dev->data->models = rte_zmalloc("cnxk_mldev_models", mz_size, 
RTE_CACHE_LINE_SIZE);
+               if (dev->data->models == NULL) {
+                       dev->data->nb_models = 0;
+                       plt_err("Failed to get memory for ml_models, nb_models 
%u",
+                               conf->nb_models);
+                       ret = -ENOMEM;
+                       goto error;
+               }
+       } else {
+               /* Re-configure */
+               void **models;
+
+               /* Stop and unload all models */
+               for (model_id = 0; model_id < dev->data->nb_models; model_id++) 
{
+                       model = dev->data->models[model_id];
+                       if (model != NULL) {
+                               if (model->state == 
ML_CNXK_MODEL_STATE_STARTED) {
+                                       if (cn10k_ml_model_stop(dev, model_id) 
!= 0)
+                                               plt_err("Could not stop model 
%u", model_id);
+                               }
+                               if (model->state == ML_CNXK_MODEL_STATE_LOADED) 
{
+                                       if (cn10k_ml_model_unload(dev, 
model_id) != 0)
+                                               plt_err("Could not unload model 
%u", model_id);
+                               }
+                               dev->data->models[model_id] = NULL;
+                       }
+               }
+
+               models = dev->data->models;
+               models = rte_realloc(models, sizeof(models[0]) * 
conf->nb_models,
+                                    RTE_CACHE_LINE_SIZE);
+               if (models == NULL) {
+                       dev->data->nb_models = 0;
+                       plt_err("Failed to realloc ml_models, nb_models = %u", 
conf->nb_models);
+                       ret = -ENOMEM;
+                       goto error;
+               }
+               memset(models, 0, sizeof(models[0]) * conf->nb_models);
+               dev->data->models = models;
+       }
+       dev->data->nb_models = conf->nb_models;
+
+       ret = cn10k_ml_dev_configure(cnxk_mldev, conf);
+       if (ret != 0) {
+               plt_err("Failed to configure CN10K ML Device");
+               goto error;
+       }
+
+       /* Set device capabilities */
+       cnxk_mldev->max_nb_layers =
+               
cnxk_mldev->cn10k_mldev.fw.req->cn10k_req.jd.fw_load.cap.s.max_models;
+
+       cnxk_mldev->nb_models_loaded = 0;
+       cnxk_mldev->nb_models_started = 0;
+       cnxk_mldev->nb_models_stopped = 0;
+       cnxk_mldev->nb_models_unloaded = 0;
+       cnxk_mldev->state = ML_CNXK_DEV_STATE_CONFIGURED;
+
+       return 0;
+
+error:
+       rte_free(dev->data->queue_pairs);
+       rte_free(dev->data->models);
+
+       return ret;
+}
+
+static int
+cnxk_ml_dev_close(struct rte_ml_dev *dev)
+{
+       struct cnxk_ml_dev *cnxk_mldev;
+       struct cnxk_ml_model *model;
+       struct cnxk_ml_qp *qp;
+       uint16_t model_id;
+       uint16_t qp_id;
+
+       if (dev == NULL)
+               return -EINVAL;
+
+       cnxk_mldev = dev->data->dev_private;
+
+       if (cn10k_ml_dev_close(cnxk_mldev) != 0)
+               plt_err("Failed to close CN10K ML Device");
+
+       /* Stop and unload all models */
+       for (model_id = 0; model_id < dev->data->nb_models; model_id++) {
+               model = dev->data->models[model_id];
+               if (model != NULL) {
+                       if (model->state == ML_CNXK_MODEL_STATE_STARTED) {
+                               if (cn10k_ml_model_stop(dev, model_id) != 0)
+                                       plt_err("Could not stop model %u", 
model_id);
+                       }
+                       if (model->state == ML_CNXK_MODEL_STATE_LOADED) {
+                               if (cn10k_ml_model_unload(dev, model_id) != 0)
+                                       plt_err("Could not unload model %u", 
model_id);
+                       }
+                       dev->data->models[model_id] = NULL;
+               }
+       }
+
+       rte_free(dev->data->models);
+
+       /* Destroy all queue pairs */
+       for (qp_id = 0; qp_id < dev->data->nb_queue_pairs; qp_id++) {
+               qp = dev->data->queue_pairs[qp_id];
+               if (qp != NULL) {
+                       if (cnxk_ml_qp_destroy(dev, qp) != 0)
+                               plt_err("Could not destroy queue pair %u", 
qp_id);
+                       dev->data->queue_pairs[qp_id] = NULL;
+               }
+       }
+
+       rte_free(dev->data->queue_pairs);
+
+       cnxk_mldev->state = ML_CNXK_DEV_STATE_CLOSED;
+
+       /* Remove PCI device */
+       return rte_dev_remove(dev->device);
+}
+
+static int
+cnxk_ml_dev_start(struct rte_ml_dev *dev)
+{
+       struct cnxk_ml_dev *cnxk_mldev;
+       int ret;
+
+       if (dev == NULL)
+               return -EINVAL;
+
+       cnxk_mldev = dev->data->dev_private;
+
+       ret = cn10k_ml_dev_start(cnxk_mldev);
+       if (ret != 0) {
+               plt_err("Failed to start CN10K ML Device");
+               return ret;
+       }
+
+       cnxk_mldev->state = ML_CNXK_DEV_STATE_STARTED;
+
+       return 0;
+}
+
+static int
+cnxk_ml_dev_stop(struct rte_ml_dev *dev)
+{
+       struct cnxk_ml_dev *cnxk_mldev;
+       int ret;
+
+       if (dev == NULL)
+               return -EINVAL;
+
+       cnxk_mldev = dev->data->dev_private;
+
+       ret = cn10k_ml_dev_stop(cnxk_mldev);
+       if (ret != 0) {
+               plt_err("Failed to stop CN10K ML Device");
+               return ret;
+       }
+
+       cnxk_mldev->state = ML_CNXK_DEV_STATE_CONFIGURED;
+
+       return 0;
+}
+
 struct rte_ml_dev_ops cnxk_ml_ops = {
        /* Device control ops */
-       .dev_info_get = cn10k_ml_dev_info_get,
-       .dev_configure = cn10k_ml_dev_configure,
-       .dev_close = cn10k_ml_dev_close,
-       .dev_start = cn10k_ml_dev_start,
-       .dev_stop = cn10k_ml_dev_stop,
+       .dev_info_get = cnxk_ml_dev_info_get,
+       .dev_configure = cnxk_ml_dev_configure,
+       .dev_close = cnxk_ml_dev_close,
+       .dev_start = cnxk_ml_dev_start,
+       .dev_stop = cnxk_ml_dev_stop,
        .dev_dump = cn10k_ml_dev_dump,
        .dev_selftest = cn10k_ml_dev_selftest,
 
diff --git a/drivers/ml/cnxk/cnxk_ml_ops.h b/drivers/ml/cnxk/cnxk_ml_ops.h
index a925c07580..2996928d7d 100644
--- a/drivers/ml/cnxk/cnxk_ml_ops.h
+++ b/drivers/ml/cnxk/cnxk_ml_ops.h
@@ -62,4 +62,7 @@ struct cnxk_ml_qp {
 
 extern struct rte_ml_dev_ops cnxk_ml_ops;
 
+/* Temporarily set cnxk driver functions as non-static */
+int cnxk_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info 
*dev_info);
+
 #endif /* _CNXK_ML_OPS_H_ */
-- 
2.42.0

Reply via email to