278bef770c8706f7005bc07194b29e1936f80279
[speexdsp.git] / libspeex / mdf.c
1 /* Copyright (C) 2003-2005 Jean-Marc Valin
2
3    File: speex_echo.c
4    Echo cancelling based on the MDF algorithm described in:
5
6    J. S. Soo, K. K. Pang Multidelay block frequency adaptive filter, 
7    IEEE Trans. Acoust. Speech Signal Process., Vol. ASSP-38, No. 2, 
8    February 1990.
9
10    Redistribution and use in source and binary forms, with or without
11    modification, are permitted provided that the following conditions are
12    met:
13
14    1. Redistributions of source code must retain the above copyright notice,
15    this list of conditions and the following disclaimer.
16
17    2. Redistributions in binary form must reproduce the above copyright
18    notice, this list of conditions and the following disclaimer in the
19    documentation and/or other materials provided with the distribution.
20
21    3. The name of the author may not be used to endorse or promote products
22    derived from this software without specific prior written permission.
23
24    THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
25    IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
26    OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27    DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT,
28    INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
29    (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30    SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
31    HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
32    STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33    ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34    POSSIBILITY OF SUCH DAMAGE.
35 */
36
37 #ifdef HAVE_CONFIG_H
38 #include "config.h"
39 #endif
40
41 #include "misc.h"
42 #include "speex/speex_echo.h"
43 #include "smallft.h"
44 #include <math.h>
45
46 #ifndef M_PI
47 #define M_PI 3.14159265358979323846
48 #endif
49
50 #undef BETA
51 #define BETA .65
52
53 #define min(a,b) ((a)<(b) ? (a) : (b))
54 #define max(a,b) ((a)>(b) ? (a) : (b))
55
56 /** Compute inner product of two real vectors */
57 static inline float inner_prod(float *x, float *y, int N)
58 {
59    int i;
60    float ret=0;
61    for (i=0;i<N;i++)
62       ret += x[i]*y[i];
63    return ret;
64 }
65
66 /** Compute power spectrum of a half-complex (packed) vector */
67 static inline void power_spectrum(float *X, float *ps, int N)
68 {
69    int i, j;
70    ps[0]=X[0]*X[0];
71    for (i=1,j=1;i<N-1;i+=2,j++)
72    {
73       ps[j] =  X[i]*X[i] + X[i+1]*X[i+1];
74    }
75    ps[j]=X[i]*X[i];
76 }
77
78 /** Compute cross-power spectrum of a half-complex (packed) vectors and add to acc */
79 static inline void spectral_mul_accum(float *X, float *Y, float *acc, int N)
80 {
81    int i;
82    acc[0] += X[0]*Y[0];
83    for (i=1;i<N-1;i+=2)
84    {
85       acc[i] += (X[i]*Y[i] - X[i+1]*Y[i+1]);
86       acc[i+1] += (X[i+1]*Y[i] + X[i]*Y[i+1]);
87    }
88    acc[i] += X[i]*Y[i];
89 }
90
91 /** Compute weighted cross-power spectrum of a half-complex (packed) vector with conjugate */
92 static inline void weighted_spectral_mul_conj(float *w, float *X, float *Y, float *prod, int N)
93 {
94    int i, j;
95    prod[0] = w[0]*X[0]*Y[0];
96    for (i=1,j=1;i<N-1;i+=2,j++)
97    {
98       prod[i] = w[j]*(X[i]*Y[i] + X[i+1]*Y[i+1]);
99       prod[i+1] = w[j]*(-X[i+1]*Y[i] + X[i]*Y[i+1]);
100    }
101    prod[i] = w[j]*X[i]*Y[i];
102 }
103
104
105 /** Creates a new echo canceller state */
106 SpeexEchoState *speex_echo_state_init(int frame_size, int filter_length)
107 {
108    int i,N,M;
109    SpeexEchoState *st = (SpeexEchoState *)speex_alloc(sizeof(SpeexEchoState));
110
111    st->frame_size = frame_size;
112    st->window_size = 2*frame_size;
113    N = st->window_size;
114    M = st->M = (filter_length+st->frame_size-1)/frame_size;
115    st->cancel_count=0;
116    st->sum_adapt = 0;
117          
118    st->fft_lookup = (struct drft_lookup*)speex_alloc(sizeof(struct drft_lookup));
119    spx_drft_init(st->fft_lookup, N);
120    
121    st->x = (float*)speex_alloc(N*sizeof(float));
122    st->d = (float*)speex_alloc(N*sizeof(float));
123    st->y = (float*)speex_alloc(N*sizeof(float));
124    st->Yps = (float*)speex_alloc(N*sizeof(float));
125    st->last_y = (float*)speex_alloc(N*sizeof(float));
126    st->Yf = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
127    st->Rf = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
128    st->Xf = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
129    st->Yh = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
130    st->Eh = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
131
132    st->X = (float*)speex_alloc(M*N*sizeof(float));
133    st->Y = (float*)speex_alloc(N*sizeof(float));
134    st->E = (float*)speex_alloc(N*sizeof(float));
135    st->W = (float*)speex_alloc(M*N*sizeof(float));
136    st->PHI = (float*)speex_alloc(M*N*sizeof(float));
137    st->power = (float*)speex_alloc((frame_size+1)*sizeof(float));
138    st->power_1 = (float*)speex_alloc((frame_size+1)*sizeof(float));
139    
140    for (i=0;i<N*M;i++)
141    {
142       st->W[i] = st->PHI[i] = 0;
143    }
144
145    st->adapted = 0;
146    st->Pey = st->Pyy = 0;
147    return st;
148 }
149
150 /** Resets echo canceller state */
151 void speex_echo_state_reset(SpeexEchoState *st)
152 {
153    int i, M, N;
154    st->cancel_count=0;
155    N = st->window_size;
156    M = st->M;
157    for (i=0;i<N*M;i++)
158    {
159       st->W[i] = 0;
160       st->X[i] = 0;
161    }
162    for (i=0;i<=st->frame_size;i++)
163       st->power[i] = 0;
164    
165    st->adapted = 0;
166    st->sum_adapt = 0;
167    st->Pey = st->Pyy = 0;
168
169 }
170
171 /** Destroys an echo canceller state */
172 void speex_echo_state_destroy(SpeexEchoState *st)
173 {
174    spx_drft_clear(st->fft_lookup);
175    speex_free(st->fft_lookup);
176    speex_free(st->x);
177    speex_free(st->d);
178    speex_free(st->y);
179    speex_free(st->last_y);
180    speex_free(st->Yps);
181    speex_free(st->Yf);
182    speex_free(st->Rf);
183    speex_free(st->Xf);
184    speex_free(st->Yh);
185    speex_free(st->Eh);
186
187    speex_free(st->X);
188    speex_free(st->Y);
189    speex_free(st->E);
190    speex_free(st->W);
191    speex_free(st->PHI);
192    speex_free(st->power);
193    speex_free(st->power_1);
194
195    speex_free(st);
196 }
197
198
199 /** Performs echo cancellation on a frame */
200 void speex_echo_cancel(SpeexEchoState *st, short *ref, short *echo, short *out, float *Yout)
201 {
202    int i,j;
203    int N,M;
204    float scale;
205    float Syy=0,See=0;
206    float leak_estimate;
207    float ss;
208    float adapt_rate;
209
210    N = st->window_size;
211    M = st->M;
212    scale = 1.0f/N;
213    st->cancel_count++;
214    ss = 1.0f/st->cancel_count;
215    if (ss < .4/M)
216       ss=.4/M;
217
218    /* Copy input data to buffer */
219    for (i=0;i<st->frame_size;i++)
220    {
221       st->x[i] = st->x[i+st->frame_size];
222       st->x[i+st->frame_size] = echo[i];
223
224       st->d[i] = st->d[i+st->frame_size];
225       st->d[i+st->frame_size] = ref[i];
226    }
227
228    /* Shift memory: this could be optimized eventually*/
229    for (i=0;i<N*(M-1);i++)
230       st->X[i]=st->X[i+N];
231
232    /* Copy new echo frame */
233    for (i=0;i<N;i++)
234       st->X[(M-1)*N+i]=st->x[i];
235
236    /* Convert x (echo input) to frequency domain */
237    spx_drft_forward(st->fft_lookup, &st->X[(M-1)*N]);
238
239    /* Compute filter response Y */
240    for (i=0;i<N;i++)
241       st->Y[i] = 0;
242    for (j=0;j<M;j++)
243       spectral_mul_accum(&st->X[j*N], &st->W[j*N], st->Y, N);
244    
245    /* Convert Y (filter response) to time domain */
246    for (i=0;i<N;i++)
247       st->y[i] = st->Y[i];
248    spx_drft_backward(st->fft_lookup, st->y);
249    for (i=0;i<N;i++)
250       st->y[i] *= scale;
251
252    /* Compute error signal (signal with echo removed) */ 
253    for (i=0;i<st->frame_size;i++)
254    {
255       float tmp_out;
256       tmp_out = (float)ref[i] - st->y[i+st->frame_size];
257       
258       st->E[i] = 0;
259       st->E[i+st->frame_size] = tmp_out;
260       
261       /* Saturation */
262       if (tmp_out>32767)
263          tmp_out = 32767;
264       else if (tmp_out<-32768)
265          tmp_out = -32768;
266       out[i] = tmp_out;
267    }
268    
269    /* Compute a bunch of correlations */
270    See = inner_prod(st->E+st->frame_size, st->E+st->frame_size, st->frame_size);
271    Syy = inner_prod(st->y+st->frame_size, st->y+st->frame_size, st->frame_size);
272    
273    /* Convert error to frequency domain */
274    spx_drft_forward(st->fft_lookup, st->E);
275    for (i=0;i<st->frame_size;i++)
276       st->y[i] = 0;
277    for (i=0;i<N;i++)
278       st->Y[i] = st->y[i];
279    spx_drft_forward(st->fft_lookup, st->Y);
280    
281    /* Compute power spectrum of echo (X), error (E) and filter response (Y) */
282    power_spectrum(st->E, st->Rf, N);
283    power_spectrum(st->Y, st->Yf, N);
284    power_spectrum(&st->X[(M-1)*N], st->Xf, N);
285    
286    /* Smooth echo energy estimate over time */
287    for (j=0;j<=st->frame_size;j++)
288       st->power[j] = (1-ss)*st->power[j] + ss*st->Xf[j];
289
290    {
291       float Pey = 0, Pyy=0;
292       float alpha;
293       for (j=0;j<=st->frame_size;j++)
294       {
295          float E, Y, Eh, Yh;
296          E = (st->Rf[j]);
297          Y = (st->Yf[j]);
298          Eh = st->Eh[j] + E;
299          Yh = st->Yh[j] + Y;
300          Pey += Eh*Yh;
301          Pyy += Yh*Yh;
302          st->Eh[j] = .95*Eh - E;
303          st->Yh[j] = .95*Yh - Y;
304       }
305       alpha = .02*Syy / (1+See);
306       if (alpha > .02)
307          alpha = .02;
308       st->Pey = (1-alpha)*st->Pey + alpha*Pey;
309       st->Pyy = (1-alpha)*st->Pyy + alpha*Pyy;
310       if (st->Pey< .001*st->Pyy)
311          st->Pey = .001*st->Pyy;
312       leak_estimate = st->Pey / (1+st->Pyy);
313       if (leak_estimate > 1)
314          leak_estimate = 1;
315       /*printf ("%f\n", leak_estimate);*/
316    }
317    
318    if (!st->adapted)
319    {
320       float Sxx;
321       Sxx = inner_prod(st->x+st->frame_size, st->x+st->frame_size, st->frame_size);
322
323       /* We consider that the filter is adapted if the following is true*/
324       if (st->sum_adapt > 1)
325          st->adapted = 1;
326
327       /* Temporary adaption rate if filter is not adapted correctly */
328       adapt_rate = .2f * Sxx / (1e4+See);
329       if (adapt_rate>.2)
330          adapt_rate = .2;
331       adapt_rate /= M;
332       
333       /* How much have we adapted so far? */
334       st->sum_adapt += adapt_rate;
335    }
336
337    if (st->adapted)
338    {
339       adapt_rate = 1.f/M;
340       for (i=0;i<=st->frame_size;i++)
341       {
342          float r;
343          /* Compute frequency-domain adaptation mask */
344          r = leak_estimate*st->Yf[i] / (1+st->Rf[i]);
345          if (r>1)
346             r = 1;
347          st->power_1[i] = adapt_rate*r/(1.f+st->power[i]);
348       }
349    } else {
350       for (i=0;i<=st->frame_size;i++)
351          st->power_1[i] = adapt_rate/(1.f+st->power[i]);      
352    }
353
354    /* Compute weight gradient */
355    for (j=0;j<M;j++)
356    {
357       weighted_spectral_mul_conj(st->power_1, &st->X[j*N], st->E, st->PHI+N*j, N);
358    }
359
360    /* Gradient descent */
361    for (i=0;i<M*N;i++)
362       st->W[i] += st->PHI[i];
363    
364    /* AUMDF weight constraint */
365    for (j=0;j<M;j++)
366    {
367       /* Remove the "if" to make this an MDF filter */
368       if (j==M-1 || st->cancel_count%(M-1) == j)
369       {
370          spx_drft_backward(st->fft_lookup, &st->W[j*N]);
371          for (i=0;i<N;i++)
372             st->W[j*N+i]*=scale;
373          for (i=st->frame_size;i<N;i++)
374          {
375             st->W[j*N+i]=0;
376          }
377          spx_drft_forward(st->fft_lookup, &st->W[j*N]);
378       }
379    }
380
381    /* Compute spectrum of estimated echo for use in an echo post-filter (if necessary)*/
382    if (Yout)
383    {
384       if (st->adapted)
385       {
386          /* If the filter is adapted, take the filtered echo */
387          for (i=0;i<st->frame_size;i++)
388             st->last_y[i] = st->last_y[st->frame_size+i];
389          for (i=0;i<st->frame_size;i++)
390             st->last_y[st->frame_size+i] = st->y[st->frame_size+i];
391       } else {
392          /* If filter isn't adapted yet, all we can do is take the echo signal directly */
393          for (i=0;i<N;i++)
394             st->last_y[i] = st->x[i];
395       }
396       
397       /* Apply hanning window (should pre-compute it)*/
398       for (i=0;i<N;i++)
399          st->Yps[i] = (.5-.5*cos(2*M_PI*i/N))*st->last_y[i];
400       
401       /* Compute power spectrum of the echo */
402       spx_drft_forward(st->fft_lookup, st->Yps);
403       power_spectrum(st->Yps, st->Yps, N);
404       
405       /* Estimate residual echo */
406       for (i=0;i<=st->frame_size;i++)
407          Yout[i] = 2.f*leak_estimate*st->Yps[i];
408    }
409
410 }
411