Added a wrapper around the FFT so any FFT can be used
[speexdsp.git] / libspeex / mdf.c
1 /* Copyright (C) 2003-2005 Jean-Marc Valin
2
3    File: speex_echo.c
4    Echo cancelling based on the MDF algorithm described in:
5
6    J. S. Soo, K. K. Pang Multidelay block frequency adaptive filter, 
7    IEEE Trans. Acoust. Speech Signal Process., Vol. ASSP-38, No. 2, 
8    February 1990.
9
10    Redistribution and use in source and binary forms, with or without
11    modification, are permitted provided that the following conditions are
12    met:
13
14    1. Redistributions of source code must retain the above copyright notice,
15    this list of conditions and the following disclaimer.
16
17    2. Redistributions in binary form must reproduce the above copyright
18    notice, this list of conditions and the following disclaimer in the
19    documentation and/or other materials provided with the distribution.
20
21    3. The name of the author may not be used to endorse or promote products
22    derived from this software without specific prior written permission.
23
24    THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
25    IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
26    OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27    DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT,
28    INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
29    (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30    SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
31    HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
32    STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33    ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34    POSSIBILITY OF SUCH DAMAGE.
35 */
36
37 #ifdef HAVE_CONFIG_H
38 #include "config.h"
39 #endif
40
41 #include "misc.h"
42 #include "speex/speex_echo.h"
43 #include "smallft.h"
44 #include "fftwrap.h"
45
46 #include <math.h>
47
48 #ifndef M_PI
49 #define M_PI 3.14159265358979323846
50 #endif
51
52 #undef BETA
53 #define BETA .65
54
55 #define min(a,b) ((a)<(b) ? (a) : (b))
56 #define max(a,b) ((a)>(b) ? (a) : (b))
57
58 /** Compute inner product of two real vectors */
59 static inline float inner_prod(float *x, float *y, int N)
60 {
61    int i;
62    float ret=0;
63    for (i=0;i<N;i++)
64       ret += x[i]*y[i];
65    return ret;
66 }
67
68 /** Compute power spectrum of a half-complex (packed) vector */
69 static inline void power_spectrum(float *X, float *ps, int N)
70 {
71    int i, j;
72    ps[0]=X[0]*X[0];
73    for (i=1,j=1;i<N-1;i+=2,j++)
74    {
75       ps[j] =  X[i]*X[i] + X[i+1]*X[i+1];
76    }
77    ps[j]=X[i]*X[i];
78 }
79
80 /** Compute cross-power spectrum of a half-complex (packed) vectors and add to acc */
81 static inline void spectral_mul_accum(float *X, float *Y, float *acc, int N)
82 {
83    int i;
84    acc[0] += X[0]*Y[0];
85    for (i=1;i<N-1;i+=2)
86    {
87       acc[i] += (X[i]*Y[i] - X[i+1]*Y[i+1]);
88       acc[i+1] += (X[i+1]*Y[i] + X[i]*Y[i+1]);
89    }
90    acc[i] += X[i]*Y[i];
91 }
92
93 /** Compute weighted cross-power spectrum of a half-complex (packed) vector with conjugate */
94 static inline void weighted_spectral_mul_conj(float *w, float *X, float *Y, float *prod, int N)
95 {
96    int i, j;
97    prod[0] = w[0]*X[0]*Y[0];
98    for (i=1,j=1;i<N-1;i+=2,j++)
99    {
100       prod[i] = w[j]*(X[i]*Y[i] + X[i+1]*Y[i+1]);
101       prod[i+1] = w[j]*(-X[i+1]*Y[i] + X[i]*Y[i+1]);
102    }
103    prod[i] = w[j]*X[i]*Y[i];
104 }
105
106
107 /** Creates a new echo canceller state */
108 SpeexEchoState *speex_echo_state_init(int frame_size, int filter_length)
109 {
110    int i,N,M;
111    SpeexEchoState *st = (SpeexEchoState *)speex_alloc(sizeof(SpeexEchoState));
112
113    st->frame_size = frame_size;
114    st->window_size = 2*frame_size;
115    N = st->window_size;
116    M = st->M = (filter_length+st->frame_size-1)/frame_size;
117    st->cancel_count=0;
118    st->sum_adapt = 0;
119
120    st->fft_table = spx_fft_init(N);
121    
122    st->e = (float*)speex_alloc(N*sizeof(float));
123    st->x = (float*)speex_alloc(N*sizeof(float));
124    st->d = (float*)speex_alloc(N*sizeof(float));
125    st->y = (float*)speex_alloc(N*sizeof(float));
126    st->Yps = (float*)speex_alloc(N*sizeof(float));
127    st->last_y = (float*)speex_alloc(N*sizeof(float));
128    st->Yf = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
129    st->Rf = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
130    st->Xf = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
131    st->Yh = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
132    st->Eh = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
133
134    st->X = (float*)speex_alloc(M*N*sizeof(float));
135    st->Y = (float*)speex_alloc(N*sizeof(float));
136    st->E = (float*)speex_alloc(N*sizeof(float));
137    st->W = (float*)speex_alloc(M*N*sizeof(float));
138    st->PHI = (float*)speex_alloc(M*N*sizeof(float));
139    st->power = (float*)speex_alloc((frame_size+1)*sizeof(float));
140    st->power_1 = (float*)speex_alloc((frame_size+1)*sizeof(float));
141    
142    for (i=0;i<N*M;i++)
143    {
144       st->W[i] = st->PHI[i] = 0;
145    }
146
147    st->adapted = 0;
148    st->Pey = st->Pyy = 0;
149    return st;
150 }
151
152 /** Resets echo canceller state */
153 void speex_echo_state_reset(SpeexEchoState *st)
154 {
155    int i, M, N;
156    st->cancel_count=0;
157    N = st->window_size;
158    M = st->M;
159    for (i=0;i<N*M;i++)
160    {
161       st->W[i] = 0;
162       st->X[i] = 0;
163    }
164    for (i=0;i<=st->frame_size;i++)
165       st->power[i] = 0;
166    
167    st->adapted = 0;
168    st->sum_adapt = 0;
169    st->Pey = st->Pyy = 0;
170
171 }
172
173 /** Destroys an echo canceller state */
174 void speex_echo_state_destroy(SpeexEchoState *st)
175 {
176    spx_fft_destroy(st->fft_table);
177
178    speex_free(st->e);
179    speex_free(st->x);
180    speex_free(st->d);
181    speex_free(st->y);
182    speex_free(st->last_y);
183    speex_free(st->Yps);
184    speex_free(st->Yf);
185    speex_free(st->Rf);
186    speex_free(st->Xf);
187    speex_free(st->Yh);
188    speex_free(st->Eh);
189
190    speex_free(st->X);
191    speex_free(st->Y);
192    speex_free(st->E);
193    speex_free(st->W);
194    speex_free(st->PHI);
195    speex_free(st->power);
196    speex_free(st->power_1);
197
198    speex_free(st);
199 }
200
201 /** Performs echo cancellation on a frame */
202 void speex_echo_cancel(SpeexEchoState *st, short *ref, short *echo, short *out, float *Yout)
203 {
204    int i,j;
205    int N,M;
206    float Syy=0,See=0;
207    float leak_estimate;
208    float ss;
209    float adapt_rate;
210    
211    N = st->window_size;
212    M = st->M;
213    st->cancel_count++;
214    ss = 1.0f/st->cancel_count;
215    if (ss < .4/M)
216       ss=.4/M;
217
218    /* Copy input data to buffer */
219    for (i=0;i<st->frame_size;i++)
220    {
221       st->x[i] = st->x[i+st->frame_size];
222       st->x[i+st->frame_size] = echo[i];
223
224       st->d[i] = st->d[i+st->frame_size];
225       st->d[i+st->frame_size] = ref[i];
226    }
227
228    /* Shift memory: this could be optimized eventually*/
229    for (i=0;i<N*(M-1);i++)
230       st->X[i]=st->X[i+N];
231
232    /* Convert x (echo input) to frequency domain */
233    spx_fft_float(st->fft_table, st->x, &st->X[(M-1)*N]);
234    /* Compute filter response Y */
235    for (i=0;i<N;i++)
236       st->Y[i] = 0;
237    for (j=0;j<M;j++)
238       spectral_mul_accum(&st->X[j*N], &st->W[j*N], st->Y, N);
239       
240    spx_ifft_float(st->fft_table, st->Y, st->y);
241
242    /* Compute error signal (signal with echo removed) */ 
243    for (i=0;i<st->frame_size;i++)
244    {
245       float tmp_out;
246       tmp_out = (float)ref[i] - st->y[i+st->frame_size];
247       
248       st->e[i] = 0;
249       st->e[i+st->frame_size] = tmp_out;
250       
251       /* Saturation */
252       if (tmp_out>32767)
253          tmp_out = 32767;
254       else if (tmp_out<-32768)
255          tmp_out = -32768;
256       out[i] = tmp_out;
257    }
258    
259    /* Compute a bunch of correlations */
260    See = inner_prod(st->e+st->frame_size, st->e+st->frame_size, st->frame_size);
261    Syy = inner_prod(st->y+st->frame_size, st->y+st->frame_size, st->frame_size);
262    
263    /* Convert error to frequency domain */
264    spx_fft_float(st->fft_table, st->e, st->E);
265    for (i=0;i<st->frame_size;i++)
266       st->y[i] = 0;
267    spx_fft_float(st->fft_table, st->y, st->Y);
268
269    /* Compute power spectrum of echo (X), error (E) and filter response (Y) */
270    power_spectrum(st->E, st->Rf, N);
271    power_spectrum(st->Y, st->Yf, N);
272    power_spectrum(&st->X[(M-1)*N], st->Xf, N);
273    
274    /* Smooth echo energy estimate over time */
275    for (j=0;j<=st->frame_size;j++)
276       st->power[j] = (1-ss)*st->power[j] + ss*st->Xf[j];
277
278    {
279       float Pey = 0, Pyy=0;
280       float alpha;
281       for (j=0;j<=st->frame_size;j++)
282       {
283          float E, Y, Eh, Yh;
284          E = (st->Rf[j]);
285          Y = (st->Yf[j]);
286          Eh = st->Eh[j] + E;
287          Yh = st->Yh[j] + Y;
288          Pey += Eh*Yh;
289          Pyy += Yh*Yh;
290          st->Eh[j] = .95*Eh - E;
291          st->Yh[j] = .95*Yh - Y;
292       }
293       alpha = .02*Syy / (1e4+See);
294       if (alpha > .02)
295          alpha = .02;
296       st->Pey = (1-alpha)*st->Pey + alpha*Pey;
297       st->Pyy = (1-alpha)*st->Pyy + alpha*Pyy;
298       if (st->Pey< .001*st->Pyy)
299          st->Pey = .001*st->Pyy;
300       leak_estimate = st->Pey / (1+st->Pyy);
301       if (leak_estimate > 1)
302          leak_estimate = 1;
303       /*printf ("%f\n", leak_estimate);*/
304    }
305    
306    if (!st->adapted)
307    {
308       float Sxx;
309       Sxx = inner_prod(st->x+st->frame_size, st->x+st->frame_size, st->frame_size);
310
311       /* We consider that the filter is adapted if the following is true*/
312       if (st->sum_adapt > 1)
313          st->adapted = 1;
314
315       /* Temporary adaption rate if filter is not adapted correctly */
316       adapt_rate = .2f * Sxx / (1e4+See);
317       if (adapt_rate>.2)
318          adapt_rate = .2;
319       adapt_rate /= M;
320       
321       /* How much have we adapted so far? */
322       st->sum_adapt += adapt_rate;
323    }
324
325    if (st->adapted)
326    {
327       adapt_rate = 1.f/M;
328       for (i=0;i<=st->frame_size;i++)
329       {
330          float r;
331          /* Compute frequency-domain adaptation mask */
332          r = leak_estimate*st->Yf[i] / (1+st->Rf[i]);
333          if (r>1)
334             r = 1;
335          st->power_1[i] = adapt_rate*r/(1.f+st->power[i]);
336       }
337    } else {
338       for (i=0;i<=st->frame_size;i++)
339          st->power_1[i] = adapt_rate/(1.f+st->power[i]);
340    }
341
342    /* Compute weight gradient */
343    for (j=0;j<M;j++)
344    {
345       weighted_spectral_mul_conj(st->power_1, &st->X[j*N], st->E, st->PHI+N*j, N);
346    }
347
348    /* Gradient descent */
349    for (i=0;i<M*N;i++)
350       st->W[i] += st->PHI[i];
351    
352    /* AUMDF weight constraint */
353    for (j=0;j<M;j++)
354    {
355       /* Remove the "if" to make this an MDF filter */
356       if (j==M-1 || st->cancel_count%(M-1) == j)
357       {
358          float w[N];
359          spx_ifft_float(st->fft_table, &st->W[j*N], w);
360          for (i=st->frame_size;i<N;i++)
361          {
362             w[i]=0;
363          }
364          spx_fft_float(st->fft_table, w, &st->W[j*N]);
365       }
366    }
367
368    /* Compute spectrum of estimated echo for use in an echo post-filter (if necessary)*/
369    if (Yout)
370    {
371       if (st->adapted)
372       {
373          /* If the filter is adapted, take the filtered echo */
374          for (i=0;i<st->frame_size;i++)
375             st->last_y[i] = st->last_y[st->frame_size+i];
376          for (i=0;i<st->frame_size;i++)
377             st->last_y[st->frame_size+i] = st->y[st->frame_size+i];
378       } else {
379          /* If filter isn't adapted yet, all we can do is take the echo signal directly */
380          for (i=0;i<N;i++)
381             st->last_y[i] = st->x[i];
382       }
383       
384       /* Apply hanning window (should pre-compute it)*/
385       for (i=0;i<N;i++)
386          st->y[i] = (.5-.5*cos(2*M_PI*i/N))*st->last_y[i];
387       
388       /* Compute power spectrum of the echo */
389       spx_fft_float(st->fft_table, st->y, st->Yps);
390       power_spectrum(st->Yps, st->Yps, N);
391       
392       /* Estimate residual echo */
393       for (i=0;i<=st->frame_size;i++)
394          Yout[i] = 2.f*leak_estimate*st->Yps[i];
395    }
396
397 }
398