Real FFT cleanup, plus some testcases
[opus.git] / libcelt / kiss_fft.c
1 /*
2 Copyright (c) 2003-2004, Mark Borgerding
3 Copyright (c) 2005-2007, Jean-Marc Valin
4
5 All rights reserved.
6
7 Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
8
9     * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
10     * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
11     * Neither the author nor the names of any contributors may be used to endorse or promote products derived from this software without specific prior written permission.
12
13 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
14 */
15
16
17 #ifdef HAVE_CONFIG_H
18 #include "config.h"
19 #endif
20
21 #include "_kiss_fft_guts.h"
22 #include "arch.h"
23 #include "os_support.h"
24
25 /* The guts header contains all the multiplication and addition macros that are defined for
26  fixed or floating point complex numbers.  It also delares the kf_ internal functions.
27  */
28
29 static void kf_bfly2(
30         kiss_fft_cpx * Fout,
31         const size_t fstride,
32         const kiss_fft_cfg st,
33         int m,
34         int N,
35         int mm
36         )
37 {
38     kiss_fft_cpx * Fout2;
39     kiss_fft_cpx * tw1;
40     kiss_fft_cpx t;
41     if (!st->inverse) {
42        int i,j;
43        kiss_fft_cpx * Fout_beg = Fout;
44        for (i=0;i<N;i++)
45        {
46           Fout = Fout_beg + i*mm;
47           Fout2 = Fout + m;
48           tw1 = st->twiddles;
49           for(j=0;j<m;j++)
50           {
51              /* Almost the same as the code path below, except that we divide the input by two
52               (while keeping the best accuracy possible) */
53              celt_word32_t tr, ti;
54              tr = SHR32(SUB32(MULT16_16(Fout2->r , tw1->r),MULT16_16(Fout2->i , tw1->i)), 1);
55              ti = SHR32(ADD32(MULT16_16(Fout2->i , tw1->r),MULT16_16(Fout2->r , tw1->i)), 1);
56              tw1 += fstride;
57              Fout2->r = PSHR32(SUB32(SHL32(EXTEND32(Fout->r), 14), tr), 15);
58              Fout2->i = PSHR32(SUB32(SHL32(EXTEND32(Fout->i), 14), ti), 15);
59              Fout->r = PSHR32(ADD32(SHL32(EXTEND32(Fout->r), 14), tr), 15);
60              Fout->i = PSHR32(ADD32(SHL32(EXTEND32(Fout->i), 14), ti), 15);
61              ++Fout2;
62              ++Fout;
63           }
64        }
65     } else {
66        int i,j;
67        kiss_fft_cpx * Fout_beg = Fout;
68        for (i=0;i<N;i++)
69        {
70           Fout = Fout_beg + i*mm;
71           Fout2 = Fout + m;
72           tw1 = st->twiddles;
73           for(j=0;j<m;j++)
74           {
75              C_MUL (t,  *Fout2 , *tw1);
76              tw1 += fstride;
77              C_SUB( *Fout2 ,  *Fout , t );
78              C_ADDTO( *Fout ,  t );
79              ++Fout2;
80              ++Fout;
81           }
82        }
83     }
84 }
85
86 static void kf_bfly4(
87         kiss_fft_cpx * Fout,
88         const size_t fstride,
89         const kiss_fft_cfg st,
90         int m,
91         int N,
92         int mm
93         )
94 {
95     kiss_fft_cpx *tw1,*tw2,*tw3;
96     kiss_fft_cpx scratch[6];
97     const size_t m2=2*m;
98     const size_t m3=3*m;
99     int i, j;
100
101     if (st->inverse)
102     {
103        kiss_fft_cpx * Fout_beg = Fout;
104        for (i=0;i<N;i++)
105        {
106           Fout = Fout_beg + i*mm;
107           tw3 = tw2 = tw1 = st->twiddles;
108           for (j=0;j<m;j++)
109           {
110              C_MUL(scratch[0],Fout[m] , *tw1 );
111              C_MUL(scratch[1],Fout[m2] , *tw2 );
112              C_MUL(scratch[2],Fout[m3] , *tw3 );
113              
114              C_SUB( scratch[5] , *Fout, scratch[1] );
115              C_ADDTO(*Fout, scratch[1]);
116              C_ADD( scratch[3] , scratch[0] , scratch[2] );
117              C_SUB( scratch[4] , scratch[0] , scratch[2] );
118              C_SUB( Fout[m2], *Fout, scratch[3] );
119              tw1 += fstride;
120              tw2 += fstride*2;
121              tw3 += fstride*3;
122              C_ADDTO( *Fout , scratch[3] );
123              
124              Fout[m].r = scratch[5].r - scratch[4].i;
125              Fout[m].i = scratch[5].i + scratch[4].r;
126              Fout[m3].r = scratch[5].r + scratch[4].i;
127              Fout[m3].i = scratch[5].i - scratch[4].r;
128              ++Fout;
129           }
130        }
131     } else
132     {
133        kiss_fft_cpx * Fout_beg = Fout;
134        for (i=0;i<N;i++)
135        {
136           Fout = Fout_beg + i*mm;
137           tw3 = tw2 = tw1 = st->twiddles;
138           for (j=0;j<m;j++)
139           {
140              C_MUL4(scratch[0],Fout[m] , *tw1 );
141              C_MUL4(scratch[1],Fout[m2] , *tw2 );
142              C_MUL4(scratch[2],Fout[m3] , *tw3 );
143              
144              Fout->r = PSHR16(Fout->r, 2);
145              Fout->i = PSHR16(Fout->i, 2);
146              C_SUB( scratch[5] , *Fout, scratch[1] );
147              C_ADDTO(*Fout, scratch[1]);
148              C_ADD( scratch[3] , scratch[0] , scratch[2] );
149              C_SUB( scratch[4] , scratch[0] , scratch[2] );
150              Fout[m2].r = PSHR16(Fout[m2].r, 2);
151              Fout[m2].i = PSHR16(Fout[m2].i, 2);
152              C_SUB( Fout[m2], *Fout, scratch[3] );
153              tw1 += fstride;
154              tw2 += fstride*2;
155              tw3 += fstride*3;
156              C_ADDTO( *Fout , scratch[3] );
157              
158              Fout[m].r = scratch[5].r + scratch[4].i;
159              Fout[m].i = scratch[5].i - scratch[4].r;
160              Fout[m3].r = scratch[5].r - scratch[4].i;
161              Fout[m3].i = scratch[5].i + scratch[4].r;
162              ++Fout;
163           }
164        }
165     }
166 }
167
168 static void kf_bfly3(
169          kiss_fft_cpx * Fout,
170          const size_t fstride,
171          const kiss_fft_cfg st,
172          size_t m
173          )
174 {
175      size_t k=m;
176      const size_t m2 = 2*m;
177      kiss_fft_cpx *tw1,*tw2;
178      kiss_fft_cpx scratch[5];
179      kiss_fft_cpx epi3;
180      epi3 = st->twiddles[fstride*m];
181
182      tw1=tw2=st->twiddles;
183
184      do{
185         if (!st->inverse) {
186          C_FIXDIV(*Fout,3); C_FIXDIV(Fout[m],3); C_FIXDIV(Fout[m2],3);
187         }
188
189          C_MUL(scratch[1],Fout[m] , *tw1);
190          C_MUL(scratch[2],Fout[m2] , *tw2);
191
192          C_ADD(scratch[3],scratch[1],scratch[2]);
193          C_SUB(scratch[0],scratch[1],scratch[2]);
194          tw1 += fstride;
195          tw2 += fstride*2;
196
197          Fout[m].r = Fout->r - HALF_OF(scratch[3].r);
198          Fout[m].i = Fout->i - HALF_OF(scratch[3].i);
199
200          C_MULBYSCALAR( scratch[0] , epi3.i );
201
202          C_ADDTO(*Fout,scratch[3]);
203
204          Fout[m2].r = Fout[m].r + scratch[0].i;
205          Fout[m2].i = Fout[m].i - scratch[0].r;
206
207          Fout[m].r -= scratch[0].i;
208          Fout[m].i += scratch[0].r;
209
210          ++Fout;
211      }while(--k);
212 }
213
214 static void kf_bfly5(
215         kiss_fft_cpx * Fout,
216         const size_t fstride,
217         const kiss_fft_cfg st,
218         int m
219         )
220 {
221     kiss_fft_cpx *Fout0,*Fout1,*Fout2,*Fout3,*Fout4;
222     int u;
223     kiss_fft_cpx scratch[13];
224     kiss_fft_cpx * twiddles = st->twiddles;
225     kiss_fft_cpx *tw;
226     kiss_fft_cpx ya,yb;
227     ya = twiddles[fstride*m];
228     yb = twiddles[fstride*2*m];
229
230     Fout0=Fout;
231     Fout1=Fout0+m;
232     Fout2=Fout0+2*m;
233     Fout3=Fout0+3*m;
234     Fout4=Fout0+4*m;
235
236     tw=st->twiddles;
237     for ( u=0; u<m; ++u ) {
238         if (!st->inverse) {
239         C_FIXDIV( *Fout0,5); C_FIXDIV( *Fout1,5); C_FIXDIV( *Fout2,5); C_FIXDIV( *Fout3,5); C_FIXDIV( *Fout4,5);
240         }
241         scratch[0] = *Fout0;
242
243         C_MUL(scratch[1] ,*Fout1, tw[u*fstride]);
244         C_MUL(scratch[2] ,*Fout2, tw[2*u*fstride]);
245         C_MUL(scratch[3] ,*Fout3, tw[3*u*fstride]);
246         C_MUL(scratch[4] ,*Fout4, tw[4*u*fstride]);
247
248         C_ADD( scratch[7],scratch[1],scratch[4]);
249         C_SUB( scratch[10],scratch[1],scratch[4]);
250         C_ADD( scratch[8],scratch[2],scratch[3]);
251         C_SUB( scratch[9],scratch[2],scratch[3]);
252
253         Fout0->r += scratch[7].r + scratch[8].r;
254         Fout0->i += scratch[7].i + scratch[8].i;
255
256         scratch[5].r = scratch[0].r + S_MUL(scratch[7].r,ya.r) + S_MUL(scratch[8].r,yb.r);
257         scratch[5].i = scratch[0].i + S_MUL(scratch[7].i,ya.r) + S_MUL(scratch[8].i,yb.r);
258
259         scratch[6].r =  S_MUL(scratch[10].i,ya.i) + S_MUL(scratch[9].i,yb.i);
260         scratch[6].i = -S_MUL(scratch[10].r,ya.i) - S_MUL(scratch[9].r,yb.i);
261
262         C_SUB(*Fout1,scratch[5],scratch[6]);
263         C_ADD(*Fout4,scratch[5],scratch[6]);
264
265         scratch[11].r = scratch[0].r + S_MUL(scratch[7].r,yb.r) + S_MUL(scratch[8].r,ya.r);
266         scratch[11].i = scratch[0].i + S_MUL(scratch[7].i,yb.r) + S_MUL(scratch[8].i,ya.r);
267         scratch[12].r = - S_MUL(scratch[10].i,yb.i) + S_MUL(scratch[9].i,ya.i);
268         scratch[12].i = S_MUL(scratch[10].r,yb.i) - S_MUL(scratch[9].r,ya.i);
269
270         C_ADD(*Fout2,scratch[11],scratch[12]);
271         C_SUB(*Fout3,scratch[11],scratch[12]);
272
273         ++Fout0;++Fout1;++Fout2;++Fout3;++Fout4;
274     }
275 }
276
277 /* perform the butterfly for one stage of a mixed radix FFT */
278 static void kf_bfly_generic(
279         kiss_fft_cpx * Fout,
280         const size_t fstride,
281         const kiss_fft_cfg st,
282         int m,
283         int p
284         )
285 {
286     int u,k,q1,q;
287     kiss_fft_cpx * twiddles = st->twiddles;
288     kiss_fft_cpx t;
289     kiss_fft_cpx scratchbuf[17];
290     int Norig = st->nfft;
291
292     /*CHECKBUF(scratchbuf,nscratchbuf,p);*/
293     if (p>17)
294        celt_fatal("KissFFT: max radix supported is 17");
295     
296     for ( u=0; u<m; ++u ) {
297         k=u;
298         for ( q1=0 ; q1<p ; ++q1 ) {
299             scratchbuf[q1] = Fout[ k  ];
300         if (!st->inverse) {
301             C_FIXDIV(scratchbuf[q1],p);
302         }
303             k += m;
304         }
305
306         k=u;
307         for ( q1=0 ; q1<p ; ++q1 ) {
308             int twidx=0;
309             Fout[ k ] = scratchbuf[0];
310             for (q=1;q<p;++q ) {
311                 twidx += fstride * k;
312                 if (twidx>=Norig) twidx-=Norig;
313                 C_MUL(t,scratchbuf[q] , twiddles[twidx] );
314                 C_ADDTO( Fout[ k ] ,t);
315             }
316             k += m;
317         }
318     }
319 }
320
321 static
322 void compute_bitrev_table(
323          int * Fout,
324          int f,
325          const size_t fstride,
326          int in_stride,
327          int * factors,
328          const kiss_fft_cfg st
329             )
330 {
331    const int p=*factors++; /* the radix  */
332    const int m=*factors++; /* stage's fft length/p */
333    
334     /*printf ("fft %d %d %d %d %d %d\n", p*m, m, p, s2, fstride*in_stride, N);*/
335    if (m==1)
336    {
337       int j;
338       for (j=0;j<p;j++)
339       {
340          Fout[j] = f;
341          f += fstride*in_stride;
342       }
343    } else {
344       int j;
345       for (j=0;j<p;j++)
346       {
347          compute_bitrev_table( Fout , f, fstride*p, in_stride, factors,st);
348          f += fstride*in_stride;
349          Fout += m;
350       }
351    }
352 }
353
354 static
355 void kf_work(
356         kiss_fft_cpx * Fout,
357         const kiss_fft_cpx * f,
358         const size_t fstride,
359         int in_stride,
360         int * factors,
361         const kiss_fft_cfg st,
362         int N,
363         int s2,
364         int m2
365         )
366 {
367     int i;
368     kiss_fft_cpx * Fout_beg=Fout;
369     const int p=*factors++; /* the radix  */
370     const int m=*factors++; /* stage's fft length/p */
371     /*printf ("fft %d %d %d %d %d %d %d\n", p*m, m, p, s2, fstride*in_stride, N, m2);*/
372     if (m!=1) 
373         kf_work( Fout , f, fstride*p, in_stride, factors,st, N*p, fstride*in_stride, m);
374
375     switch (p) {
376         case 2: kf_bfly2(Fout,fstride,st,m, N, m2); break;
377         case 3: for (i=0;i<N;i++){Fout=Fout_beg+i*m2; kf_bfly3(Fout,fstride,st,m);} break; 
378         case 4: kf_bfly4(Fout,fstride,st,m, N, m2); break;
379         case 5: for (i=0;i<N;i++){Fout=Fout_beg+i*m2; kf_bfly5(Fout,fstride,st,m);} break; 
380         default: for (i=0;i<N;i++){Fout=Fout_beg+i*m2; kf_bfly_generic(Fout,fstride,st,m,p);} break;
381     }    
382 }
383
384 /*  facbuf is populated by p1,m1,p2,m2, ...
385     where 
386     p[i] * m[i] = m[i-1]
387     m0 = n                  */
388 static 
389 void kf_factor(int n,int * facbuf)
390 {
391     int p=4;
392
393     /*factor out powers of 4, powers of 2, then any remaining primes */
394     do {
395         while (n % p) {
396             switch (p) {
397                 case 4: p = 2; break;
398                 case 2: p = 3; break;
399                 default: p += 2; break;
400             }
401             if (p>32000 || (celt_int32_t)p*(celt_int32_t)p > n)
402                 p = n;          /* no more factors, skip to end */
403         }
404         n /= p;
405         *facbuf++ = p;
406         *facbuf++ = n;
407     } while (n > 1);
408 }
409 /*
410  *
411  * User-callable function to allocate all necessary storage space for the fft.
412  *
413  * The return value is a contiguous block of memory, allocated with malloc.  As such,
414  * It can be freed with free(), rather than a kiss_fft-specific function.
415  * */
416 kiss_fft_cfg kiss_fft_alloc(int nfft,int inverse_fft,void * mem,size_t * lenmem )
417 {
418     kiss_fft_cfg st=NULL;
419     size_t memneeded = sizeof(struct kiss_fft_state)
420         + sizeof(kiss_fft_cpx)*(nfft-1); /* twiddle factors*/
421
422     if ( lenmem==NULL ) {
423         st = ( kiss_fft_cfg)KISS_FFT_MALLOC( memneeded );
424     }else{
425         if (mem != NULL && *lenmem >= memneeded)
426             st = (kiss_fft_cfg)mem;
427         *lenmem = memneeded;
428     }
429     if (st) {
430         int i;
431         st->nfft=nfft;
432         st->inverse = inverse_fft;
433 #ifdef FIXED_POINT
434         for (i=0;i<nfft;++i) {
435             celt_word32_t phase = i;
436             if (!st->inverse)
437                 phase = -phase;
438             kf_cexp2(st->twiddles+i, DIV32(SHL32(phase,17),nfft));
439         }
440 #else
441         for (i=0;i<nfft;++i) {
442            const double pi=3.14159265358979323846264338327;
443            double phase = ( -2*pi /nfft ) * i;
444            if (st->inverse)
445               phase *= -1;
446            kf_cexp(st->twiddles+i, phase );
447         }
448 #endif
449         kf_factor(nfft,st->factors);
450         
451         /* bitrev */
452         st->bitrev = celt_alloc(sizeof(int)*(nfft));
453         compute_bitrev_table(st->bitrev, 0, 1,1, st->factors,st);
454     }
455     return st;
456 }
457
458
459
460     
461 void kiss_fft_stride(kiss_fft_cfg st,const kiss_fft_cpx *fin,kiss_fft_cpx *fout,int in_stride)
462 {
463     if (fin == fout) 
464     {
465        celt_fatal("In-place FFT not supported");
466     } else {
467        /* Bit-reverse the input */
468        int i;
469        for (i=0;i<st->nfft;i++)
470           fout[i] = fin[st->bitrev[i]];
471        kf_work( fout, fin, 1,in_stride, st->factors,st, 1, in_stride, 1);
472     }
473 }
474
475 void kiss_fft(kiss_fft_cfg cfg,const kiss_fft_cpx *fin,kiss_fft_cpx *fout)
476 {
477     kiss_fft_stride(cfg,fin,fout,1);
478 }
479