Added generic cnxk request structure. Moved common fields
from cn10k structures to cnxk structure. Moved job related
structures and enumerations to ops headers.

Signed-off-by: Srikanth Yalavarthi <syalavar...@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_dev.c   |  70 ++++---
 drivers/ml/cnxk/cn10k_ml_dev.h   | 269 +------------------------
 drivers/ml/cnxk/cn10k_ml_model.c |   6 +-
 drivers/ml/cnxk/cn10k_ml_model.h |   4 +-
 drivers/ml/cnxk/cn10k_ml_ops.c   | 329 +++++++++++++++++--------------
 drivers/ml/cnxk/cn10k_ml_ops.h   | 296 +++++++++++++++++++++++----
 drivers/ml/cnxk/cnxk_ml_ops.c    |   7 +
 drivers/ml/cnxk/cnxk_ml_ops.h    |  63 ++++++
 drivers/ml/cnxk/meson.build      |   2 +
 9 files changed, 558 insertions(+), 488 deletions(-)
 create mode 100644 drivers/ml/cnxk/cnxk_ml_ops.c
 create mode 100644 drivers/ml/cnxk/cnxk_ml_ops.h

diff --git a/drivers/ml/cnxk/cn10k_ml_dev.c b/drivers/ml/cnxk/cn10k_ml_dev.c
index 367fb7014c4..f6e05cfc472 100644
--- a/drivers/ml/cnxk/cn10k_ml_dev.c
+++ b/drivers/ml/cnxk/cn10k_ml_dev.c
@@ -23,6 +23,7 @@
 #include "cn10k_ml_ops.h"
 
 #include "cnxk_ml_dev.h"
+#include "cnxk_ml_ops.h"
 
 #define CN10K_ML_FW_PATH               "fw_path"
 #define CN10K_ML_FW_ENABLE_DPE_WARNINGS "enable_dpe_warnings"
@@ -457,20 +458,23 @@ cn10k_ml_pci_remove(struct rte_pci_device *pci_dev)
 static void
 cn10k_ml_fw_print_info(struct cn10k_ml_fw *fw)
 {
-       plt_info("ML Firmware Version = %s", fw->req->jd.fw_load.version);
-
-       plt_ml_dbg("Firmware capabilities = 0x%016lx", 
fw->req->jd.fw_load.cap.u64);
-       plt_ml_dbg("Version = %s", fw->req->jd.fw_load.version);
-       plt_ml_dbg("core0_debug_ptr = 0x%016lx", 
fw->req->jd.fw_load.debug.core0_debug_ptr);
-       plt_ml_dbg("core1_debug_ptr = 0x%016lx", 
fw->req->jd.fw_load.debug.core1_debug_ptr);
-       plt_ml_dbg("debug_buffer_size = %u bytes", 
fw->req->jd.fw_load.debug.debug_buffer_size);
+       plt_info("ML Firmware Version = %s", 
fw->req->cn10k_req.jd.fw_load.version);
+
+       plt_ml_dbg("Firmware capabilities = 0x%016lx", 
fw->req->cn10k_req.jd.fw_load.cap.u64);
+       plt_ml_dbg("Version = %s", fw->req->cn10k_req.jd.fw_load.version);
+       plt_ml_dbg("core0_debug_ptr = 0x%016lx",
+                  fw->req->cn10k_req.jd.fw_load.debug.core0_debug_ptr);
+       plt_ml_dbg("core1_debug_ptr = 0x%016lx",
+                  fw->req->cn10k_req.jd.fw_load.debug.core1_debug_ptr);
+       plt_ml_dbg("debug_buffer_size = %u bytes",
+                  fw->req->cn10k_req.jd.fw_load.debug.debug_buffer_size);
        plt_ml_dbg("core0_exception_buffer = 0x%016lx",
-                  fw->req->jd.fw_load.debug.core0_exception_buffer);
+                  fw->req->cn10k_req.jd.fw_load.debug.core0_exception_buffer);
        plt_ml_dbg("core1_exception_buffer = 0x%016lx",
-                  fw->req->jd.fw_load.debug.core1_exception_buffer);
+                  fw->req->cn10k_req.jd.fw_load.debug.core1_exception_buffer);
        plt_ml_dbg("exception_state_size = %u bytes",
-                  fw->req->jd.fw_load.debug.exception_state_size);
-       plt_ml_dbg("flags = 0x%016lx", fw->req->jd.fw_load.flags);
+                  fw->req->cn10k_req.jd.fw_load.debug.exception_state_size);
+       plt_ml_dbg("flags = 0x%016lx", fw->req->cn10k_req.jd.fw_load.flags);
 }
 
 uint64_t
@@ -515,29 +519,30 @@ cn10k_ml_fw_load_asim(struct cn10k_ml_fw *fw)
        roc_ml_reg_save(&cn10k_mldev->roc, ML_MLR_BASE);
 
        /* Update FW load completion structure */
-       fw->req->jd.hdr.jce.w1.u64 = PLT_U64_CAST(&fw->req->status);
-       fw->req->jd.hdr.job_type = ML_CN10K_JOB_TYPE_FIRMWARE_LOAD;
-       fw->req->jd.hdr.result = roc_ml_addr_ap2mlip(&cn10k_mldev->roc, 
&fw->req->result);
-       fw->req->jd.fw_load.flags = cn10k_ml_fw_flags_get(fw);
-       plt_write64(ML_CNXK_POLL_JOB_START, &fw->req->status);
+       fw->req->cn10k_req.jd.hdr.jce.w1.u64 = 
PLT_U64_CAST(&fw->req->cn10k_req.status);
+       fw->req->cn10k_req.jd.hdr.job_type = ML_CN10K_JOB_TYPE_FIRMWARE_LOAD;
+       fw->req->cn10k_req.jd.hdr.result =
+               roc_ml_addr_ap2mlip(&cn10k_mldev->roc, 
&fw->req->cn10k_req.result);
+       fw->req->cn10k_req.jd.fw_load.flags = cn10k_ml_fw_flags_get(fw);
+       plt_write64(ML_CNXK_POLL_JOB_START, &fw->req->cn10k_req.status);
        plt_wmb();
 
        /* Enqueue FW load through scratch registers */
        timeout = true;
        timeout_cycle = plt_tsc_cycles() + ML_CNXK_CMD_TIMEOUT * plt_tsc_hz();
-       roc_ml_scratch_enqueue(&cn10k_mldev->roc, &fw->req->jd);
+       roc_ml_scratch_enqueue(&cn10k_mldev->roc, &fw->req->cn10k_req.jd);
 
        plt_rmb();
        do {
                if (roc_ml_scratch_is_done_bit_set(&cn10k_mldev->roc) &&
-                   (plt_read64(&fw->req->status) == ML_CNXK_POLL_JOB_FINISH)) {
+                   (plt_read64(&fw->req->cn10k_req.status) == 
ML_CNXK_POLL_JOB_FINISH)) {
                        timeout = false;
                        break;
                }
        } while (plt_tsc_cycles() < timeout_cycle);
 
        /* Check firmware load status, clean-up and exit on failure. */
-       if ((!timeout) && (fw->req->result.error_code.u64 == 0)) {
+       if ((!timeout) && (fw->req->cn10k_req.result.error_code == 0)) {
                cn10k_ml_fw_print_info(fw);
        } else {
                /* Set ML to disable new jobs */
@@ -711,29 +716,30 @@ cn10k_ml_fw_load_cn10ka(struct cn10k_ml_fw *fw, void 
*buffer, uint64_t size)
        plt_ml_dbg("ML_SW_RST_CTRL => 0x%08x", reg_val32);
 
        /* (12) Wait for notification from firmware that ML is ready for job 
execution. */
-       fw->req->jd.hdr.jce.w1.u64 = PLT_U64_CAST(&fw->req->status);
-       fw->req->jd.hdr.job_type = ML_CN10K_JOB_TYPE_FIRMWARE_LOAD;
-       fw->req->jd.hdr.result = roc_ml_addr_ap2mlip(&cn10k_mldev->roc, 
&fw->req->result);
-       fw->req->jd.fw_load.flags = cn10k_ml_fw_flags_get(fw);
-       plt_write64(ML_CNXK_POLL_JOB_START, &fw->req->status);
+       fw->req->cn10k_req.jd.hdr.jce.w1.u64 = 
PLT_U64_CAST(&fw->req->cn10k_req.status);
+       fw->req->cn10k_req.jd.hdr.job_type = ML_CN10K_JOB_TYPE_FIRMWARE_LOAD;
+       fw->req->cn10k_req.jd.hdr.result =
+               roc_ml_addr_ap2mlip(&cn10k_mldev->roc, 
&fw->req->cn10k_req.result);
+       fw->req->cn10k_req.jd.fw_load.flags = cn10k_ml_fw_flags_get(fw);
+       plt_write64(ML_CNXK_POLL_JOB_START, &fw->req->cn10k_req.status);
        plt_wmb();
 
        /* Enqueue FW load through scratch registers */
        timeout = true;
        timeout_cycle = plt_tsc_cycles() + ML_CNXK_CMD_TIMEOUT * plt_tsc_hz();
-       roc_ml_scratch_enqueue(&cn10k_mldev->roc, &fw->req->jd);
+       roc_ml_scratch_enqueue(&cn10k_mldev->roc, &fw->req->cn10k_req.jd);
 
        plt_rmb();
        do {
                if (roc_ml_scratch_is_done_bit_set(&cn10k_mldev->roc) &&
-                   (plt_read64(&fw->req->status) == ML_CNXK_POLL_JOB_FINISH)) {
+                   (plt_read64(&fw->req->cn10k_req.status) == 
ML_CNXK_POLL_JOB_FINISH)) {
                        timeout = false;
                        break;
                }
        } while (plt_tsc_cycles() < timeout_cycle);
 
        /* Check firmware load status, clean-up and exit on failure. */
-       if ((!timeout) && (fw->req->result.error_code.u64 == 0)) {
+       if ((!timeout) && (fw->req->cn10k_req.result.error_code == 0)) {
                cn10k_ml_fw_print_info(fw);
        } else {
                /* Set ML to disable new jobs */
@@ -823,11 +829,11 @@ cn10k_ml_fw_load(struct cnxk_ml_dev *cnxk_mldev)
                }
 
                /* Reserve memzone for firmware load completion and data */
-               mz_size = sizeof(struct cn10k_ml_req) + fw_size + 
FW_STACK_BUFFER_SIZE +
+               mz_size = sizeof(struct cnxk_ml_req) + fw_size + 
FW_STACK_BUFFER_SIZE +
                          FW_DEBUG_BUFFER_SIZE + FW_EXCEPTION_BUFFER_SIZE;
        } else if (roc_env_is_asim()) {
                /* Reserve memzone for firmware load completion */
-               mz_size = sizeof(struct cn10k_ml_req);
+               mz_size = sizeof(struct cnxk_ml_req);
        }
 
        mz = plt_memzone_reserve_aligned(FW_MEMZONE_NAME, mz_size, 0, 
ML_CN10K_ALIGN_SIZE);
@@ -839,8 +845,8 @@ cn10k_ml_fw_load(struct cnxk_ml_dev *cnxk_mldev)
        fw->req = mz->addr;
 
        /* Reset firmware load completion structure */
-       memset(&fw->req->jd, 0, sizeof(struct cn10k_ml_jd));
-       memset(&fw->req->jd.fw_load.version[0], '\0', 
MLDEV_FIRMWARE_VERSION_LENGTH);
+       memset(&fw->req->cn10k_req.jd, 0, sizeof(struct cn10k_ml_jd));
+       memset(&fw->req->cn10k_req.jd.fw_load.version[0], '\0', 
MLDEV_FIRMWARE_VERSION_LENGTH);
 
        /* Reset device, if in active state */
        if (roc_ml_mlip_is_enabled(&cn10k_mldev->roc))
@@ -848,7 +854,7 @@ cn10k_ml_fw_load(struct cnxk_ml_dev *cnxk_mldev)
 
        /* Load firmware */
        if (roc_env_is_emulator() || roc_env_is_hw()) {
-               fw->data = PLT_PTR_ADD(mz->addr, sizeof(struct cn10k_ml_req));
+               fw->data = PLT_PTR_ADD(mz->addr, sizeof(struct cnxk_ml_req));
                ret = cn10k_ml_fw_load_cn10ka(fw, fw_buffer, fw_size);
                rte_free(fw_buffer);
        } else if (roc_env_is_asim()) {
diff --git a/drivers/ml/cnxk/cn10k_ml_dev.h b/drivers/ml/cnxk/cn10k_ml_dev.h
index 99ff0a344a2..1852d4f6c9a 100644
--- a/drivers/ml/cnxk/cn10k_ml_dev.h
+++ b/drivers/ml/cnxk/cn10k_ml_dev.h
@@ -17,9 +17,6 @@ extern struct rte_ml_dev_ops ml_dev_dummy_ops;
 /* Marvell OCTEON CN10K ML PMD device name */
 #define MLDEV_NAME_CN10K_PMD ml_cn10k
 
-/* Firmware version string length */
-#define MLDEV_FIRMWARE_VERSION_LENGTH 32
-
 /* Device alignment size */
 #define ML_CN10K_ALIGN_SIZE 128
 
@@ -52,17 +49,8 @@ extern struct rte_ml_dev_ops ml_dev_dummy_ops;
 #endif
 
 struct cnxk_ml_dev;
-struct cn10k_ml_req;
-struct cn10k_ml_qp;
-
-/* Job types */
-enum cn10k_ml_job_type {
-       ML_CN10K_JOB_TYPE_MODEL_RUN = 0,
-       ML_CN10K_JOB_TYPE_MODEL_STOP,
-       ML_CN10K_JOB_TYPE_MODEL_START,
-       ML_CN10K_JOB_TYPE_FIRMWARE_LOAD,
-       ML_CN10K_JOB_TYPE_FIRMWARE_SELFTEST,
-};
+struct cnxk_ml_req;
+struct cnxk_ml_qp;
 
 /* Error types enumeration */
 enum cn10k_ml_error_etype {
@@ -112,251 +100,6 @@ union cn10k_ml_error_code {
        uint64_t u64;
 };
 
-/* Firmware stats */
-struct cn10k_ml_fw_stats {
-       /* Firmware start cycle */
-       uint64_t fw_start;
-
-       /* Firmware end cycle */
-       uint64_t fw_end;
-
-       /* Hardware start cycle */
-       uint64_t hw_start;
-
-       /* Hardware end cycle */
-       uint64_t hw_end;
-};
-
-/* Result structure */
-struct cn10k_ml_result {
-       /* Job error code */
-       union cn10k_ml_error_code error_code;
-
-       /* Firmware stats */
-       struct cn10k_ml_fw_stats stats;
-
-       /* User context pointer */
-       void *user_ptr;
-};
-
-/* Firmware capability structure */
-union cn10k_ml_fw_cap {
-       uint64_t u64;
-
-       struct {
-               /* CMPC completion support */
-               uint64_t cmpc_completions : 1;
-
-               /* Poll mode completion support */
-               uint64_t poll_completions : 1;
-
-               /* SSO completion support */
-               uint64_t sso_completions : 1;
-
-               /* Support for model side loading */
-               uint64_t side_load_model : 1;
-
-               /* Batch execution */
-               uint64_t batch_run : 1;
-
-               /* Max number of models to be loaded in parallel */
-               uint64_t max_models : 8;
-
-               /* Firmware statistics */
-               uint64_t fw_stats : 1;
-
-               /* Hardware statistics */
-               uint64_t hw_stats : 1;
-
-               /* Max number of batches */
-               uint64_t max_num_batches : 16;
-
-               uint64_t rsvd : 33;
-       } s;
-};
-
-/* Firmware debug info structure */
-struct cn10k_ml_fw_debug {
-       /* ACC core 0 debug buffer */
-       uint64_t core0_debug_ptr;
-
-       /* ACC core 1 debug buffer */
-       uint64_t core1_debug_ptr;
-
-       /* ACC core 0 exception state buffer */
-       uint64_t core0_exception_buffer;
-
-       /* ACC core 1 exception state buffer */
-       uint64_t core1_exception_buffer;
-
-       /* Debug buffer size per core */
-       uint32_t debug_buffer_size;
-
-       /* Exception state dump size */
-       uint32_t exception_state_size;
-};
-
-/* Job descriptor header (32 bytes) */
-struct cn10k_ml_jd_header {
-       /* Job completion structure */
-       struct ml_jce_s jce;
-
-       /* Model ID */
-       uint64_t model_id : 8;
-
-       /* Job type */
-       uint64_t job_type : 8;
-
-       /* Flags for fast-path jobs */
-       uint64_t fp_flags : 16;
-
-       /* Flags for slow-path jobs */
-       uint64_t sp_flags : 16;
-       uint64_t rsvd : 16;
-
-       /* Job result pointer */
-       uint64_t *result;
-};
-
-/* Extra arguments for job descriptor */
-union cn10k_ml_jd_extended_args {
-       struct cn10k_ml_jd_extended_args_section_start {
-               /** DDR Scratch base address */
-               uint64_t ddr_scratch_base_address;
-
-               /** DDR Scratch range start */
-               uint64_t ddr_scratch_range_start;
-
-               /** DDR Scratch range end */
-               uint64_t ddr_scratch_range_end;
-
-               uint8_t rsvd[104];
-       } start;
-};
-
-/* Job descriptor structure */
-struct cn10k_ml_jd {
-       /* Job descriptor header (32 bytes) */
-       struct cn10k_ml_jd_header hdr;
-
-       union {
-               struct cn10k_ml_jd_section_fw_load {
-                       /* Firmware capability structure (8 bytes) */
-                       union cn10k_ml_fw_cap cap;
-
-                       /* Firmware version (32 bytes) */
-                       uint8_t version[MLDEV_FIRMWARE_VERSION_LENGTH];
-
-                       /* Debug capability structure (40 bytes) */
-                       struct cn10k_ml_fw_debug debug;
-
-                       /* Flags to control error handling */
-                       uint64_t flags;
-
-                       uint8_t rsvd[8];
-               } fw_load;
-
-               struct cn10k_ml_jd_section_model_start {
-                       /* Extended arguments */
-                       uint64_t extended_args;
-
-                       /* Destination model start address in DDR relative to 
ML_MLR_BASE */
-                       uint64_t model_dst_ddr_addr;
-
-                       /* Offset to model init section in the model */
-                       uint64_t model_init_offset : 32;
-
-                       /* Size of init section in the model */
-                       uint64_t model_init_size : 32;
-
-                       /* Offset to model main section in the model */
-                       uint64_t model_main_offset : 32;
-
-                       /* Size of main section in the model */
-                       uint64_t model_main_size : 32;
-
-                       /* Offset to model finish section in the model */
-                       uint64_t model_finish_offset : 32;
-
-                       /* Size of finish section in the model */
-                       uint64_t model_finish_size : 32;
-
-                       /* Offset to WB in model bin */
-                       uint64_t model_wb_offset : 32;
-
-                       /* Number of model layers */
-                       uint64_t num_layers : 8;
-
-                       /* Number of gather entries, 0 means linear input mode 
(= no gather) */
-                       uint64_t num_gather_entries : 8;
-
-                       /* Number of scatter entries 0 means linear input mode 
(= no scatter) */
-                       uint64_t num_scatter_entries : 8;
-
-                       /* Tile mask to load model */
-                       uint64_t tilemask : 8;
-
-                       /* Batch size of model  */
-                       uint64_t batch_size : 32;
-
-                       /* OCM WB base address */
-                       uint64_t ocm_wb_base_address : 32;
-
-                       /* OCM WB range start */
-                       uint64_t ocm_wb_range_start : 32;
-
-                       /* OCM WB range End */
-                       uint64_t ocm_wb_range_end : 32;
-
-                       /* DDR WB address */
-                       uint64_t ddr_wb_base_address;
-
-                       /* DDR WB range start */
-                       uint64_t ddr_wb_range_start : 32;
-
-                       /* DDR WB range end */
-                       uint64_t ddr_wb_range_end : 32;
-
-                       union {
-                               /* Points to gather list if num_gather_entries 
> 0 */
-                               void *gather_list;
-                               struct {
-                                       /* Linear input mode */
-                                       uint64_t ddr_range_start : 32;
-                                       uint64_t ddr_range_end : 32;
-                               } s;
-                       } input;
-
-                       union {
-                               /* Points to scatter list if 
num_scatter_entries > 0 */
-                               void *scatter_list;
-                               struct {
-                                       /* Linear output mode */
-                                       uint64_t ddr_range_start : 32;
-                                       uint64_t ddr_range_end : 32;
-                               } s;
-                       } output;
-               } model_start;
-
-               struct cn10k_ml_jd_section_model_stop {
-                       uint8_t rsvd[96];
-               } model_stop;
-
-               struct cn10k_ml_jd_section_model_run {
-                       /* Address of the input for the run relative to 
ML_MLR_BASE */
-                       uint64_t input_ddr_addr;
-
-                       /* Address of the output for the run relative to 
ML_MLR_BASE */
-                       uint64_t output_ddr_addr;
-
-                       /* Number of batches to run in variable batch 
processing */
-                       uint16_t num_batches;
-
-                       uint8_t rsvd[78];
-               } model_run;
-       };
-};
-
 /* ML firmware structure */
 struct cn10k_ml_fw {
        /* Device reference */
@@ -375,7 +118,7 @@ struct cn10k_ml_fw {
        uint8_t *data;
 
        /* Firmware load / handshake request structure */
-       struct cn10k_ml_req *req;
+       struct cnxk_ml_req *req;
 };
 
 /* Extended stats types enum */
@@ -488,9 +231,9 @@ struct cn10k_ml_dev {
        bool (*ml_jcmdq_enqueue)(struct roc_ml *roc_ml, struct ml_job_cmd_s 
*job_cmd);
 
        /* Poll handling function pointers */
-       void (*set_poll_addr)(struct cn10k_ml_req *req);
-       void (*set_poll_ptr)(struct cn10k_ml_req *req);
-       uint64_t (*get_poll_ptr)(struct cn10k_ml_req *req);
+       void (*set_poll_addr)(struct cnxk_ml_req *req);
+       void (*set_poll_ptr)(struct cnxk_ml_req *req);
+       uint64_t (*get_poll_ptr)(struct cnxk_ml_req *req);
 };
 
 uint64_t cn10k_ml_fw_flags_get(struct cn10k_ml_fw *fw);
diff --git a/drivers/ml/cnxk/cn10k_ml_model.c b/drivers/ml/cnxk/cn10k_ml_model.c
index 0ea6520bf78..2a0ae44cfd5 100644
--- a/drivers/ml/cnxk/cn10k_ml_model.c
+++ b/drivers/ml/cnxk/cn10k_ml_model.c
@@ -12,6 +12,7 @@
 
 #include "cnxk_ml_dev.h"
 #include "cnxk_ml_model.h"
+#include "cnxk_ml_ops.h"
 
 static enum rte_ml_io_type
 cn10k_ml_io_type_map(uint8_t type)
@@ -551,7 +552,6 @@ void
 cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct cnxk_ml_model *model)
 {
        struct cn10k_ml_model_metadata *metadata;
-       struct cn10k_ml_dev *cn10k_mldev;
        struct cnxk_ml_dev *cnxk_mldev;
        struct rte_ml_model_info *info;
        struct rte_ml_io_info *output;
@@ -560,7 +560,6 @@ cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct 
cnxk_ml_model *model)
        uint8_t i;
 
        cnxk_mldev = dev->data->dev_private;
-       cn10k_mldev = &cnxk_mldev->cn10k_mldev;
        metadata = &model->glow.metadata;
        info = PLT_PTR_CAST(model->info);
        input = PLT_PTR_ADD(info, sizeof(struct rte_ml_model_info));
@@ -577,7 +576,8 @@ cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct 
cnxk_ml_model *model)
        info->io_layout = RTE_ML_IO_LAYOUT_PACKED;
        info->min_batches = model->batch_size;
        info->max_batches =
-               cn10k_mldev->fw.req->jd.fw_load.cap.s.max_num_batches / 
model->batch_size;
+               
cnxk_mldev->cn10k_mldev.fw.req->cn10k_req.jd.fw_load.cap.s.max_num_batches /
+               model->batch_size;
        info->nb_inputs = metadata->model.num_input;
        info->input_info = input;
        info->nb_outputs = metadata->model.num_output;
diff --git a/drivers/ml/cnxk/cn10k_ml_model.h b/drivers/ml/cnxk/cn10k_ml_model.h
index 206a369ca75..74ada1531a8 100644
--- a/drivers/ml/cnxk/cn10k_ml_model.h
+++ b/drivers/ml/cnxk/cn10k_ml_model.h
@@ -11,10 +11,10 @@
 
 #include "cn10k_ml_dev.h"
 #include "cn10k_ml_ocm.h"
-#include "cn10k_ml_ops.h"
 
 struct cnxk_ml_model;
 struct cnxk_ml_layer;
+struct cnxk_ml_req;
 
 /* Model Metadata : v 2.3.0.1 */
 #define MRVL_ML_MODEL_MAGIC_STRING "MRVL"
@@ -444,7 +444,7 @@ struct cn10k_ml_layer_data {
        struct cn10k_ml_ocm_layer_map ocm_map;
 
        /* Layer: Slow-path operations request pointer */
-       struct cn10k_ml_req *req;
+       struct cnxk_ml_req *req;
 
        /* Layer: Stats for burst ops */
        struct cn10k_ml_layer_stats *burst_stats;
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index a52509630fe..2b1fa08154d 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -13,6 +13,7 @@
 
 #include "cnxk_ml_dev.h"
 #include "cnxk_ml_model.h"
+#include "cnxk_ml_ops.h"
 
 /* ML model macros */
 #define CN10K_ML_MODEL_MEMZONE_NAME "ml_cn10k_model_mz"
@@ -80,31 +81,31 @@ print_line(FILE *fp, int len)
 }
 
 static inline void
-cn10k_ml_set_poll_addr(struct cn10k_ml_req *req)
+cn10k_ml_set_poll_addr(struct cnxk_ml_req *req)
 {
-       req->compl_W1 = PLT_U64_CAST(&req->status);
+       req->status = &req->cn10k_req.status;
 }
 
 static inline void
-cn10k_ml_set_poll_ptr(struct cn10k_ml_req *req)
+cn10k_ml_set_poll_ptr(struct cnxk_ml_req *req)
 {
-       plt_write64(ML_CNXK_POLL_JOB_START, req->compl_W1);
+       plt_write64(ML_CNXK_POLL_JOB_START, req->status);
 }
 
 static inline uint64_t
-cn10k_ml_get_poll_ptr(struct cn10k_ml_req *req)
+cn10k_ml_get_poll_ptr(struct cnxk_ml_req *req)
 {
-       return plt_read64(req->compl_W1);
+       return plt_read64(req->status);
 }
 
 static void
 qp_memzone_name_get(char *name, int size, int dev_id, int qp_id)
 {
-       snprintf(name, size, "cn10k_ml_qp_mem_%u:%u", dev_id, qp_id);
+       snprintf(name, size, "cnxk_ml_qp_mem_%u:%u", dev_id, qp_id);
 }
 
 static int
-cn10k_ml_qp_destroy(const struct rte_ml_dev *dev, struct cn10k_ml_qp *qp)
+cnxk_ml_qp_destroy(const struct rte_ml_dev *dev, struct cnxk_ml_qp *qp)
 {
        const struct rte_memzone *qp_mem;
        char name[RTE_MEMZONE_NAMESIZE];
@@ -124,14 +125,14 @@ cn10k_ml_qp_destroy(const struct rte_ml_dev *dev, struct 
cn10k_ml_qp *qp)
 static int
 cn10k_ml_dev_queue_pair_release(struct rte_ml_dev *dev, uint16_t queue_pair_id)
 {
-       struct cn10k_ml_qp *qp;
+       struct cnxk_ml_qp *qp;
        int ret;
 
        qp = dev->data->queue_pairs[queue_pair_id];
        if (qp == NULL)
                return -EINVAL;
 
-       ret = cn10k_ml_qp_destroy(dev, qp);
+       ret = cnxk_ml_qp_destroy(dev, qp);
        if (ret) {
                plt_err("Could not destroy queue pair %u", queue_pair_id);
                return ret;
@@ -142,18 +143,18 @@ cn10k_ml_dev_queue_pair_release(struct rte_ml_dev *dev, 
uint16_t queue_pair_id)
        return 0;
 }
 
-static struct cn10k_ml_qp *
-cn10k_ml_qp_create(const struct rte_ml_dev *dev, uint16_t qp_id, uint32_t 
nb_desc, int socket_id)
+static struct cnxk_ml_qp *
+cnxk_ml_qp_create(const struct rte_ml_dev *dev, uint16_t qp_id, uint32_t 
nb_desc, int socket_id)
 {
        const struct rte_memzone *qp_mem;
        char name[RTE_MEMZONE_NAMESIZE];
-       struct cn10k_ml_qp *qp;
+       struct cnxk_ml_qp *qp;
        uint32_t len;
        uint8_t *va;
        uint64_t i;
 
        /* Allocate queue pair */
-       qp = rte_zmalloc_socket("cn10k_ml_pmd_queue_pair", sizeof(struct 
cn10k_ml_qp), ROC_ALIGN,
+       qp = rte_zmalloc_socket("cn10k_ml_pmd_queue_pair", sizeof(struct 
cnxk_ml_qp), ROC_ALIGN,
                                socket_id);
        if (qp == NULL) {
                plt_err("Could not allocate queue pair");
@@ -161,7 +162,7 @@ cn10k_ml_qp_create(const struct rte_ml_dev *dev, uint16_t 
qp_id, uint32_t nb_des
        }
 
        /* For request queue */
-       len = nb_desc * sizeof(struct cn10k_ml_req);
+       len = nb_desc * sizeof(struct cnxk_ml_req);
        qp_memzone_name_get(name, RTE_MEMZONE_NAMESIZE, dev->data->dev_id, 
qp_id);
        qp_mem = rte_memzone_reserve_aligned(
                name, len, socket_id, RTE_MEMZONE_SIZE_HINT_ONLY | 
RTE_MEMZONE_256MB, ROC_ALIGN);
@@ -175,7 +176,7 @@ cn10k_ml_qp_create(const struct rte_ml_dev *dev, uint16_t 
qp_id, uint32_t nb_des
 
        /* Initialize Request queue */
        qp->id = qp_id;
-       qp->queue.reqs = (struct cn10k_ml_req *)va;
+       qp->queue.reqs = (struct cnxk_ml_req *)va;
        qp->queue.head = 0;
        qp->queue.tail = 0;
        qp->queue.wait_cycles = ML_CNXK_CMD_TIMEOUT * plt_tsc_hz();
@@ -187,8 +188,9 @@ cn10k_ml_qp_create(const struct rte_ml_dev *dev, uint16_t 
qp_id, uint32_t nb_des
 
        /* Initialize job command */
        for (i = 0; i < qp->nb_desc; i++) {
-               memset(&qp->queue.reqs[i].jd, 0, sizeof(struct cn10k_ml_jd));
-               qp->queue.reqs[i].jcmd.w1.s.jobptr = 
PLT_U64_CAST(&qp->queue.reqs[i].jd);
+               memset(&qp->queue.reqs[i].cn10k_req.jd, 0, sizeof(struct 
cn10k_ml_jd));
+               qp->queue.reqs[i].cn10k_req.jcmd.w1.s.jobptr =
+                       PLT_U64_CAST(&qp->queue.reqs[i].cn10k_req.jd);
        }
 
        return qp;
@@ -335,7 +337,7 @@ cn10k_ml_model_print(struct rte_ml_dev *dev, uint16_t 
model_id, FILE *fp)
 
 static void
 cn10k_ml_prep_sp_job_descriptor(struct cn10k_ml_dev *cn10k_mldev, struct 
cnxk_ml_model *model,
-                               struct cn10k_ml_req *req, enum 
cn10k_ml_job_type job_type)
+                               struct cnxk_ml_req *req, enum cn10k_ml_job_type 
job_type)
 {
        struct cn10k_ml_model_metadata *metadata;
        struct cn10k_ml_layer_addr *addr;
@@ -343,79 +345,88 @@ cn10k_ml_prep_sp_job_descriptor(struct cn10k_ml_dev 
*cn10k_mldev, struct cnxk_ml
        metadata = &model->glow.metadata;
        addr = &model->layer[0].glow.addr;
 
-       memset(&req->jd, 0, sizeof(struct cn10k_ml_jd));
-       req->jd.hdr.jce.w0.u64 = 0;
-       req->jd.hdr.jce.w1.u64 = PLT_U64_CAST(&req->status);
-       req->jd.hdr.model_id = model->model_id;
-       req->jd.hdr.job_type = job_type;
-       req->jd.hdr.fp_flags = 0x0;
-       req->jd.hdr.result = roc_ml_addr_ap2mlip(&cn10k_mldev->roc, 
&req->result);
+       memset(&req->cn10k_req.jd, 0, sizeof(struct cn10k_ml_jd));
+       req->cn10k_req.jd.hdr.jce.w0.u64 = 0;
+       req->cn10k_req.jd.hdr.jce.w1.u64 = PLT_U64_CAST(&req->cn10k_req.status);
+       req->cn10k_req.jd.hdr.model_id = model->model_id;
+       req->cn10k_req.jd.hdr.job_type = job_type;
+       req->cn10k_req.jd.hdr.fp_flags = 0x0;
+       req->cn10k_req.jd.hdr.result =
+               roc_ml_addr_ap2mlip(&cn10k_mldev->roc, &req->cn10k_req.result);
 
        if (job_type == ML_CN10K_JOB_TYPE_MODEL_START) {
                if (!model->glow.metadata.model.ocm_relocatable)
-                       req->jd.hdr.sp_flags = 
ML_CN10K_SP_FLAGS_OCM_NONRELOCATABLE;
+                       req->cn10k_req.jd.hdr.sp_flags = 
ML_CN10K_SP_FLAGS_OCM_NONRELOCATABLE;
                else
-                       req->jd.hdr.sp_flags = 0x0;
+                       req->cn10k_req.jd.hdr.sp_flags = 0x0;
 
-               req->jd.hdr.sp_flags |= ML_CN10K_SP_FLAGS_EXTENDED_LOAD_JD;
-               req->jd.model_start.extended_args =
-                       PLT_U64_CAST(roc_ml_addr_ap2mlip(&cn10k_mldev->roc, 
&req->extended_args));
-               req->jd.model_start.model_dst_ddr_addr =
+               req->cn10k_req.jd.hdr.sp_flags |= 
ML_CN10K_SP_FLAGS_EXTENDED_LOAD_JD;
+               req->cn10k_req.jd.model_start.extended_args = PLT_U64_CAST(
+                       roc_ml_addr_ap2mlip(&cn10k_mldev->roc, 
&req->cn10k_req.extended_args));
+               req->cn10k_req.jd.model_start.model_dst_ddr_addr =
                        PLT_U64_CAST(roc_ml_addr_ap2mlip(&cn10k_mldev->roc, 
addr->init_run_addr));
-               req->jd.model_start.model_init_offset = 0x0;
-               req->jd.model_start.model_main_offset = 
metadata->init_model.file_size;
-               req->jd.model_start.model_finish_offset =
+               req->cn10k_req.jd.model_start.model_init_offset = 0x0;
+               req->cn10k_req.jd.model_start.model_main_offset = 
metadata->init_model.file_size;
+               req->cn10k_req.jd.model_start.model_finish_offset =
                        metadata->init_model.file_size + 
metadata->main_model.file_size;
-               req->jd.model_start.model_init_size = 
metadata->init_model.file_size;
-               req->jd.model_start.model_main_size = 
metadata->main_model.file_size;
-               req->jd.model_start.model_finish_size = 
metadata->finish_model.file_size;
-               req->jd.model_start.model_wb_offset = 
metadata->init_model.file_size +
-                                                     
metadata->main_model.file_size +
-                                                     
metadata->finish_model.file_size;
-               req->jd.model_start.num_layers = metadata->model.num_layers;
-               req->jd.model_start.num_gather_entries = 0;
-               req->jd.model_start.num_scatter_entries = 0;
-               req->jd.model_start.tilemask = 0; /* Updated after reserving 
pages */
-               req->jd.model_start.batch_size = model->batch_size;
-               req->jd.model_start.ocm_wb_base_address = 0; /* Updated after 
reserving pages */
-               req->jd.model_start.ocm_wb_range_start = 
metadata->model.ocm_wb_range_start;
-               req->jd.model_start.ocm_wb_range_end = 
metadata->model.ocm_wb_range_end;
-               req->jd.model_start.ddr_wb_base_address = 
PLT_U64_CAST(roc_ml_addr_ap2mlip(
-                       &cn10k_mldev->roc,
-                       PLT_PTR_ADD(addr->finish_load_addr, 
metadata->finish_model.file_size)));
-               req->jd.model_start.ddr_wb_range_start = 
metadata->model.ddr_wb_range_start;
-               req->jd.model_start.ddr_wb_range_end = 
metadata->model.ddr_wb_range_end;
-               req->jd.model_start.input.s.ddr_range_start = 
metadata->model.ddr_input_range_start;
-               req->jd.model_start.input.s.ddr_range_end = 
metadata->model.ddr_input_range_end;
-               req->jd.model_start.output.s.ddr_range_start =
+               req->cn10k_req.jd.model_start.model_init_size = 
metadata->init_model.file_size;
+               req->cn10k_req.jd.model_start.model_main_size = 
metadata->main_model.file_size;
+               req->cn10k_req.jd.model_start.model_finish_size = 
metadata->finish_model.file_size;
+               req->cn10k_req.jd.model_start.model_wb_offset = 
metadata->init_model.file_size +
+                                                               
metadata->main_model.file_size +
+                                                               
metadata->finish_model.file_size;
+               req->cn10k_req.jd.model_start.num_layers = 
metadata->model.num_layers;
+               req->cn10k_req.jd.model_start.num_gather_entries = 0;
+               req->cn10k_req.jd.model_start.num_scatter_entries = 0;
+               req->cn10k_req.jd.model_start.tilemask = 0; /* Updated after 
reserving pages */
+               req->cn10k_req.jd.model_start.batch_size = model->batch_size;
+               req->cn10k_req.jd.model_start.ocm_wb_base_address =
+                       0; /* Updated after reserving pages */
+               req->cn10k_req.jd.model_start.ocm_wb_range_start =
+                       metadata->model.ocm_wb_range_start;
+               req->cn10k_req.jd.model_start.ocm_wb_range_end = 
metadata->model.ocm_wb_range_end;
+               req->cn10k_req.jd.model_start.ddr_wb_base_address =
+                       PLT_U64_CAST(roc_ml_addr_ap2mlip(
+                               &cn10k_mldev->roc, 
PLT_PTR_ADD(addr->finish_load_addr,
+                                                              
metadata->finish_model.file_size)));
+               req->cn10k_req.jd.model_start.ddr_wb_range_start =
+                       metadata->model.ddr_wb_range_start;
+               req->cn10k_req.jd.model_start.ddr_wb_range_end = 
metadata->model.ddr_wb_range_end;
+               req->cn10k_req.jd.model_start.input.s.ddr_range_start =
+                       metadata->model.ddr_input_range_start;
+               req->cn10k_req.jd.model_start.input.s.ddr_range_end =
+                       metadata->model.ddr_input_range_end;
+               req->cn10k_req.jd.model_start.output.s.ddr_range_start =
                        metadata->model.ddr_output_range_start;
-               req->jd.model_start.output.s.ddr_range_end = 
metadata->model.ddr_output_range_end;
+               req->cn10k_req.jd.model_start.output.s.ddr_range_end =
+                       metadata->model.ddr_output_range_end;
 
-               req->extended_args.start.ddr_scratch_base_address = 
PLT_U64_CAST(
+               req->cn10k_req.extended_args.start.ddr_scratch_base_address = 
PLT_U64_CAST(
                        roc_ml_addr_ap2mlip(&cn10k_mldev->roc, 
addr->scratch_base_addr));
-               req->extended_args.start.ddr_scratch_range_start =
+               req->cn10k_req.extended_args.start.ddr_scratch_range_start =
                        metadata->model.ddr_scratch_range_start;
-               req->extended_args.start.ddr_scratch_range_end =
+               req->cn10k_req.extended_args.start.ddr_scratch_range_end =
                        metadata->model.ddr_scratch_range_end;
        }
 }
 
 static __rte_always_inline void
-cn10k_ml_prep_fp_job_descriptor(struct cn10k_ml_dev *cn10k_mldev, struct 
cn10k_ml_req *req,
+cn10k_ml_prep_fp_job_descriptor(struct cn10k_ml_dev *cn10k_mldev, struct 
cnxk_ml_req *req,
                                struct rte_ml_op *op)
 {
-       req->jd.hdr.jce.w0.u64 = 0;
-       req->jd.hdr.jce.w1.u64 = req->compl_W1;
-       req->jd.hdr.model_id = op->model_id;
-       req->jd.hdr.job_type = ML_CN10K_JOB_TYPE_MODEL_RUN;
-       req->jd.hdr.fp_flags = ML_FLAGS_POLL_COMPL;
-       req->jd.hdr.sp_flags = 0x0;
-       req->jd.hdr.result = roc_ml_addr_ap2mlip(&cn10k_mldev->roc, 
&req->result);
-       req->jd.model_run.input_ddr_addr =
+       req->cn10k_req.jd.hdr.jce.w0.u64 = 0;
+       req->cn10k_req.jd.hdr.jce.w1.u64 = PLT_U64_CAST(req->status);
+       req->cn10k_req.jd.hdr.model_id = op->model_id;
+       req->cn10k_req.jd.hdr.job_type = ML_CN10K_JOB_TYPE_MODEL_RUN;
+       req->cn10k_req.jd.hdr.fp_flags = ML_FLAGS_POLL_COMPL;
+       req->cn10k_req.jd.hdr.sp_flags = 0x0;
+       req->cn10k_req.jd.hdr.result =
+               roc_ml_addr_ap2mlip(&cn10k_mldev->roc, &req->cn10k_req.result);
+       req->cn10k_req.jd.model_run.input_ddr_addr =
                PLT_U64_CAST(roc_ml_addr_ap2mlip(&cn10k_mldev->roc, 
op->input[0]->addr));
-       req->jd.model_run.output_ddr_addr =
+       req->cn10k_req.jd.model_run.output_ddr_addr =
                PLT_U64_CAST(roc_ml_addr_ap2mlip(&cn10k_mldev->roc, 
op->output[0]->addr));
-       req->jd.model_run.num_batches = op->nb_batches;
+       req->cn10k_req.jd.model_run.num_batches = op->nb_batches;
 }
 
 struct xstat_info {
@@ -863,7 +874,7 @@ cn10k_ml_cache_model_data(struct rte_ml_dev *dev, uint16_t 
model_id)
        op.input = &inp;
        op.output = &out;
 
-       memset(model->layer[0].glow.req, 0, sizeof(struct cn10k_ml_req));
+       memset(model->layer[0].glow.req, 0, sizeof(struct cnxk_ml_req));
        ret = cn10k_ml_inference_sync(dev, &op);
        plt_memzone_free(mz);
 
@@ -906,7 +917,7 @@ cn10k_ml_dev_configure(struct rte_ml_dev *dev, const struct 
rte_ml_dev_config *c
        struct cnxk_ml_dev *cnxk_mldev;
        struct cnxk_ml_model *model;
        struct cn10k_ml_ocm *ocm;
-       struct cn10k_ml_qp *qp;
+       struct cnxk_ml_qp *qp;
        uint16_t model_id;
        uint32_t mz_size;
        uint16_t tile_id;
@@ -1103,7 +1114,7 @@ cn10k_ml_dev_close(struct rte_ml_dev *dev)
        struct cn10k_ml_dev *cn10k_mldev;
        struct cnxk_ml_dev *cnxk_mldev;
        struct cnxk_ml_model *model;
-       struct cn10k_ml_qp *qp;
+       struct cnxk_ml_qp *qp;
        uint16_t model_id;
        uint16_t qp_id;
 
@@ -1138,7 +1149,7 @@ cn10k_ml_dev_close(struct rte_ml_dev *dev)
        for (qp_id = 0; qp_id < dev->data->nb_queue_pairs; qp_id++) {
                qp = dev->data->queue_pairs[qp_id];
                if (qp != NULL) {
-                       if (cn10k_ml_qp_destroy(dev, qp) != 0)
+                       if (cnxk_ml_qp_destroy(dev, qp) != 0)
                                plt_err("Could not destroy queue pair %u", 
qp_id);
                        dev->data->queue_pairs[qp_id] = NULL;
                }
@@ -1215,7 +1226,7 @@ cn10k_ml_dev_queue_pair_setup(struct rte_ml_dev *dev, 
uint16_t queue_pair_id,
                              const struct rte_ml_dev_qp_conf *qp_conf, int 
socket_id)
 {
        struct rte_ml_dev_info dev_info;
-       struct cn10k_ml_qp *qp;
+       struct cnxk_ml_qp *qp;
        uint32_t nb_desc;
 
        if (queue_pair_id >= dev->data->nb_queue_pairs) {
@@ -1241,7 +1252,7 @@ cn10k_ml_dev_queue_pair_setup(struct rte_ml_dev *dev, 
uint16_t queue_pair_id,
         */
        nb_desc =
                (qp_conf->nb_desc == dev_info.max_desc) ? dev_info.max_desc : 
qp_conf->nb_desc + 1;
-       qp = cn10k_ml_qp_create(dev, queue_pair_id, nb_desc, socket_id);
+       qp = cnxk_ml_qp_create(dev, queue_pair_id, nb_desc, socket_id);
        if (qp == NULL) {
                plt_err("Could not create queue pair %u", queue_pair_id);
                return -ENOMEM;
@@ -1254,7 +1265,7 @@ cn10k_ml_dev_queue_pair_setup(struct rte_ml_dev *dev, 
uint16_t queue_pair_id,
 static int
 cn10k_ml_dev_stats_get(struct rte_ml_dev *dev, struct rte_ml_dev_stats *stats)
 {
-       struct cn10k_ml_qp *qp;
+       struct cnxk_ml_qp *qp;
        int qp_id;
 
        for (qp_id = 0; qp_id < dev->data->nb_queue_pairs; qp_id++) {
@@ -1271,7 +1282,7 @@ cn10k_ml_dev_stats_get(struct rte_ml_dev *dev, struct 
rte_ml_dev_stats *stats)
 static void
 cn10k_ml_dev_stats_reset(struct rte_ml_dev *dev)
 {
-       struct cn10k_ml_qp *qp;
+       struct cnxk_ml_qp *qp;
        int qp_id;
 
        for (qp_id = 0; qp_id < dev->data->nb_queue_pairs; qp_id++) {
@@ -1487,20 +1498,22 @@ cn10k_ml_dev_dump(struct rte_ml_dev *dev, FILE *fp)
 
        /* Dump debug buffer */
        for (core_id = 0; core_id <= 1; core_id++) {
-               bufsize = fw->req->jd.fw_load.debug.debug_buffer_size;
+               bufsize = fw->req->cn10k_req.jd.fw_load.debug.debug_buffer_size;
                if (core_id == 0) {
                        head_loc =
                                roc_ml_reg_read64(&cn10k_mldev->roc, 
ML_SCRATCH_DBG_BUFFER_HEAD_C0);
                        tail_loc =
                                roc_ml_reg_read64(&cn10k_mldev->roc, 
ML_SCRATCH_DBG_BUFFER_TAIL_C0);
-                       head_ptr = 
PLT_PTR_CAST(fw->req->jd.fw_load.debug.core0_debug_ptr);
+                       head_ptr =
+                               
PLT_PTR_CAST(fw->req->cn10k_req.jd.fw_load.debug.core0_debug_ptr);
                        head_ptr = roc_ml_addr_mlip2ap(&cn10k_mldev->roc, 
head_ptr);
                } else {
                        head_loc =
                                roc_ml_reg_read64(&cn10k_mldev->roc, 
ML_SCRATCH_DBG_BUFFER_HEAD_C1);
                        tail_loc =
                                roc_ml_reg_read64(&cn10k_mldev->roc, 
ML_SCRATCH_DBG_BUFFER_TAIL_C1);
-                       head_ptr = 
PLT_PTR_CAST(fw->req->jd.fw_load.debug.core1_debug_ptr);
+                       head_ptr =
+                               
PLT_PTR_CAST(fw->req->cn10k_req.jd.fw_load.debug.core1_debug_ptr);
                        head_ptr = roc_ml_addr_mlip2ap(&cn10k_mldev->roc, 
head_ptr);
                }
                if (head_loc < tail_loc) {
@@ -1513,17 +1526,19 @@ cn10k_ml_dev_dump(struct rte_ml_dev *dev, FILE *fp)
 
        /* Dump exception info */
        for (core_id = 0; core_id <= 1; core_id++) {
-               bufsize = fw->req->jd.fw_load.debug.exception_state_size;
+               bufsize = 
fw->req->cn10k_req.jd.fw_load.debug.exception_state_size;
                if ((core_id == 0) &&
                    (roc_ml_reg_read64(&cn10k_mldev->roc, 
ML_SCRATCH_EXCEPTION_SP_C0) != 0)) {
-                       head_ptr = 
PLT_PTR_CAST(fw->req->jd.fw_load.debug.core0_exception_buffer);
+                       head_ptr = PLT_PTR_CAST(
+                               
fw->req->cn10k_req.jd.fw_load.debug.core0_exception_buffer);
                        fprintf(fp, "ML_SCRATCH_EXCEPTION_SP_C0 = 0x%016lx",
                                roc_ml_reg_read64(&cn10k_mldev->roc, 
ML_SCRATCH_EXCEPTION_SP_C0));
                        head_ptr = roc_ml_addr_mlip2ap(&cn10k_mldev->roc, 
head_ptr);
                        fprintf(fp, "%.*s", bufsize, head_ptr);
                } else if ((core_id == 1) && 
(roc_ml_reg_read64(&cn10k_mldev->roc,
                                                                
ML_SCRATCH_EXCEPTION_SP_C1) != 0)) {
-                       head_ptr = 
PLT_PTR_CAST(fw->req->jd.fw_load.debug.core1_exception_buffer);
+                       head_ptr = PLT_PTR_CAST(
+                               
fw->req->cn10k_req.jd.fw_load.debug.core1_exception_buffer);
                        fprintf(fp, "ML_SCRATCH_EXCEPTION_SP_C1 = 0x%016lx",
                                roc_ml_reg_read64(&cn10k_mldev->roc, 
ML_SCRATCH_EXCEPTION_SP_C1));
                        head_ptr = roc_ml_addr_mlip2ap(&cn10k_mldev->roc, 
head_ptr);
@@ -1540,14 +1555,14 @@ cn10k_ml_dev_selftest(struct rte_ml_dev *dev)
        struct cn10k_ml_dev *cn10k_mldev;
        struct cnxk_ml_dev *cnxk_mldev;
        const struct plt_memzone *mz;
-       struct cn10k_ml_req *req;
+       struct cnxk_ml_req *req;
        uint64_t timeout_cycle;
        bool timeout;
        int ret;
 
        cnxk_mldev = dev->data->dev_private;
        cn10k_mldev = &cnxk_mldev->cn10k_mldev;
-       mz = plt_memzone_reserve_aligned("dev_selftest", sizeof(struct 
cn10k_ml_req), 0,
+       mz = plt_memzone_reserve_aligned("dev_selftest", sizeof(struct 
cnxk_ml_req), 0,
                                         ML_CN10K_ALIGN_SIZE);
        if (mz == NULL) {
                plt_err("Could not allocate reserved memzone");
@@ -1556,23 +1571,24 @@ cn10k_ml_dev_selftest(struct rte_ml_dev *dev)
        req = mz->addr;
 
        /* Prepare load completion structure */
-       memset(&req->jd, 0, sizeof(struct cn10k_ml_jd));
-       req->jd.hdr.jce.w1.u64 = PLT_U64_CAST(&req->status);
-       req->jd.hdr.job_type = ML_CN10K_JOB_TYPE_FIRMWARE_SELFTEST;
-       req->jd.hdr.result = roc_ml_addr_ap2mlip(&cn10k_mldev->roc, 
&req->result);
-       req->jd.fw_load.flags = cn10k_ml_fw_flags_get(&cn10k_mldev->fw);
-       plt_write64(ML_CNXK_POLL_JOB_START, &req->status);
+       memset(&req->cn10k_req.jd, 0, sizeof(struct cn10k_ml_jd));
+       req->cn10k_req.jd.hdr.jce.w1.u64 = PLT_U64_CAST(&req->cn10k_req.status);
+       req->cn10k_req.jd.hdr.job_type = ML_CN10K_JOB_TYPE_FIRMWARE_SELFTEST;
+       req->cn10k_req.jd.hdr.result =
+               roc_ml_addr_ap2mlip(&cn10k_mldev->roc, &req->cn10k_req.result);
+       req->cn10k_req.jd.fw_load.flags = 
cn10k_ml_fw_flags_get(&cn10k_mldev->fw);
+       plt_write64(ML_CNXK_POLL_JOB_START, &req->cn10k_req.status);
        plt_wmb();
 
        /* Enqueue firmware selftest request through scratch registers */
        timeout = true;
        timeout_cycle = plt_tsc_cycles() + ML_CNXK_CMD_TIMEOUT * plt_tsc_hz();
-       roc_ml_scratch_enqueue(&cn10k_mldev->roc, &req->jd);
+       roc_ml_scratch_enqueue(&cn10k_mldev->roc, &req->cn10k_req.jd);
 
        plt_rmb();
        do {
                if (roc_ml_scratch_is_done_bit_set(&cn10k_mldev->roc) &&
-                   (plt_read64(&req->status) == ML_CNXK_POLL_JOB_FINISH)) {
+                   (plt_read64(&req->cn10k_req.status) == 
ML_CNXK_POLL_JOB_FINISH)) {
                        timeout = false;
                        break;
                }
@@ -1583,7 +1599,7 @@ cn10k_ml_dev_selftest(struct rte_ml_dev *dev)
        if (timeout) {
                ret = -ETIME;
        } else {
-               if (req->result.error_code.u64 != 0)
+               if (req->cn10k_req.result.error_code != 0)
                        ret = -1;
        }
 
@@ -1656,7 +1672,7 @@ cn10k_ml_model_load(struct rte_ml_dev *dev, struct 
rte_ml_model_params *params,
 
        mz_size = PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_model), 
ML_CN10K_ALIGN_SIZE) +
                  2 * model_data_size + model_scratch_size + model_info_size +
-                 PLT_ALIGN_CEIL(sizeof(struct cn10k_ml_req), 
ML_CN10K_ALIGN_SIZE) +
+                 PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_req), 
ML_CN10K_ALIGN_SIZE) +
                  model_stats_size;
 
        /* Allocate memzone for model object and model data */
@@ -1728,7 +1744,7 @@ cn10k_ml_model_load(struct rte_ml_dev *dev, struct 
rte_ml_model_params *params,
        /* Reset burst and sync stats */
        model->layer[0].glow.burst_stats =
                PLT_PTR_ADD(model->layer[0].glow.req,
-                           PLT_ALIGN_CEIL(sizeof(struct cn10k_ml_req), 
ML_CN10K_ALIGN_SIZE));
+                           PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_req), 
ML_CN10K_ALIGN_SIZE));
        for (qp_id = 0; qp_id < dev->data->nb_queue_pairs + 1; qp_id++) {
                model->layer[0].glow.burst_stats[qp_id].hw_latency_tot = 0;
                model->layer[0].glow.burst_stats[qp_id].hw_latency_min = 
UINT64_MAX;
@@ -1792,7 +1808,7 @@ cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t 
model_id)
        struct cnxk_ml_dev *cnxk_mldev;
        struct cnxk_ml_model *model;
        struct cn10k_ml_ocm *ocm;
-       struct cn10k_ml_req *req;
+       struct cnxk_ml_req *req;
 
        bool job_enqueued;
        bool job_dequeued;
@@ -1817,10 +1833,10 @@ cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t 
model_id)
        /* Prepare JD */
        req = model->layer[0].glow.req;
        cn10k_ml_prep_sp_job_descriptor(cn10k_mldev, model, req, 
ML_CN10K_JOB_TYPE_MODEL_START);
-       req->result.error_code.u64 = 0x0;
-       req->result.user_ptr = NULL;
+       req->cn10k_req.result.error_code = 0x0;
+       req->cn10k_req.result.user_ptr = NULL;
 
-       plt_write64(ML_CNXK_POLL_JOB_START, &req->status);
+       plt_write64(ML_CNXK_POLL_JOB_START, &req->cn10k_req.status);
        plt_wmb();
 
        num_tiles = model->layer[0].glow.metadata.model.tile_end -
@@ -1880,8 +1896,8 @@ cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t 
model_id)
 
        /* Update JD */
        cn10k_ml_ocm_tilecount(model->layer[0].glow.ocm_map.tilemask, 
&tile_start, &tile_end);
-       req->jd.model_start.tilemask = GENMASK_ULL(tile_end, tile_start);
-       req->jd.model_start.ocm_wb_base_address =
+       req->cn10k_req.jd.model_start.tilemask = GENMASK_ULL(tile_end, 
tile_start);
+       req->cn10k_req.jd.model_start.ocm_wb_base_address =
                model->layer[0].glow.ocm_map.wb_page_start * ocm->page_size;
 
        job_enqueued = false;
@@ -1889,19 +1905,21 @@ cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t 
model_id)
        do {
                if (!job_enqueued) {
                        req->timeout = plt_tsc_cycles() + ML_CNXK_CMD_TIMEOUT * 
plt_tsc_hz();
-                       job_enqueued = 
roc_ml_scratch_enqueue(&cn10k_mldev->roc, &req->jd);
+                       job_enqueued =
+                               roc_ml_scratch_enqueue(&cn10k_mldev->roc, 
&req->cn10k_req.jd);
                }
 
                if (job_enqueued && !job_dequeued)
-                       job_dequeued = 
roc_ml_scratch_dequeue(&cn10k_mldev->roc, &req->jd);
+                       job_dequeued =
+                               roc_ml_scratch_dequeue(&cn10k_mldev->roc, 
&req->cn10k_req.jd);
 
                if (job_dequeued)
                        break;
        } while (plt_tsc_cycles() < req->timeout);
 
        if (job_dequeued) {
-               if (plt_read64(&req->status) == ML_CNXK_POLL_JOB_FINISH) {
-                       if (req->result.error_code.u64 == 0)
+               if (plt_read64(&req->cn10k_req.status) == 
ML_CNXK_POLL_JOB_FINISH) {
+                       if (req->cn10k_req.result.error_code == 0)
                                ret = 0;
                        else
                                ret = -1;
@@ -1954,7 +1972,7 @@ cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t 
model_id)
        struct cnxk_ml_dev *cnxk_mldev;
        struct cnxk_ml_model *model;
        struct cn10k_ml_ocm *ocm;
-       struct cn10k_ml_req *req;
+       struct cnxk_ml_req *req;
 
        bool job_enqueued;
        bool job_dequeued;
@@ -1974,10 +1992,10 @@ cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t 
model_id)
        /* Prepare JD */
        req = model->layer[0].glow.req;
        cn10k_ml_prep_sp_job_descriptor(cn10k_mldev, model, req, 
ML_CN10K_JOB_TYPE_MODEL_STOP);
-       req->result.error_code.u64 = 0x0;
-       req->result.user_ptr = NULL;
+       req->cn10k_req.result.error_code = 0x0;
+       req->cn10k_req.result.user_ptr = NULL;
 
-       plt_write64(ML_CNXK_POLL_JOB_START, &req->status);
+       plt_write64(ML_CNXK_POLL_JOB_START, &req->cn10k_req.status);
        plt_wmb();
 
        locked = false;
@@ -2017,19 +2035,21 @@ cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t 
model_id)
        do {
                if (!job_enqueued) {
                        req->timeout = plt_tsc_cycles() + ML_CNXK_CMD_TIMEOUT * 
plt_tsc_hz();
-                       job_enqueued = 
roc_ml_scratch_enqueue(&cn10k_mldev->roc, &req->jd);
+                       job_enqueued =
+                               roc_ml_scratch_enqueue(&cn10k_mldev->roc, 
&req->cn10k_req.jd);
                }
 
                if (job_enqueued && !job_dequeued)
-                       job_dequeued = 
roc_ml_scratch_dequeue(&cn10k_mldev->roc, &req->jd);
+                       job_dequeued =
+                               roc_ml_scratch_dequeue(&cn10k_mldev->roc, 
&req->cn10k_req.jd);
 
                if (job_dequeued)
                        break;
        } while (plt_tsc_cycles() < req->timeout);
 
        if (job_dequeued) {
-               if (plt_read64(&req->status) == ML_CNXK_POLL_JOB_FINISH) {
-                       if (req->result.error_code.u64 == 0x0)
+               if (plt_read64(&req->cn10k_req.status) == 
ML_CNXK_POLL_JOB_FINISH) {
+                       if (req->cn10k_req.result.error_code == 0x0)
                                ret = 0;
                        else
                                ret = -1;
@@ -2289,18 +2309,23 @@ queue_free_count(uint64_t head, uint64_t tail, uint64_t 
nb_desc)
 }
 
 static __rte_always_inline void
-cn10k_ml_result_update(struct rte_ml_dev *dev, int qp_id, struct 
cn10k_ml_result *result,
-                      struct rte_ml_op *op)
+cn10k_ml_result_update(struct rte_ml_dev *dev, int qp_id, struct cnxk_ml_req 
*req)
 {
+       union cn10k_ml_error_code *error_code;
        struct cn10k_ml_layer_stats *stats;
        struct cn10k_ml_dev *cn10k_mldev;
        struct cnxk_ml_dev *cnxk_mldev;
+       struct cn10k_ml_result *result;
        struct cnxk_ml_model *model;
-       struct cn10k_ml_qp *qp;
+       struct cnxk_ml_qp *qp;
+       struct rte_ml_op *op;
        uint64_t hw_latency;
        uint64_t fw_latency;
 
-       if (likely(result->error_code.u64 == 0)) {
+       result = &req->cn10k_req.result;
+       op = req->op;
+
+       if (likely(result->error_code == 0)) {
                model = dev->data->models[op->model_id];
                if (likely(qp_id >= 0)) {
                        qp = dev->data->queue_pairs[qp_id];
@@ -2331,7 +2356,7 @@ cn10k_ml_result_update(struct rte_ml_dev *dev, int qp_id, 
struct cn10k_ml_result
                stats->fw_latency_max = PLT_MAX(stats->fw_latency_max, 
fw_latency);
                stats->dequeued_count++;
 
-               op->impl_opaque = result->error_code.u64;
+               op->impl_opaque = result->error_code;
                op->status = RTE_ML_OP_STATUS_SUCCESS;
        } else {
                if (likely(qp_id >= 0)) {
@@ -2340,7 +2365,8 @@ cn10k_ml_result_update(struct rte_ml_dev *dev, int qp_id, 
struct cn10k_ml_result
                }
 
                /* Handle driver error */
-               if (result->error_code.s.etype == ML_ETYPE_DRIVER) {
+               error_code = (union cn10k_ml_error_code *)&result->error_code;
+               if (error_code->s.etype == ML_ETYPE_DRIVER) {
                        cnxk_mldev = dev->data->dev_private;
                        cn10k_mldev = &cnxk_mldev->cn10k_mldev;
 
@@ -2348,15 +2374,15 @@ cn10k_ml_result_update(struct rte_ml_dev *dev, int 
qp_id, struct cn10k_ml_result
                        if ((roc_ml_reg_read64(&cn10k_mldev->roc, 
ML_SCRATCH_EXCEPTION_SP_C0) !=
                             0) ||
                            (roc_ml_reg_read64(&cn10k_mldev->roc, 
ML_SCRATCH_EXCEPTION_SP_C1) != 0))
-                               result->error_code.s.stype = 
ML_DRIVER_ERR_EXCEPTION;
+                               error_code->s.stype = ML_DRIVER_ERR_EXCEPTION;
                        else if ((roc_ml_reg_read64(&cn10k_mldev->roc, 
ML_CORE_INT_LO) != 0) ||
                                 (roc_ml_reg_read64(&cn10k_mldev->roc, 
ML_CORE_INT_HI) != 0))
-                               result->error_code.s.stype = 
ML_DRIVER_ERR_FW_ERROR;
+                               error_code->s.stype = ML_DRIVER_ERR_FW_ERROR;
                        else
-                               result->error_code.s.stype = 
ML_DRIVER_ERR_UNKNOWN;
+                               error_code->s.stype = ML_DRIVER_ERR_UNKNOWN;
                }
 
-               op->impl_opaque = result->error_code.u64;
+               op->impl_opaque = result->error_code;
                op->status = RTE_ML_OP_STATUS_ERROR;
        }
 
@@ -2367,11 +2393,12 @@ __rte_hot uint16_t
 cn10k_ml_enqueue_burst(struct rte_ml_dev *dev, uint16_t qp_id, struct 
rte_ml_op **ops,
                       uint16_t nb_ops)
 {
+       union cn10k_ml_error_code *error_code;
        struct cn10k_ml_dev *cn10k_mldev;
        struct cnxk_ml_dev *cnxk_mldev;
-       struct cn10k_ml_queue *queue;
-       struct cn10k_ml_req *req;
-       struct cn10k_ml_qp *qp;
+       struct cnxk_ml_queue *queue;
+       struct cnxk_ml_req *req;
+       struct cnxk_ml_qp *qp;
        struct rte_ml_op *op;
 
        uint16_t count;
@@ -2397,12 +2424,13 @@ cn10k_ml_enqueue_burst(struct rte_ml_dev *dev, uint16_t 
qp_id, struct rte_ml_op
        cn10k_mldev->set_poll_addr(req);
        cn10k_ml_prep_fp_job_descriptor(cn10k_mldev, req, op);
 
-       memset(&req->result, 0, sizeof(struct cn10k_ml_result));
-       req->result.error_code.s.etype = ML_ETYPE_UNKNOWN;
-       req->result.user_ptr = op->user_ptr;
+       memset(&req->cn10k_req.result, 0, sizeof(struct cn10k_ml_result));
+       error_code = (union cn10k_ml_error_code 
*)&req->cn10k_req.result.error_code;
+       error_code->s.etype = ML_ETYPE_UNKNOWN;
+       req->cn10k_req.result.user_ptr = op->user_ptr;
 
        cn10k_mldev->set_poll_ptr(req);
-       enqueued = cn10k_mldev->ml_jcmdq_enqueue(&cn10k_mldev->roc, &req->jcmd);
+       enqueued = cn10k_mldev->ml_jcmdq_enqueue(&cn10k_mldev->roc, 
&req->cn10k_req.jcmd);
        if (unlikely(!enqueued))
                goto jcmdq_full;
 
@@ -2426,11 +2454,12 @@ __rte_hot uint16_t
 cn10k_ml_dequeue_burst(struct rte_ml_dev *dev, uint16_t qp_id, struct 
rte_ml_op **ops,
                       uint16_t nb_ops)
 {
+       union cn10k_ml_error_code *error_code;
        struct cn10k_ml_dev *cn10k_mldev;
        struct cnxk_ml_dev *cnxk_mldev;
-       struct cn10k_ml_queue *queue;
-       struct cn10k_ml_req *req;
-       struct cn10k_ml_qp *qp;
+       struct cnxk_ml_queue *queue;
+       struct cnxk_ml_req *req;
+       struct cnxk_ml_qp *qp;
 
        uint64_t status;
        uint16_t count;
@@ -2452,13 +2481,15 @@ cn10k_ml_dequeue_burst(struct rte_ml_dev *dev, uint16_t 
qp_id, struct rte_ml_op
        req = &queue->reqs[tail];
        status = cn10k_mldev->get_poll_ptr(req);
        if (unlikely(status != ML_CNXK_POLL_JOB_FINISH)) {
-               if (plt_tsc_cycles() < req->timeout)
+               if (plt_tsc_cycles() < req->timeout) {
                        goto empty_or_active;
-               else /* Timeout, set indication of driver error */
-                       req->result.error_code.s.etype = ML_ETYPE_DRIVER;
+               } else { /* Timeout, set indication of driver error */
+                       error_code = (union cn10k_ml_error_code 
*)&req->cn10k_req.result.error_code;
+                       error_code->s.etype = ML_ETYPE_DRIVER;
+               }
        }
 
-       cn10k_ml_result_update(dev, qp_id, &req->result, req->op);
+       cn10k_ml_result_update(dev, qp_id, req);
        ops[count] = req->op;
 
        queue_index_advance(&tail, qp->nb_desc);
@@ -2509,10 +2540,11 @@ cn10k_ml_op_error_get(struct rte_ml_dev *dev, struct 
rte_ml_op *op, struct rte_m
 __rte_hot int
 cn10k_ml_inference_sync(struct rte_ml_dev *dev, struct rte_ml_op *op)
 {
+       union cn10k_ml_error_code *error_code;
        struct cn10k_ml_dev *cn10k_mldev;
        struct cnxk_ml_dev *cnxk_mldev;
        struct cnxk_ml_model *model;
-       struct cn10k_ml_req *req;
+       struct cnxk_ml_req *req;
        bool timeout;
        int ret = 0;
 
@@ -2524,17 +2556,18 @@ cn10k_ml_inference_sync(struct rte_ml_dev *dev, struct 
rte_ml_op *op)
        cn10k_ml_set_poll_addr(req);
        cn10k_ml_prep_fp_job_descriptor(cn10k_mldev, req, op);
 
-       memset(&req->result, 0, sizeof(struct cn10k_ml_result));
-       req->result.error_code.s.etype = ML_ETYPE_UNKNOWN;
-       req->result.user_ptr = op->user_ptr;
+       memset(&req->cn10k_req.result, 0, sizeof(struct cn10k_ml_result));
+       error_code = (union cn10k_ml_error_code 
*)&req->cn10k_req.result.error_code;
+       error_code->s.etype = ML_ETYPE_UNKNOWN;
+       req->cn10k_req.result.user_ptr = op->user_ptr;
 
        cn10k_mldev->set_poll_ptr(req);
-       req->jcmd.w1.s.jobptr = PLT_U64_CAST(&req->jd);
+       req->cn10k_req.jcmd.w1.s.jobptr = PLT_U64_CAST(&req->cn10k_req.jd);
 
        timeout = true;
        req->timeout = plt_tsc_cycles() + ML_CNXK_CMD_TIMEOUT * plt_tsc_hz();
        do {
-               if (cn10k_mldev->ml_jcmdq_enqueue(&cn10k_mldev->roc, 
&req->jcmd)) {
+               if (cn10k_mldev->ml_jcmdq_enqueue(&cn10k_mldev->roc, 
&req->cn10k_req.jcmd)) {
                        req->op = op;
                        timeout = false;
                        break;
@@ -2557,7 +2590,7 @@ cn10k_ml_inference_sync(struct rte_ml_dev *dev, struct 
rte_ml_op *op)
        if (timeout)
                ret = -ETIME;
        else
-               cn10k_ml_result_update(dev, -1, &req->result, req->op);
+               cn10k_ml_result_update(dev, -1, req);
 
 error_enqueue:
        return ret;
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.h b/drivers/ml/cnxk/cn10k_ml_ops.h
index 005b093e45d..fd5992e1925 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.h
+++ b/drivers/ml/cnxk/cn10k_ml_ops.h
@@ -10,63 +10,279 @@
 
 #include <roc_api.h>
 
-#include "cn10k_ml_dev.h"
+/* Firmware version string length */
+#define MLDEV_FIRMWARE_VERSION_LENGTH 32
 
-/* Request structure */
-struct cn10k_ml_req {
-       /* Job descriptor */
-       struct cn10k_ml_jd jd;
+/* Job types */
+enum cn10k_ml_job_type {
+       ML_CN10K_JOB_TYPE_MODEL_RUN = 0,
+       ML_CN10K_JOB_TYPE_MODEL_STOP,
+       ML_CN10K_JOB_TYPE_MODEL_START,
+       ML_CN10K_JOB_TYPE_FIRMWARE_LOAD,
+       ML_CN10K_JOB_TYPE_FIRMWARE_SELFTEST,
+};
 
-       /* Job descriptor extra arguments */
-       union cn10k_ml_jd_extended_args extended_args;
+/* Firmware stats */
+struct cn10k_ml_stats {
+       /* Firmware start cycle */
+       uint64_t fw_start;
 
-       /* Job result */
-       struct cn10k_ml_result result;
+       /* Firmware end cycle */
+       uint64_t fw_end;
 
-       /* Status field for poll mode requests */
-       volatile uint64_t status;
+       /* Hardware start cycle */
+       uint64_t hw_start;
 
-       /* Job command */
-       struct ml_job_cmd_s jcmd;
+       /* Hardware end cycle */
+       uint64_t hw_end;
+};
+
+/* Result structure */
+struct cn10k_ml_result {
+       /* Job error code */
+       uint64_t error_code;
+
+       /* Stats */
+       struct cn10k_ml_stats stats;
+
+       /* User context pointer */
+       void *user_ptr;
+};
+
+/* Firmware capability structure */
+union cn10k_ml_fw_cap {
+       uint64_t u64;
+
+       struct {
+               /* CMPC completion support */
+               uint64_t cmpc_completions : 1;
+
+               /* Poll mode completion support */
+               uint64_t poll_completions : 1;
+
+               /* SSO completion support */
+               uint64_t sso_completions : 1;
+
+               /* Support for model side loading */
+               uint64_t side_load_model : 1;
 
-       /* Job completion W1 */
-       uint64_t compl_W1;
+               /* Batch execution */
+               uint64_t batch_run : 1;
 
-       /* Timeout cycle */
-       uint64_t timeout;
+               /* Max number of models to be loaded in parallel */
+               uint64_t max_models : 8;
 
-       /* Op */
-       struct rte_ml_op *op;
-} __rte_aligned(ROC_ALIGN);
+               /* Firmware statistics */
+               uint64_t fw_stats : 1;
 
-/* Request queue */
-struct cn10k_ml_queue {
-       /* Array of requests */
-       struct cn10k_ml_req *reqs;
+               /* Hardware statistics */
+               uint64_t hw_stats : 1;
 
-       /* Head of the queue, used for enqueue */
-       uint64_t head;
+               /* Max number of batches */
+               uint64_t max_num_batches : 16;
 
-       /* Tail of the queue, used for dequeue */
-       uint64_t tail;
+               uint64_t rsvd : 33;
+       } s;
+};
+
+/* Firmware debug info structure */
+struct cn10k_ml_fw_debug {
+       /* ACC core 0 debug buffer */
+       uint64_t core0_debug_ptr;
+
+       /* ACC core 1 debug buffer */
+       uint64_t core1_debug_ptr;
+
+       /* ACC core 0 exception state buffer */
+       uint64_t core0_exception_buffer;
+
+       /* ACC core 1 exception state buffer */
+       uint64_t core1_exception_buffer;
+
+       /* Debug buffer size per core */
+       uint32_t debug_buffer_size;
 
-       /* Wait cycles before timeout */
-       uint64_t wait_cycles;
+       /* Exception state dump size */
+       uint32_t exception_state_size;
 };
 
-/* Queue-pair structure */
-struct cn10k_ml_qp {
-       /* ID */
-       uint32_t id;
+/* Job descriptor header (32 bytes) */
+struct cn10k_ml_jd_header {
+       /* Job completion structure */
+       struct ml_jce_s jce;
+
+       /* Model ID */
+       uint64_t model_id : 8;
+
+       /* Job type */
+       uint64_t job_type : 8;
+
+       /* Flags for fast-path jobs */
+       uint64_t fp_flags : 16;
+
+       /* Flags for slow-path jobs */
+       uint64_t sp_flags : 16;
+       uint64_t rsvd : 16;
+
+       /* Job result pointer */
+       uint64_t *result;
+};
+
+/* Extra arguments for job descriptor */
+union cn10k_ml_jd_extended_args {
+       struct cn10k_ml_jd_extended_args_section_start {
+               /* DDR Scratch base address */
+               uint64_t ddr_scratch_base_address;
+
+               /* DDR Scratch range start */
+               uint64_t ddr_scratch_range_start;
+
+               /* DDR Scratch range end */
+               uint64_t ddr_scratch_range_end;
+
+               uint8_t rsvd[104];
+       } start;
+};
+
+/* Job descriptor structure */
+struct cn10k_ml_jd {
+       /* Job descriptor header (32 bytes) */
+       struct cn10k_ml_jd_header hdr;
+
+       union {
+               struct cn10k_ml_jd_section_fw_load {
+                       /* Firmware capability structure (8 bytes) */
+                       union cn10k_ml_fw_cap cap;
+
+                       /* Firmware version (32 bytes) */
+                       uint8_t version[MLDEV_FIRMWARE_VERSION_LENGTH];
+
+                       /* Debug capability structure (40 bytes) */
+                       struct cn10k_ml_fw_debug debug;
 
-       /* Number of descriptors */
-       uint64_t nb_desc;
+                       /* Flags to control error handling */
+                       uint64_t flags;
 
-       /* Request queue */
-       struct cn10k_ml_queue queue;
+                       uint8_t rsvd[8];
+               } fw_load;
 
-       /* Statistics per queue-pair */
-       struct rte_ml_dev_stats stats;
+               struct cn10k_ml_jd_section_model_start {
+                       /* Extended arguments */
+                       uint64_t extended_args;
+
+                       /* Destination model start address in DDR relative to 
ML_MLR_BASE */
+                       uint64_t model_dst_ddr_addr;
+
+                       /* Offset to model init section in the model */
+                       uint64_t model_init_offset : 32;
+
+                       /* Size of init section in the model */
+                       uint64_t model_init_size : 32;
+
+                       /* Offset to model main section in the model */
+                       uint64_t model_main_offset : 32;
+
+                       /* Size of main section in the model */
+                       uint64_t model_main_size : 32;
+
+                       /* Offset to model finish section in the model */
+                       uint64_t model_finish_offset : 32;
+
+                       /* Size of finish section in the model */
+                       uint64_t model_finish_size : 32;
+
+                       /* Offset to WB in model bin */
+                       uint64_t model_wb_offset : 32;
+
+                       /* Number of model layers */
+                       uint64_t num_layers : 8;
+
+                       /* Number of gather entries, 0 means linear input mode 
(= no gather) */
+                       uint64_t num_gather_entries : 8;
+
+                       /* Number of scatter entries 0 means linear input mode 
(= no scatter) */
+                       uint64_t num_scatter_entries : 8;
+
+                       /* Tile mask to load model */
+                       uint64_t tilemask : 8;
+
+                       /* Batch size of model  */
+                       uint64_t batch_size : 32;
+
+                       /* OCM WB base address */
+                       uint64_t ocm_wb_base_address : 32;
+
+                       /* OCM WB range start */
+                       uint64_t ocm_wb_range_start : 32;
+
+                       /* OCM WB range End */
+                       uint64_t ocm_wb_range_end : 32;
+
+                       /* DDR WB address */
+                       uint64_t ddr_wb_base_address;
+
+                       /* DDR WB range start */
+                       uint64_t ddr_wb_range_start : 32;
+
+                       /* DDR WB range end */
+                       uint64_t ddr_wb_range_end : 32;
+
+                       union {
+                               /* Points to gather list if num_gather_entries 
> 0 */
+                               void *gather_list;
+                               struct {
+                                       /* Linear input mode */
+                                       uint64_t ddr_range_start : 32;
+                                       uint64_t ddr_range_end : 32;
+                               } s;
+                       } input;
+
+                       union {
+                               /* Points to scatter list if 
num_scatter_entries > 0 */
+                               void *scatter_list;
+                               struct {
+                                       /* Linear output mode */
+                                       uint64_t ddr_range_start : 32;
+                                       uint64_t ddr_range_end : 32;
+                               } s;
+                       } output;
+               } model_start;
+
+               struct cn10k_ml_jd_section_model_stop {
+                       uint8_t rsvd[96];
+               } model_stop;
+
+               struct cn10k_ml_jd_section_model_run {
+                       /* Address of the input for the run relative to 
ML_MLR_BASE */
+                       uint64_t input_ddr_addr;
+
+                       /* Address of the output for the run relative to 
ML_MLR_BASE */
+                       uint64_t output_ddr_addr;
+
+                       /* Number of batches to run in variable batch 
processing */
+                       uint16_t num_batches;
+
+                       uint8_t rsvd[78];
+               } model_run;
+       };
+} __plt_aligned(ROC_ALIGN);
+
+/* CN10K specific request */
+struct cn10k_ml_req {
+       /* Job descriptor */
+       struct cn10k_ml_jd jd;
+
+       /* Job descriptor extra arguments */
+       union cn10k_ml_jd_extended_args extended_args;
+
+       /* Status field for poll mode requests */
+       volatile uint64_t status;
+
+       /* Job command */
+       struct ml_job_cmd_s jcmd;
+
+       /* Result */
+       struct cn10k_ml_result result;
 };
 
 /* Device ops */
diff --git a/drivers/ml/cnxk/cnxk_ml_ops.c b/drivers/ml/cnxk/cnxk_ml_ops.c
new file mode 100644
index 00000000000..f1872dcf7c6
--- /dev/null
+++ b/drivers/ml/cnxk/cnxk_ml_ops.c
@@ -0,0 +1,7 @@
+/* SPDX-License-Identifier: BSD-3-Clause
+ * Copyright (c) 2023 Marvell.
+ */
+
+#include <rte_mldev.h>
+
+#include "cnxk_ml_ops.h"
diff --git a/drivers/ml/cnxk/cnxk_ml_ops.h b/drivers/ml/cnxk/cnxk_ml_ops.h
new file mode 100644
index 00000000000..b953fb0f5fc
--- /dev/null
+++ b/drivers/ml/cnxk/cnxk_ml_ops.h
@@ -0,0 +1,63 @@
+/* SPDX-License-Identifier: BSD-3-Clause
+ * Copyright (c) 2023 Marvell.
+ */
+
+#ifndef _CNXK_ML_OPS_H_
+#define _CNXK_ML_OPS_H_
+
+#include <rte_mldev.h>
+#include <rte_mldev_core.h>
+
+#include <roc_api.h>
+
+#include "cn10k_ml_ops.h"
+
+/* Request structure */
+struct cnxk_ml_req {
+       /* Device specific request */
+       union {
+               /* CN10K */
+               struct cn10k_ml_req cn10k_req;
+       };
+
+       /* Address of status field */
+       volatile uint64_t *status;
+
+       /* Timeout cycle */
+       uint64_t timeout;
+
+       /* Op */
+       struct rte_ml_op *op;
+} __rte_aligned(ROC_ALIGN);
+
+/* Request queue */
+struct cnxk_ml_queue {
+       /* Array of requests */
+       struct cnxk_ml_req *reqs;
+
+       /* Head of the queue, used for enqueue */
+       uint64_t head;
+
+       /* Tail of the queue, used for dequeue */
+       uint64_t tail;
+
+       /* Wait cycles before timeout */
+       uint64_t wait_cycles;
+};
+
+/* Queue-pair structure */
+struct cnxk_ml_qp {
+       /* ID */
+       uint32_t id;
+
+       /* Number of descriptors */
+       uint64_t nb_desc;
+
+       /* Request queue */
+       struct cnxk_ml_queue queue;
+
+       /* Statistics per queue-pair */
+       struct rte_ml_dev_stats stats;
+};
+
+#endif /* _CNXK_ML_OPS_H_ */
diff --git a/drivers/ml/cnxk/meson.build b/drivers/ml/cnxk/meson.build
index 72e03b15b5b..73db458fcd9 100644
--- a/drivers/ml/cnxk/meson.build
+++ b/drivers/ml/cnxk/meson.build
@@ -15,6 +15,7 @@ driver_sdk_headers = files(
         'cnxk_ml_dev.h',
         'cnxk_ml_io.h',
         'cnxk_ml_model.h',
+        'cnxk_ml_ops.h',
 )
 
 sources = files(
@@ -24,6 +25,7 @@ sources = files(
         'cn10k_ml_ocm.c',
         'cnxk_ml_dev.c',
         'cnxk_ml_model.c',
+        'cnxk_ml_ops.c',
 )
 
 deps += ['mldev', 'common_cnxk', 'kvargs', 'hash']
-- 
2.41.0

Reply via email to