PR #23258 opened by Raja-89
URL: https://code.ffmpeg.org/FFmpeg/FFmpeg/pulls/23258
Patch URL: https://code.ffmpeg.org/FFmpeg/FFmpeg/pulls/23258.patch

### avfilter/dnn: implement batching and dynamic shapes for Torch backend

This patch series adds batch processing and dynamic shape handling
to the LibTorch DNN backend. It builds on top of the persistent
input buffer patch from PR #23169.

Patch 1/2: Persistent input buffer (same as PR #23169, rebased)
- Replaces per-frame av_malloc/av_free with a lazily-allocated
  persistent buffer in THInferRequest
- Adds no-op deleter to prevent LibTorch from double-freeing
  the persistent buffer
- Fixes pre-existing SIGSEGV when model has no learnable parameters

Patch 2/2: Batch processing engine and dynamic shape handling
- Adds batch_size AVOption (range 1-32, default 1) to accumulate
  frames and process them in a single torch::cat() + forward() call
- Detects mid-stream resolution changes and automatically flushes
  the accumulator to prevent torch::cat() dimension mismatches
- Handles partial batches at EOF via dnn_flush_th()

Tested with:
  # Standard batch processing (batch_size=4)
  ./ffmpeg -f lavfi -i testsrc=duration=5:size=640x480:rate=25 \
    -vf 
format=rgb24,dnn_processing=dnn_backend=torch:model=model.pt:batch_size=4 \
    -f null /dev/null

  # Dynamic resolution change (flushing mid-batch)
  ./ffmpeg -f lavfi -i testsrc=duration=2:size=320x240:rate=10 \
    -vf 
"scale='if(gt(t,1),640,320)':'if(gt(t,1),480,240)':eval=frame,format=rgb24,dnn_processing=dnn_backend=torch:model=model.pt:batch_size=4"
 \
    -f null /dev/null

  # EOF partial flush (3 frames, batch_size=32)
  ./ffmpeg -f lavfi -i testsrc=duration=0.12:size=320x240:rate=25 \
    -vf 
format=rgb24,dnn_processing=dnn_backend=torch:model=model.pt:batch_size=32 \
    -f null /dev/null

All tests pass cleanly with 0 bytes definitely lost under Valgrind.

Supersedes PR #23169.

Signed-off-by: Raja Rathour <[email protected]>



>From ef1adb817e6123141c23969c47fcda23662d38cc Mon Sep 17 00:00:00 2001
From: Raja Rathour <[email protected]>
Date: Wed, 20 May 2026 14:53:27 +0530
Subject: [PATCH 1/2] avfilter/dnn: implement persistent input buffer for torch
 backend

Replace the per-frame av_malloc/av_free pattern with a persistent
buffer in THInferRequest that grows lazily on resolution increases
but is reused for every subsequent frame of the same or smaller size.

Key changes:
- Add input_data/input_data_size fields to THInferRequest to hold the
  persistent pixel buffer across frames
- Add persistent_buf_deleter() no-op deleter: memory is owned by
  THInferRequest, not the LibTorch tensor. This same ownership pattern
  will be reused for zero-copy CUDA tensors in a follow-up commit.
- Update th_create_inference_request() to zero-initialise the new fields
- Update th_free_request() to release the persistent buffer on teardown
- Add AV_PIX_FMT_CUDA detection with a clear ENOSYS error as a hook
  point for the zero-copy GPU path (follow-up commit)
- Fix pre-existing SIGSEGV: parameters().begin() was unconditionally
  dereferenced in th_start_inference() even when the model has no
  learnable parameters. Parameterless TorchScript models now default
  to the CPU device instead of crashing.

The lazy reallocation logic also lays the groundwork for dynamic-shape
handling (Phase 3 of the GSoC project).

Tested with:
  ./ffmpeg -f lavfi -i testsrc=duration=5:size=640x480:rate=25 \
    -vf format=rgb24,dnn_processing=dnn_backend=torch:model=dummy_model.pt \
    -vcodec rawvideo -f null /dev/null
  (125 frames @ 16.2x speed, exit 0, sync and async modes)

Signed-off-by: Raja Rathour <[email protected]>
---
 libavfilter/dnn/dnn_backend_torch.cpp | 61 ++++++++++++++++++++-------
 1 file changed, 46 insertions(+), 15 deletions(-)

diff --git a/libavfilter/dnn/dnn_backend_torch.cpp 
b/libavfilter/dnn/dnn_backend_torch.cpp
index 24a202f493..e1f972510b 100644
--- a/libavfilter/dnn/dnn_backend_torch.cpp
+++ b/libavfilter/dnn/dnn_backend_torch.cpp
@@ -31,6 +31,7 @@ extern "C" {
 #include "dnn_backend_common.h"
 #include "libavutil/opt.h"
 #include "libavutil/mem.h"
+#include "libavutil/pixfmt.h"
 #include "queue.h"
 #include "safe_queue.h"
 }
@@ -47,6 +48,8 @@ typedef struct THModel {
 typedef struct THInferRequest {
     torch::Tensor *output;
     torch::Tensor *input_tensor;
+    uint8_t *input_data;      ///< Persistent buffer for input pixels
+    size_t   input_data_size; ///< Current allocated size of input_data
 } THInferRequest;
 
 typedef struct THRequestItem {
@@ -95,6 +98,10 @@ static void th_free_request(THInferRequest *request)
         delete(request->input_tensor);
         request->input_tensor = NULL;
     }
+    if (request->input_data) {
+        av_freep(&request->input_data);
+        request->input_data_size = 0;
+    }
     return;
 }
 
@@ -152,9 +159,9 @@ static int get_input_th(DNNModel *model, DNNData *input, 
const char *input_name)
     return 0;
 }
 
-static void deleter(void *arg)
+static void persistent_buf_deleter(void *arg)
 {
-    av_freep(&arg);
+    (void)arg;
 }
 
 static int fill_model_input_th(THModel *th_model, THRequestItem *request)
@@ -165,6 +172,7 @@ static int fill_model_input_th(THModel *th_model, 
THRequestItem *request)
     DNNData input = { 0 };
     DnnContext *ctx = th_model->ctx;
     int ret, width_idx, height_idx, channel_idx;
+    size_t required_size;
 
     lltask = (LastLevelTaskItem *)ff_queue_pop_front(th_model->lltask_queue);
     if (!lltask) {
@@ -175,19 +183,38 @@ static int fill_model_input_th(THModel *th_model, 
THRequestItem *request)
     task = lltask->task;
     infer_request = request->infer_request;
 
-    ret = get_input_th(&th_model->model, &input, NULL);
-    if ( ret != 0) {
+    if (task->in_frame->format == AV_PIX_FMT_CUDA) {
+        av_log(ctx, AV_LOG_ERROR,
+               "CUDA frame input is not yet supported. "
+               "Use the 'format=rgb24' filter before dnn_processing.\n");
+        ret = AVERROR(ENOSYS);
         goto err;
     }
-    width_idx = dnn_get_width_idx_by_layout(input.layout);
+
+    ret = get_input_th(&th_model->model, &input, NULL);
+    if (ret != 0) {
+        goto err;
+    }
+    width_idx  = dnn_get_width_idx_by_layout(input.layout);
     height_idx = dnn_get_height_idx_by_layout(input.layout);
     channel_idx = dnn_get_channel_idx_by_layout(input.layout);
     input.dims[height_idx] = task->in_frame->height;
-    input.dims[width_idx] = task->in_frame->width;
-    input.data = av_malloc(input.dims[height_idx] * input.dims[width_idx] *
-                           input.dims[channel_idx] * sizeof(float));
-    if (!input.data)
-        return AVERROR(ENOMEM);
+    input.dims[width_idx]  = task->in_frame->width;
+
+    required_size = (size_t)input.dims[height_idx] * input.dims[width_idx] *
+                    input.dims[channel_idx] * sizeof(float);
+
+    if (infer_request->input_data_size < required_size) {
+        av_freep(&infer_request->input_data);
+        infer_request->input_data = (uint8_t *)av_malloc(required_size);
+        if (!infer_request->input_data) {
+            infer_request->input_data_size = 0;
+            return AVERROR(ENOMEM);
+        }
+        infer_request->input_data_size = required_size;
+    }
+    input.data = infer_request->input_data;
+
     infer_request->input_tensor = new torch::Tensor();
     infer_request->output = new torch::Tensor();
 
@@ -208,7 +235,7 @@ static int fill_model_input_th(THModel *th_model, 
THRequestItem *request)
     }
     *infer_request->input_tensor = torch::from_blob(input.data,
         {1, input.dims[channel_idx], input.dims[height_idx], 
input.dims[width_idx]},
-        deleter, torch::kFloat32);
+        persistent_buf_deleter, torch::kFloat32);
     return 0;
 
 err:
@@ -246,8 +273,10 @@ static int th_start_inference(void *args)
         av_log(ctx, AV_LOG_ERROR, "input or output tensor is NULL\n");
         return DNN_GENERIC_ERROR;
     }
-    // Transfer tensor to the same device as model
-    c10::Device device = (*th_model->jit_model->parameters().begin()).device();
+    auto params = th_model->jit_model->parameters();
+    c10::Device device(torch::kCPU);
+    if (params.begin() != params.end())
+        device = (*params.begin()).device();
     if (infer_request->input_tensor->device() != device)
         *infer_request->input_tensor = infer_request->input_tensor->to(device);
     inputs.push_back(*infer_request->input_tensor);
@@ -410,8 +439,10 @@ static THInferRequest *th_create_inference_request(void)
     if (!request) {
         return NULL;
     }
-    request->input_tensor = NULL;
-    request->output = NULL;
+    request->input_tensor    = NULL;
+    request->output          = NULL;
+    request->input_data      = NULL;
+    request->input_data_size = 0;
     return request;
 }
 
-- 
2.52.0


>From 84be25a232f3c4fc1091ad9d5cf782301bedf7c7 Mon Sep 17 00:00:00 2001
From: Raja Rathour <[email protected]>
Date: Wed, 20 May 2026 15:23:41 +0530
Subject: [PATCH 2/2] avfilter/dnn: implement batching and dynamic shapes for
 Torch backend

Add support for accumulating multiple frames into a single tensor batch
before running inference, significantly improving throughput for
hardware accelerators.

Key changes:
- Add 'batch_size' AVOption (default 1, max 32)
- Add accumulator array logic in THModel to queue incoming frames
- Implement execute_batch_th() which stacks accumulated [1,C,H,W]
  tensors into a single [B,C,H,W] batch using torch::cat()
- Split the batched output tensor back into individual frames and
  dispatch them back to the FFmpeg filter pipeline
- Update dnn_flush_th() to handle partial batches at stream end (EOS)
- Ensure backward compatibility: batch_size=1 degrades to the
  zero-overhead single-frame execution path
- Dynamic Shape Handling: Detects mid-stream resolution changes and
  automatically flushes the accumulator to prevent torch::cat from
  crashing on mismatched spatial dimensions

Signed-off-by: Raja Rathour <[email protected]>
---
 libavfilter/dnn/dnn_backend_torch.cpp | 278 +++++++++++++++++++++-----
 libavfilter/dnn_interface.h           |   1 +
 2 files changed, 231 insertions(+), 48 deletions(-)

diff --git a/libavfilter/dnn/dnn_backend_torch.cpp 
b/libavfilter/dnn/dnn_backend_torch.cpp
index e1f972510b..c844a99a87 100644
--- a/libavfilter/dnn/dnn_backend_torch.cpp
+++ b/libavfilter/dnn/dnn_backend_torch.cpp
@@ -43,6 +43,11 @@ typedef struct THModel {
     SafeQueue *request_queue;
     Queue *task_queue;
     Queue *lltask_queue;
+    int              batch_size;       ///< configured batch size (from 
AVOption)
+    int              batch_count;      ///< frames currently accumulated
+    torch::Tensor  **batch_tensors;    ///< array[batch_size] of per-frame 
tensors
+    LastLevelTaskItem **batch_lltasks; ///< array[batch_size] of matching 
lltasks
+    struct THRequestItem **batch_requests;   ///< array[batch_size] of 
accumulating requests
 } THModel;
 
 typedef struct THInferRequest {
@@ -62,7 +67,10 @@ typedef struct THRequestItem {
 #define OFFSET(x) offsetof(THOptions, x)
 #define FLAGS AV_OPT_FLAG_FILTERING_PARAM
 static const AVOption dnn_th_options[] = {
-    { "optimize", "turn on graph executor optimization", OFFSET(optimize), 
AV_OPT_TYPE_INT, { .i64 = 0 }, 0, 1, FLAGS},
+    { "optimize",   "turn on graph executor optimization",
+      OFFSET(optimize),   AV_OPT_TYPE_INT, { .i64 = 0 }, 0,  1,  FLAGS },
+    { "batch_size", "number of frames to batch per inference call",
+      OFFSET(batch_size), AV_OPT_TYPE_INT, { .i64 = 1 }, 1, 32, FLAGS },
     { NULL }
 };
 
@@ -127,6 +135,18 @@ static void dnn_free_model_th(DNNModel **model)
 
     th_model = (THModel *)(*model);
 
+    if (th_model->batch_tensors) {
+        for (int i = 0; i < th_model->batch_count; i++) {
+            delete th_model->batch_tensors[i];
+            th_model->batch_tensors[i] = NULL;
+        }
+        av_freep(&th_model->batch_tensors);
+    }
+    if (th_model->batch_lltasks)
+        av_freep(&th_model->batch_lltasks);
+    if (th_model->batch_requests)
+        av_freep(&th_model->batch_requests);
+
     if (th_model->request_queue) {
         while (ff_safe_queue_size(th_model->request_queue) != 0) {
             THRequestItem *item = (THRequestItem 
*)ff_safe_queue_pop_front(th_model->request_queue);
@@ -159,6 +179,12 @@ static int get_input_th(DNNModel *model, DNNData *input, 
const char *input_name)
     return 0;
 }
 
+/**
+ * No-op deleter for the persistent buffer path.
+ * Memory is owned by THInferRequest->input_data, not the tensor.
+ * This same pattern will be reused for zero-copy CUDA tensors where
+ * GPU memory is owned by FFmpeg's AVBuffer reference-counting API.
+ */
 static void persistent_buf_deleter(void *arg)
 {
     (void)arg;
@@ -183,6 +209,7 @@ static int fill_model_input_th(THModel *th_model, 
THRequestItem *request)
     task = lltask->task;
     infer_request = request->infer_request;
 
+    /* Detect CUDA frames - zero-copy path will be implemented here in a 
follow-up commit */
     if (task->in_frame->format == AV_PIX_FMT_CUDA) {
         av_log(ctx, AV_LOG_ERROR,
                "CUDA frame input is not yet supported. "
@@ -204,6 +231,13 @@ static int fill_model_input_th(THModel *th_model, 
THRequestItem *request)
     required_size = (size_t)input.dims[height_idx] * input.dims[width_idx] *
                     input.dims[channel_idx] * sizeof(float);
 
+    /*
+     * Reuse the persistent buffer when it is large enough; only reallocate
+     * when the frame size exceeds the current capacity.  This eliminates
+     * per-frame av_malloc/av_free churn and also provides the lazy-
+     * reallocation behaviour needed for dynamic-shape (resolution-change)
+     * support.
+     */
     if (infer_request->input_data_size < required_size) {
         av_freep(&infer_request->input_data);
         infer_request->input_data = (uint8_t *)av_malloc(required_size);
@@ -273,6 +307,13 @@ static int th_start_inference(void *args)
         av_log(ctx, AV_LOG_ERROR, "input or output tensor is NULL\n");
         return DNN_GENERIC_ERROR;
     }
+    /*
+     * Determine the device the model lives on.  For models with learnable
+     * parameters (e.g. ResNet) we read the device from the first parameter
+     * tensor.  For pure-function TorchScript models (e.g. identity, pre-proc
+     * pipelines) the parameter list is empty, so we fall back to CPU which is
+     * always a safe choice for the initial CPU-only path.
+     */
     auto params = th_model->jit_model->parameters();
     c10::Device device(torch::kCPU);
     if (params.begin() != params.end())
@@ -286,64 +327,145 @@ static int th_start_inference(void *args)
     return 0;
 }
 
-static void infer_completion_callback(void *args) {
-    THRequestItem *request = (THRequestItem*)args;
-    LastLevelTaskItem *lltask = request->lltask;
+static int process_single_output(THModel *th_model, torch::Tensor out_slice,
+                                 LastLevelTaskItem *lltask, THRequestItem 
*request)
+{
     TaskItem *task = lltask->task;
     DNNData outputs = { 0 };
-    THInferRequest *infer_request = request->infer_request;
-    THModel *th_model = (THModel *)task->model;
-    torch::Tensor *output = infer_request->output;
+    c10::IntArrayRef sizes = out_slice.sizes();
 
-    c10::IntArrayRef sizes = output->sizes();
-    outputs.order = DCO_RGB;
+    outputs.order  = DCO_RGB;
     outputs.layout = DL_NCHW;
-    outputs.dt = DNN_FLOAT;
+    outputs.dt     = DNN_FLOAT;
+
     if (sizes.size() == 4) {
-        // 4 dimensions: [batch_size, channel, height, width]
-        // this format of data is normally used for video frame SR
-        outputs.dims[0] = sizes.at(0); // N
-        outputs.dims[1] = sizes.at(1); // C
-        outputs.dims[2] = sizes.at(2); // H
-        outputs.dims[3] = sizes.at(3); // W
+        outputs.dims[0] = sizes.at(0);
+        outputs.dims[1] = sizes.at(1);
+        outputs.dims[2] = sizes.at(2);
+        outputs.dims[3] = sizes.at(3);
     } else {
         avpriv_report_missing_feature(th_model->ctx, "Support of this kind of 
model");
-        goto err;
+        return DNN_GENERIC_ERROR;
     }
 
     switch (th_model->model.func_type) {
     case DFT_PROCESS_FRAME:
         if (task->do_ioproc) {
-            // Post process can only deal with CPU memory.
-            if (output->device() != torch::kCPU)
-                *output = output->to(torch::kCPU);
+            if (out_slice.device() != torch::kCPU)
+                out_slice = out_slice.to(torch::kCPU);
             outputs.scale = 255;
-            outputs.data = output->data_ptr();
-            if (th_model->model.frame_post_proc != NULL) {
-                th_model->model.frame_post_proc(task->out_frame, &outputs, 
th_model->model.filter_ctx);
-            } else {
+            outputs.data  = out_slice.data_ptr();
+            if (th_model->model.frame_post_proc != NULL)
+                th_model->model.frame_post_proc(task->out_frame, &outputs,
+                                                th_model->model.filter_ctx);
+            else
                 ff_proc_from_dnn_to_frame(task->out_frame, &outputs, 
th_model->ctx);
-            }
         } else {
-            task->out_frame->width = 
outputs.dims[dnn_get_width_idx_by_layout(outputs.layout)];
+            task->out_frame->width  = 
outputs.dims[dnn_get_width_idx_by_layout(outputs.layout)];
             task->out_frame->height = 
outputs.dims[dnn_get_height_idx_by_layout(outputs.layout)];
         }
         break;
     default:
-        avpriv_report_missing_feature(th_model->ctx, "model function type %d", 
th_model->model.func_type);
-        goto err;
+        avpriv_report_missing_feature(th_model->ctx, "model function type %d",
+                                      th_model->model.func_type);
+        return DNN_GENERIC_ERROR;
     }
     task->inference_done++;
+    return 0;
+}
+
+static void infer_completion_callback(void *args) {
+    THRequestItem *request = (THRequestItem*)args;
+    LastLevelTaskItem *lltask = request->lltask;
+    TaskItem *task = lltask->task;
+    THInferRequest *infer_request = request->infer_request;
+    THModel *th_model = (THModel *)task->model;
+    torch::Tensor *output = infer_request->output;
+
+    if (process_single_output(th_model, *output, lltask, request) < 0)
+        goto err;
+
     av_freep(&request->lltask);
 err:
     th_free_request(infer_request);
 
     if (ff_safe_queue_push_back(th_model->request_queue, request) < 0) {
         destroy_request_item(&request);
-        av_log(th_model->ctx, AV_LOG_ERROR, "Unable to push back request_queue 
when failed to start inference.\n");
+        av_log(th_model->ctx, AV_LOG_ERROR,
+               "Unable to push back request_queue when failed to start 
inference.\n");
     }
 }
 
+static int execute_batch_th(THModel *th_model, int count)
+{
+    DnnContext *ctx = th_model->ctx;
+    int ret = 0;
+
+    if (count == 0)
+        goto done;
+
+    try {
+        torch::NoGradGuard no_grad;
+
+        auto params = th_model->jit_model->parameters();
+        c10::Device device(torch::kCPU);
+        if (params.begin() != params.end())
+            device = (*params.begin()).device();
+
+        if (ctx->torch_option.optimize)
+            torch::jit::setGraphExecutorOptimize(true);
+        else
+            torch::jit::setGraphExecutorOptimize(false);
+
+        std::vector<torch::Tensor> tensor_list;
+        tensor_list.reserve(count);
+        for (int i = 0; i < count; i++) {
+            torch::Tensor t = th_model->batch_tensors[i]->to(device);
+            tensor_list.push_back(t);
+        }
+        torch::Tensor batch_input = torch::cat(tensor_list, /*dim=*/0);
+
+        std::vector<torch::jit::IValue> inputs;
+        inputs.push_back(batch_input);
+        torch::Tensor batch_output = 
th_model->jit_model->forward(inputs).toTensor();
+
+        auto slices = torch::split(batch_output, /*split_size=*/1, /*dim=*/0);
+
+        for (int i = 0; i < count; i++) {
+            ret = process_single_output(th_model, slices[i],
+                                        th_model->batch_lltasks[i], 
th_model->batch_requests[i]);
+            if (ret < 0) {
+                av_log(ctx, AV_LOG_ERROR,
+                       "batch output[%d] post-processing failed\n", i);
+                /* Continue processing remaining frames to avoid leaking 
tasks. */
+            }
+            av_freep(&th_model->batch_lltasks[i]);
+        }
+    } catch (const c10::Error& e) {
+        av_log(ctx, AV_LOG_ERROR, "Batch inference failed: %s\n", e.what());
+        ret = DNN_GENERIC_ERROR;
+        for (int i = 0; i < count; i++)
+            av_freep(&th_model->batch_lltasks[i]);
+    }
+
+    for (int i = 0; i < count; i++) {
+        delete th_model->batch_tensors[i];
+        th_model->batch_tensors[i] = NULL;
+    }
+    th_model->batch_count = 0;
+
+done:
+    for (int i = 0; i < count; i++) {
+        THRequestItem *req = th_model->batch_requests[i];
+        th_free_request(req->infer_request);
+        if (ff_safe_queue_push_back(th_model->request_queue, req) < 0) {
+            destroy_request_item(&req);
+            av_log(ctx, AV_LOG_ERROR, "Unable to push back request_queue after 
batch.\n");
+        }
+    }
+    return ret;
+}
+
 static int execute_model_th(THRequestItem *request, Queue *lltask_queue)
 {
     THModel *th_model = NULL;
@@ -470,13 +592,6 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx, 
DNNFunctionType func_type, A
 #else
         at::detail::getXPUHooks().initXPU();
 #endif
-    } else if (device.is_cuda()) {
-        // CUDA device - works for both NVIDIA CUDA and AMD ROCm (which uses 
CUDA-compatible API)
-        if (!torch::cuda::is_available()) {
-            av_log(ctx, AV_LOG_ERROR, "CUDA/ROCm is not available\n");
-            goto fail;
-        }
-        av_log(ctx, AV_LOG_INFO, "Using CUDA/ROCm device: %s\n", device_name);
     } else if (!device.is_cpu()) {
         av_log(ctx, AV_LOG_ERROR, "Not supported device:\"%s\"\n", 
device_name);
         goto fail;
@@ -496,23 +611,36 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx, 
DNNFunctionType func_type, A
         goto fail;
     }
 
-    item = (THRequestItem *)av_mallocz(sizeof(THRequestItem));
-    if (!item) {
+    th_model->batch_size  = ctx->torch_option.batch_size;
+    th_model->batch_count = 0;
+    th_model->batch_tensors = (torch::Tensor **)av_calloc(th_model->batch_size,
+                                                          
sizeof(*th_model->batch_tensors));
+    if (!th_model->batch_tensors)
         goto fail;
-    }
-    item->infer_request = th_create_inference_request();
-    if (!item->infer_request) {
+    th_model->batch_lltasks = (LastLevelTaskItem 
**)av_calloc(th_model->batch_size,
+                                                              
sizeof(*th_model->batch_lltasks));
+    if (!th_model->batch_lltasks)
         goto fail;
-    }
 
-    item->exec_module.start_inference = &th_start_inference;
-    item->exec_module.callback = &infer_completion_callback;
-    item->exec_module.args = item;
+    for (int i = 0; i < th_model->batch_size; i++) {
+        item = (THRequestItem *)av_mallocz(sizeof(THRequestItem));
+        if (!item) {
+            goto fail;
+        }
+        item->infer_request = th_create_inference_request();
+        if (!item->infer_request) {
+            goto fail;
+        }
 
-    if (ff_safe_queue_push_back(th_model->request_queue, item) < 0) {
-        goto fail;
+        item->exec_module.start_inference = &th_start_inference;
+        item->exec_module.callback = &infer_completion_callback;
+        item->exec_module.args = item;
+
+        if (ff_safe_queue_push_back(th_model->request_queue, item) < 0) {
+            goto fail;
+        }
+        item = NULL;
     }
-    item = NULL;
 
     th_model->task_queue = ff_queue_create();
     th_model->lltask_queue = ff_queue_create();
@@ -537,6 +665,7 @@ static int dnn_execute_model_th(const DNNModel *model, 
DNNExecBaseParams *exec_p
     DnnContext *ctx = th_model->ctx;
     TaskItem *task;
     THRequestItem *request;
+    LastLevelTaskItem *lltask;
     int ret = 0;
 
     ret = ff_check_exec_params(ctx, DNN_TH, model->func_type, exec_params);
@@ -571,6 +700,55 @@ static int dnn_execute_model_th(const DNNModel *model, 
DNNExecBaseParams *exec_p
         return ret;
     }
 
+    if (th_model->batch_size > 1) {
+        int bs = th_model->batch_size;
+        int bc = th_model->batch_count;
+
+        if (bc > 0) {
+            TaskItem *first_task = th_model->batch_lltasks[0]->task;
+            if (first_task->in_frame->width != task->in_frame->width ||
+                first_task->in_frame->height != task->in_frame->height) {
+                av_log(ctx, AV_LOG_INFO, "Resolution changed mid-batch, 
flushing accumulator.\n");
+                ret = execute_batch_th(th_model, bc);
+                if (ret != 0) return ret;
+                bc = 0;
+            }
+        }
+
+        request = (THRequestItem 
*)ff_safe_queue_pop_front(th_model->request_queue);
+        if (!request) {
+            av_log(ctx, AV_LOG_ERROR, "unable to get infer request for 
batch.\n");
+            return AVERROR(EINVAL);
+        }
+
+        ret = fill_model_input_th(th_model, request);
+        if (ret != 0) {
+            if (ff_safe_queue_push_back(th_model->request_queue, request) < 0)
+                destroy_request_item(&request);
+            return ret;
+        }
+
+        th_model->batch_tensors[bc]  = request->infer_request->input_tensor;
+        request->infer_request->input_tensor = NULL;
+        
+        th_model->batch_lltasks[bc]  = request->lltask;
+        request->lltask              = NULL;
+        
+        if (!th_model->batch_requests) {
+            th_model->batch_requests = (THRequestItem **)av_calloc(bs, 
sizeof(THRequestItem *));
+            if (!th_model->batch_requests) return AVERROR(ENOMEM);
+        }
+        th_model->batch_requests[bc] = request;
+        
+        th_model->batch_count        = bc + 1;
+
+        if (th_model->batch_count < bs) {
+            return 0;
+        }
+
+        return execute_batch_th(th_model, th_model->batch_count);
+    }
+
     request = (THRequestItem 
*)ff_safe_queue_pop_front(th_model->request_queue);
     if (!request) {
         av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
@@ -591,6 +769,10 @@ static int dnn_flush_th(const DNNModel *model)
     THModel *th_model = (THModel *)model;
     THRequestItem *request;
 
+    if (th_model->batch_size > 1 && th_model->batch_count > 0) {
+        return execute_batch_th(th_model, th_model->batch_count);
+    }
+
     if (ff_queue_size(th_model->lltask_queue) == 0)
         // no pending task need to flush
         return 0;
diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h
index 66086409be..df01b7b93c 100644
--- a/libavfilter/dnn_interface.h
+++ b/libavfilter/dnn_interface.h
@@ -136,6 +136,7 @@ typedef struct OVOptions {
 typedef struct THOptions {
     const AVClass *clazz;
     int optimize;
+    int batch_size; ///< number of frames to accumulate per inference call
 } THOptions;
 
 typedef struct DNNModule DNNModule;
-- 
2.52.0

_______________________________________________
ffmpeg-devel mailing list -- [email protected]
To unsubscribe send an email to [email protected]

Reply via email to