This patch adds support for arbitrary-point FFTs and all even MDCT
transforms.
Odd MDCTs are not supported yet as they're based on the DCT-II and DCT-III
and they're very niche.
With this we can now write tests.
Patch attached.
>From 63eac77a5689e560dfa1da793d10851ea799bab7 Mon Sep 17 00:00:00 2001
From: Lynne
Date: Tue, 12 Jan 2021 08:11:47 +0100
Subject: [PATCH] lavu: support arbitrary-point FFT and all even (i)MDCT
transforms
This patch adds support for arbitrary-point FFTs and all even MDCT
transforms.
Odd MDCTs are not supported yet as they're based on the DCT-II and DCT-III
and they're very niche.
With this we can now write tests.
---
libavutil/tx.h | 5 +-
libavutil/tx_priv.h | 3 ++
libavutil/tx_template.c | 101 +---
3 files changed, 101 insertions(+), 8 deletions(-)
diff --git a/libavutil/tx.h b/libavutil/tx.h
index 418e8ec1ed..f49eb8c4c7 100644
--- a/libavutil/tx.h
+++ b/libavutil/tx.h
@@ -51,6 +51,8 @@ enum AVTXType {
* For inverse transforms, the stride specifies the spacing between each
* sample in the input array in bytes. The output will be a flat array.
* Stride must be a non-zero multiple of sizeof(float).
+ * NOTE: the inverse transform is half-length, meaning the output will not
+ * contain redundant data. This is what most codecs work with.
*/
AV_TX_FLOAT_MDCT = 1,
/**
@@ -93,8 +95,7 @@ typedef void (*av_tx_fn)(AVTXContext *s, void *out, void *in, ptrdiff_t stride);
/**
* Initialize a transform context with the given configuration
- * Currently power of two lengths from 2 to 131072 are supported, along with
- * any length decomposable to a power of two and either 3, 5 or 15.
+ * (i)MDCTs with an odd length are currently not supported.
*
* @param ctx the context to allocate, will be NULL on error
* @param tx pointer to the transform function pointer to set
diff --git a/libavutil/tx_priv.h b/libavutil/tx_priv.h
index 0ace3e90dc..18a07c312c 100644
--- a/libavutil/tx_priv.h
+++ b/libavutil/tx_priv.h
@@ -58,6 +58,7 @@ typedef void FFTComplex;
(dim) = (are) * (bim) - (aim) * (bre); \
} while (0)
+#define UNSCALE(x) (x)
#define RESCALE(x) (x)
#define FOLD(a, b) ((a) + (b))
@@ -85,6 +86,7 @@ typedef void FFTComplex;
(dim) = (int)(((accu) + 0x4000) >> 31); \
} while (0)
+#define UNSCALE(x) ((double)x/2147483648.0)
#define RESCALE(x) (av_clip64(lrintf((x) * 2147483648.0), INT32_MIN, INT32_MAX))
#define FOLD(x, y) ((int)((x) + (unsigned)(y) + 32) >> 6)
@@ -108,6 +110,7 @@ struct AVTXContext {
int m; /* Ptwo part */
int inv;/* Is inverted */
int type; /* Type */
+double scale; /* Scale */
FFTComplex *exptab; /* MDCT exptab */
FFTComplex *tmp;/* Temporary buffer needed for all compound transforms */
diff --git a/libavutil/tx_template.c b/libavutil/tx_template.c
index 7f4ca2f31e..a91b8f900c 100644
--- a/libavutil/tx_template.c
+++ b/libavutil/tx_template.c
@@ -397,6 +397,31 @@ static void monolithic_fft(AVTXContext *s, void *_out, void *_in,
fft_dispatch[mb](out);
}
+static void naive_fft(AVTXContext *s, void *_out, void *_in,
+ ptrdiff_t stride)
+{
+FFTComplex *in = _in;
+FFTComplex *out = _out;
+const int n = s->n;
+double phase = s->inv ? 2.0*M_PI/n : -2.0*M_PI/n;
+
+for(int i = 0; i < n; i++) {
+FFTComplex tmp = { 0 };
+for(int j = 0; j < n; j++) {
+const double factor = phase*i*j;
+const FFTComplex mult = {
+RESCALE(cos(factor)),
+RESCALE(sin(factor)),
+};
+FFTComplex res;
+CMUL3(res, in[j], mult);
+tmp.re += res.re;
+tmp.im += res.im;
+}
+out[i] = tmp;
+}
+}
+
#define DECL_COMP_IMDCT(N) \
static void compound_imdct_##N##xM(AVTXContext *s, void *_dst, void *_src, \
ptrdiff_t stride) \
@@ -553,6 +578,57 @@ static void monolithic_mdct(AVTXContext *s, void *_dst, void *_src,
}
}
+static void naive_imdct(AVTXContext *s, void *_dst, void *_src,
+ptrdiff_t stride)
+{
+int len = s->n;
+int len2 = len*2;
+FFTSample *src = _src;
+FFTSample *dst = _dst;
+double scale = s->scale;
+const double phase = M_PI/(4.0*len2);
+
+stride /= sizeof(*src);
+
+for (int i = 0; i < len; i++) {
+double sum_d = 0.0;
+double sum_u = 0.0;
+double i_d = phase * (4*len - 2*i - 1);
+double i_u = phase * (3*len2 + 2*i + 1);
+for (int j = 0; j < len2; j++) {
+double a = (2 * j + 1);
+double a_d = cos(a * i_d);
+double a_u = cos(a * i_u);
+double val = UNSCALE(src[j*stride]);