Making the MDCT produce interleaved data
[opus.git] / libcelt / mdct.c
index e04f437..1e310ad 100644 (file)
@@ -99,7 +99,8 @@ void clt_mdct_clear(mdct_lookup *l)
 
 #endif /* CUSTOM_MODES */
 
-void clt_mdct_forward(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_scalar * restrict out, const opus_val16 *window, int overlap, int shift)
+void clt_mdct_forward(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_scalar * restrict out,
+      const opus_val16 *window, int overlap, int shift, int stride)
 {
    int i;
    int N, N2, N4;
@@ -124,7 +125,7 @@ void clt_mdct_forward(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_scalar
       /* Temp pointers to make it really clear to the compiler what we're doing */
       const kiss_fft_scalar * restrict xp1 = in+(overlap>>1);
       const kiss_fft_scalar * restrict xp2 = in+N2-1+(overlap>>1);
-      kiss_fft_scalar * restrict yp = out;
+      kiss_fft_scalar * restrict yp = f;
       const opus_val16 * restrict wp1 = window+(overlap>>1);
       const opus_val16 * restrict wp2 = window+(overlap>>1)-1;
       for(i=0;i<(overlap>>2);i++)
@@ -160,7 +161,7 @@ void clt_mdct_forward(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_scalar
    }
    /* Pre-rotation */
    {
-      kiss_fft_scalar * restrict yp = out;
+      kiss_fft_scalar * restrict yp = f;
       const kiss_twiddle_scalar *t = &l->trig[0];
       for(i=0;i<N4;i++)
       {
@@ -176,14 +177,14 @@ void clt_mdct_forward(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_scalar
    }
 
    /* N/4 complex FFT, down-scales by 4/N */
-   opus_fft(l->kfft[shift], (kiss_fft_cpx *)out, (kiss_fft_cpx *)f);
+   opus_fft(l->kfft[shift], (kiss_fft_cpx *)f, (kiss_fft_cpx *)in);
 
    /* Post-rotate */
    {
       /* Temp pointers to make it really clear to the compiler what we're doing */
-      const kiss_fft_scalar * restrict fp = f;
+      const kiss_fft_scalar * restrict fp = in;
       kiss_fft_scalar * restrict yp1 = out;
-      kiss_fft_scalar * restrict yp2 = out+N2-1;
+      kiss_fft_scalar * restrict yp2 = out+stride*(N2-1);
       const kiss_twiddle_scalar *t = &l->trig[0];
       /* Temp pointers to make it really clear to the compiler what we're doing */
       for(i=0;i<N4;i++)
@@ -195,14 +196,15 @@ void clt_mdct_forward(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_scalar
          *yp1 = yr - S_MUL(yi,sine);
          *yp2 = yi + S_MUL(yr,sine);;
          fp += 2;
-         yp1 += 2;
-         yp2 -= 2;
+         yp1 += 2*stride;
+         yp2 -= 2*stride;
       }
    }
    RESTORE_STACK;
 }
 
-void clt_mdct_backward(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_scalar * restrict out, const opus_val16 * restrict window, int overlap, int shift)
+void clt_mdct_backward(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_scalar * restrict out,
+      const opus_val16 * restrict window, int overlap, int shift, int stride)
 {
    int i;
    int N, N2, N4;
@@ -227,7 +229,7 @@ void clt_mdct_backward(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_scala
    {
       /* Temp pointers to make it really clear to the compiler what we're doing */
       const kiss_fft_scalar * restrict xp1 = in;
-      const kiss_fft_scalar * restrict xp2 = in+N2-1;
+      const kiss_fft_scalar * restrict xp2 = in+stride*(N2-1);
       kiss_fft_scalar * restrict yp = f2;
       const kiss_twiddle_scalar *t = &l->trig[0];
       for(i=0;i<N4;i++)
@@ -238,8 +240,8 @@ void clt_mdct_backward(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_scala
          /* works because the cos is nearly one */
          *yp++ = yr - S_MUL(yi,sine);
          *yp++ = yi + S_MUL(yr,sine);
-         xp1+=2;
-         xp2-=2;
+         xp1+=2*stride;
+         xp2-=2*stride;
       }
    }