31e46adc41dc17e451d0a30b9b223b650199abbe
[speexdsp.git] / libspeex / mdf.c
1 /* Copyright (C) Jean-Marc Valin
2
3    File: speex_echo.c
4    Echo cancelling based on the MDF algorithm described in:
5
6    J. S. Soo, K. K. Pang Multidelay block frequency adaptive filter, 
7    IEEE Trans. Acoust. Speech Signal Process., Vol. ASSP-38, No. 2, 
8    February 1990.
9
10    Redistribution and use in source and binary forms, with or without
11    modification, are permitted provided that the following conditions are
12    met:
13
14    1. Redistributions of source code must retain the above copyright notice,
15    this list of conditions and the following disclaimer.
16
17    2. Redistributions in binary form must reproduce the above copyright
18    notice, this list of conditions and the following disclaimer in the
19    documentation and/or other materials provided with the distribution.
20
21    3. The name of the author may not be used to endorse or promote products
22    derived from this software without specific prior written permission.
23
24    THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
25    IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
26    OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27    DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT,
28    INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
29    (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30    SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
31    HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
32    STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33    ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34    POSSIBILITY OF SUCH DAMAGE.
35 */
36
37 #ifdef HAVE_CONFIG_H
38 #include "config.h"
39 #endif
40
41 #include "misc.h"
42 #include "speex/speex_echo.h"
43 #include "smallft.h"
44 #include <math.h>
45
46 #ifndef M_PI
47 #define M_PI 3.14159265358979323846
48 #endif
49
50 #define BETA .65
51
52 #define min(a,b) ((a)<(b) ? (a) : (b))
53 #define max(a,b) ((a)>(b) ? (a) : (b))
54
55 /** Compute inner product of two real vectors */
56 static inline float inner_prod(float *x, float *y, int N)
57 {
58    int i;
59    float ret=0;
60    for (i=0;i<N;i++)
61       ret += x[i]*y[i];
62    return ret;
63 }
64
65 /** Compute power spectrum of a half-complex (packed) vector */
66 static inline void power_spectrum(float *X, float *ps, int N)
67 {
68    int i, j;
69    ps[0]=X[0]*X[0];
70    for (i=1,j=1;i<N-1;i+=2,j++)
71    {
72       ps[j] =  X[i]*X[i] + X[i+1]*X[i+1];
73    }
74    ps[j]=X[i]*X[i];
75 }
76
77 /** Compute cross-power spectrum of a half-complex (packed) vectors and add to acc */
78 static inline void spectral_mul_accum(float *X, float *Y, float *acc, int N)
79 {
80    int i;
81    acc[0] += X[0]*Y[0];
82    for (i=1;i<N-1;i+=2)
83    {
84       acc[i] += (X[i]*Y[i] - X[i+1]*Y[i+1]);
85       acc[i+1] += (X[i+1]*Y[i] + X[i]*Y[i+1]);
86    }
87    acc[i] += X[i]*Y[i];
88 }
89
90 /** Compute cross-power spectrum of a half-complex (packed) vector with conjugate */
91 static inline void spectral_mul_conj(float *X, float *Y, float *prod, int N)
92 {
93    int i;
94    prod[0] = X[0]*Y[0];
95    for (i=1;i<N-1;i+=2)
96    {
97       prod[i] = (X[i]*Y[i] + X[i+1]*Y[i+1]);
98       prod[i+1] = (-X[i+1]*Y[i] + X[i]*Y[i+1]);
99    }
100    prod[i] = X[i]*Y[i];
101 }
102
103
104 /** Compute weighted cross-power spectrum of a half-complex (packed) vector with conjugate */
105 static inline void weighted_spectral_mul_conj(float *w, float *X, float *Y, float *prod, int N)
106 {
107    int i, j;
108    prod[0] = w[0]*X[0]*Y[0];
109    for (i=1,j=1;i<N-1;i+=2,j++)
110    {
111       prod[i] = w[j]*(X[i]*Y[i] + X[i+1]*Y[i+1]);
112       prod[i+1] = w[j]*(-X[i+1]*Y[i] + X[i]*Y[i+1]);
113    }
114    prod[i] = w[j]*X[i]*Y[i];
115 }
116
117
118 /** Creates a new echo canceller state */
119 SpeexEchoState *speex_echo_state_init(int frame_size, int filter_length)
120 {
121    int i,j,N,M;
122    SpeexEchoState *st = (SpeexEchoState *)speex_alloc(sizeof(SpeexEchoState));
123
124    st->frame_size = frame_size;
125    st->window_size = 2*frame_size;
126    N = st->window_size;
127    M = st->M = (filter_length+st->frame_size-1)/frame_size;
128    st->cancel_count=0;
129    st->adapt_rate = .01f;
130    st->sum_adapt = 0;
131    st->Sey = 0;
132    st->Syy = 0;
133    st->See = 0;
134          
135    st->fft_lookup = (struct drft_lookup*)speex_alloc(sizeof(struct drft_lookup));
136    spx_drft_init(st->fft_lookup, N);
137    
138    st->x = (float*)speex_alloc(N*sizeof(float));
139    st->d = (float*)speex_alloc(N*sizeof(float));
140    st->y = (float*)speex_alloc(N*sizeof(float));
141    st->y2 = (float*)speex_alloc(N*sizeof(float));
142    st->Yps = (float*)speex_alloc(N*sizeof(float));
143    st->last_y = (float*)speex_alloc(N*sizeof(float));
144    st->Yf = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
145    st->Rf = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
146    st->Xf = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
147    st->fratio = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
148    st->regul = (float*)speex_alloc(N*sizeof(float));
149
150    st->X = (float*)speex_alloc(M*N*sizeof(float));
151    st->D = (float*)speex_alloc(N*sizeof(float));
152    st->Y = (float*)speex_alloc(N*sizeof(float));
153    st->Y2 = (float*)speex_alloc(N*sizeof(float));
154    st->E = (float*)speex_alloc(N*sizeof(float));
155    st->W = (float*)speex_alloc(M*N*sizeof(float));
156    st->PHI = (float*)speex_alloc(M*N*sizeof(float));
157    st->power = (float*)speex_alloc((frame_size+1)*sizeof(float));
158    st->power_1 = (float*)speex_alloc((frame_size+1)*sizeof(float));
159    st->grad = (float*)speex_alloc(N*M*sizeof(float));
160    
161    for (i=0;i<N*M;i++)
162    {
163       st->W[i] = st->PHI[i] = 0;
164    }
165    
166    st->regul[0] = (.01+(10.)/((4.)*(4.)))/M;
167    for (i=1,j=1;i<N-1;i+=2,j++)
168    {
169       st->regul[i] = .01+((10.)/((j+4.)*(j+4.)))/M;
170       st->regul[i+1] = .01+((10.)/((j+4.)*(j+4.)))/M;
171    }
172    st->regul[i] = .01+((10.)/((j+4.)*(j+4.)))/M;
173          
174    st->adapted = 0;
175    return st;
176 }
177
178 /** Resets echo canceller state */
179 void speex_echo_state_reset(SpeexEchoState *st)
180 {
181    int i, M, N;
182    st->cancel_count=0;
183    st->adapt_rate = .01f;
184    N = st->window_size;
185    M = st->M;
186    for (i=0;i<N*M;i++)
187    {
188       st->W[i] = 0;
189       st->X[i] = 0;
190    }
191    for (i=0;i<=st->frame_size;i++)
192       st->power[i] = 0;
193    
194    st->adapted = 0;
195    st->adapt_rate = .01f;
196    st->sum_adapt = 0;
197    st->Sey = 0;
198    st->Syy = 0;
199    st->See = 0;
200
201 }
202
203 /** Destroys an echo canceller state */
204 void speex_echo_state_destroy(SpeexEchoState *st)
205 {
206    spx_drft_clear(st->fft_lookup);
207    speex_free(st->fft_lookup);
208    speex_free(st->x);
209    speex_free(st->d);
210    speex_free(st->y);
211    speex_free(st->last_y);
212    speex_free(st->Yps);
213    speex_free(st->Yf);
214    speex_free(st->Rf);
215    speex_free(st->Xf);
216    speex_free(st->fratio);
217    speex_free(st->regul);
218
219    speex_free(st->X);
220    speex_free(st->D);
221    speex_free(st->Y);
222    speex_free(st->E);
223    speex_free(st->W);
224    speex_free(st->PHI);
225    speex_free(st->power);
226    speex_free(st->power_1);
227    speex_free(st->grad);
228
229    speex_free(st);
230 }
231
232       
233 /** Performs echo cancellation on a frame */
234 void speex_echo_cancel(SpeexEchoState *st, short *ref, short *echo, short *out, float *Yout)
235 {
236    int i,j,m;
237    int N,M;
238    float scale;
239    float ESR;
240    float SER;
241    float Sry=0,Srr=0,Syy=0,Sey=0,See=0,Sxx=0;
242    float leak_estimate;
243    
244    leak_estimate = .1+(.9/(1+2*st->sum_adapt));
245          
246    N = st->window_size;
247    M = st->M;
248    scale = 1.0f/N;
249    st->cancel_count++;
250
251    /* Copy input data to buffer */
252    for (i=0;i<st->frame_size;i++)
253    {
254       st->x[i] = st->x[i+st->frame_size];
255       st->x[i+st->frame_size] = echo[i];
256
257       st->d[i] = st->d[i+st->frame_size];
258       st->d[i+st->frame_size] = ref[i];
259    }
260
261    /* Shift memory: this could be optimized eventually*/
262    for (i=0;i<N*(M-1);i++)
263       st->X[i]=st->X[i+N];
264
265    /* Copy new echo frame */
266    for (i=0;i<N;i++)
267       st->X[(M-1)*N+i]=st->x[i];
268
269    /* Convert x (echo input) to frequency domain */
270    spx_drft_forward(st->fft_lookup, &st->X[(M-1)*N]);
271
272    /* Compute filter response Y */
273    for (i=0;i<N;i++)
274       st->Y[i] = 0;
275    for (j=0;j<M;j++)
276       spectral_mul_accum(&st->X[j*N], &st->W[j*N], st->Y, N);
277    
278    /* Convert Y (filter response) to time domain */
279    for (i=0;i<N;i++)
280       st->y[i] = st->Y[i];
281    spx_drft_backward(st->fft_lookup, st->y);
282    for (i=0;i<N;i++)
283       st->y[i] *= scale;
284
285    /* Transform d (reference signal) to frequency domain */
286    for (i=0;i<N;i++)
287       st->D[i]=st->d[i];
288    spx_drft_forward(st->fft_lookup, st->D);
289
290    /* Compute error signal (signal with echo removed) */ 
291    for (i=0;i<st->frame_size;i++)
292    {
293       float tmp_out;
294       tmp_out = (float)ref[i] - st->y[i+st->frame_size];
295       
296       st->E[i] = 0;
297       st->E[i+st->frame_size] = tmp_out;
298       
299       /* Saturation */
300       if (tmp_out>32767)
301          tmp_out = 32767;
302       else if (tmp_out<-32768)
303          tmp_out = -32768;
304       out[i] = tmp_out;
305    }
306    
307    /* This bit of code is optional and provides faster adaptation by doing a projection 
308       of the previous gradient on the "MMSE surface" */
309    if (1)
310    {
311       float Sge, Sgg, Syy;
312       float gain;
313       Syy = inner_prod(st->y+st->frame_size, st->y+st->frame_size, st->frame_size);
314       for (i=0;i<N;i++)
315          st->Y2[i] = 0;
316       for (j=0;j<M;j++)
317          spectral_mul_accum(&st->X[j*N], &st->PHI[j*N], st->Y2, N);
318       for (i=0;i<N;i++)
319          st->y2[i] = st->Y2[i];
320       spx_drft_backward(st->fft_lookup, st->y2);
321       for (i=0;i<N;i++)
322          st->y2[i] *= scale;
323       Sge = inner_prod(st->y2+st->frame_size, st->E+st->frame_size, st->frame_size);
324       Sgg = inner_prod(st->y2+st->frame_size, st->y2+st->frame_size, st->frame_size);
325       /* Compute projection gain */
326       gain = Sge/(N+.03*Syy+Sgg);
327       if (gain>2)
328          gain = 2;
329       if (gain < -2)
330          gain = -2;
331       
332       /* Apply gain to weights, echo estimates, output */
333       for (i=0;i<N;i++)
334          st->Y[i] += gain*st->Y2[i];
335       for (i=0;i<st->frame_size;i++)
336       {
337          st->y[i+st->frame_size] += gain*st->y2[i+st->frame_size];
338          st->E[i+st->frame_size] -= gain*st->y2[i+st->frame_size];
339       }
340       for (i=0;i<M*N;i++)
341          st->W[i] += gain*st->PHI[i];
342    }
343
344    /* Compute power spectrum of output (D-Y) and filter response (Y) */
345    for (i=0;i<N;i++)
346       st->D[i] -= st->Y[i];
347    power_spectrum(st->D, st->Rf, N);
348    power_spectrum(st->Y, st->Yf, N);
349    
350    /* Compute frequency-domain adaptation mask */
351    for (j=0;j<=st->frame_size;j++)
352    {
353       float r;
354       r = leak_estimate*st->Yf[j] / (1+st->Rf[j]);
355       if (r>1)
356          r = 1;
357       st->fratio[j] = r;
358    }
359
360    /* Compute a bunch of correlations */
361    Sry = inner_prod(st->y+st->frame_size, st->d+st->frame_size, st->frame_size);
362    Sey = inner_prod(st->y+st->frame_size, st->E+st->frame_size, st->frame_size);
363    See = inner_prod(st->E+st->frame_size, st->E+st->frame_size, st->frame_size);
364    Syy = inner_prod(st->y+st->frame_size, st->y+st->frame_size, st->frame_size);
365    Srr = inner_prod(st->d+st->frame_size, st->d+st->frame_size, st->frame_size);
366    Sxx = inner_prod(st->x+st->frame_size, st->x+st->frame_size, st->frame_size);
367
368    /* Compute smoothed cross-correlation and energy */   
369    st->Sey = .98*st->Sey + .02*Sey;
370    st->Syy = .98*st->Syy + .02*Syy;
371    st->See = .98*st->See + .02*See;
372    
373    /* Check if filter is completely mis-adapted (if so, reset filter) */
374    if (st->Sey/(1+st->Syy + .01*st->See) < -1)
375    {
376       /*fprintf (stderr, "reset at %d\n", st->cancel_count);*/
377       speex_echo_state_reset(st);
378       return;
379    }
380
381    SER = Srr / (1+Sxx);
382    ESR = leak_estimate*Syy / (1+See);
383    if (ESR>1)
384       ESR = 1;
385 #if 1
386    /* If over-cancellation (creating echo with 180 phase) damp filter */
387    if (st->Sey/(1+st->Syy) < -.1 && (ESR > .3))
388    {
389       for (i=0;i<M*N;i++)
390          st->W[i] *= .95;
391       st->Sey *= .5;
392       /*fprintf (stderr, "corrected down\n");*/
393    }
394 #endif
395 #if 1
396    /* If under-cancellation (leaving echo with 0 phase) scale filter up */
397    if (st->Sey/(1+st->Syy) > .1 && (ESR > .1 || SER < 10))
398    {
399       for (i=0;i<M*N;i++)
400          st->W[i] *= 1.05;
401       st->Sey *= .5;
402       /*fprintf (stderr, "corrected up %d\n", st->cancel_count);*/
403    }
404 #endif
405    
406    /* We consider that the filter is adapted if the following is true*/
407    if (ESR>.6 && st->sum_adapt > 1)
408    {
409       /*if (!st->adapted)
410          fprintf(stderr, "Adapted at %d %f\n", st->cancel_count, st->sum_adapt);*/
411       st->adapted = 1;
412    }
413    
414    /* Update frequency-dependent energy ratio with the total energy ratio */
415    for (i=0;i<=st->frame_size;i++)
416    {
417       st->fratio[i]  = (.2*ESR+.8*min(.005+ESR,st->fratio[i]));
418    }   
419
420    if (st->adapted)
421    {
422       st->adapt_rate = .95f/(2+M);
423    } else {
424       /* Temporary adaption rate if filter is not adapted correctly */
425       if (SER<.1)
426          st->adapt_rate =.5/(2+M);
427       else if (SER<1)
428          st->adapt_rate =.3/(2+M);
429       else if (SER<10)
430          st->adapt_rate =.2/(2+M);
431       else if (SER<30)
432          st->adapt_rate =.08/(2+M);
433       else
434          st->adapt_rate = 0;
435    }
436    
437    /* How much have we adapted so far? */
438    st->sum_adapt += st->adapt_rate;
439
440    /* Compute echo power in each frequency bin */
441    {
442       float ss = 1.0f/st->cancel_count;
443       if (ss < .3/M)
444          ss=.3/M;
445       power_spectrum(&st->X[(M-1)*N], st->Xf, N);
446       /* Smooth echo energy estimate over time */
447       for (j=0;j<=st->frame_size;j++)
448          st->power[j] = (1-ss)*st->power[j] + ss*st->Xf[j];
449       
450       
451       /* Combine adaptation rate to the the inverse energy estimate */
452       if (st->adapted)
453       {
454          /* If filter is adapted, include the frequency-dependent ratio too */
455          for (i=0;i<=st->frame_size;i++)
456             st->power_1[i] = st->adapt_rate*st->fratio[i] /(1.f+st->power[i]);
457       } else {
458          for (i=0;i<=st->frame_size;i++)
459             st->power_1[i] = st->adapt_rate/(1.f+st->power[i]);
460       }
461    }
462
463    
464    /* Convert error to frequency domain */
465    spx_drft_forward(st->fft_lookup, st->E);
466
467    /* Do some regularization (prevents problems when system is ill-conditoned) */
468    for (m=0;m<M;m++)
469       for (i=0;i<N;i++)
470          st->W[m*N+i] *= 1-st->regul[i]*ESR;
471    
472    /* Compute weight gradient */
473    for (j=0;j<M;j++)
474    {
475       weighted_spectral_mul_conj(st->power_1, &st->X[j*N], st->E, st->PHI+N*j, N);
476    }
477
478    /* Gradient descent */
479    for (i=0;i<M*N;i++)
480       st->W[i] += st->PHI[i];
481    
482    /* AUMDF weight constraint */
483    for (j=0;j<M;j++)
484    {
485       /* Remove the "if" to make this an MDF filter */
486       if (st->cancel_count%M == j)
487       {
488          spx_drft_backward(st->fft_lookup, &st->W[j*N]);
489          for (i=0;i<N;i++)
490             st->W[j*N+i]*=scale;
491          for (i=st->frame_size;i<N;i++)
492          {
493             st->W[j*N+i]=0;
494          }
495          spx_drft_forward(st->fft_lookup, &st->W[j*N]);
496       }
497    }
498
499    /* Compute spectrum of estimated echo for use in an echo post-filter (if necessary)*/
500    if (Yout)
501    {
502       if (st->adapted)
503       {
504          /* If the filter is adapted, take the filtered echo */
505          for (i=0;i<st->frame_size;i++)
506             st->last_y[i] = st->last_y[st->frame_size+i];
507          for (i=0;i<st->frame_size;i++)
508             st->last_y[st->frame_size+i] = st->y[st->frame_size+i];
509       } else {
510          /* If filter isn't adapted yet, all we can do is take the echo signal directly */
511          for (i=0;i<N;i++)
512             st->last_y[i] = st->x[i];
513       }
514       
515       /* Apply hanning window (should pre-compute it)*/
516       for (i=0;i<N;i++)
517          st->Yps[i] = (.5-.5*cos(2*M_PI*i/N))*st->last_y[i];
518       
519       /* Compute power spectrum of the echo */
520       spx_drft_forward(st->fft_lookup, st->Yps);
521       power_spectrum(st->Yps, st->Yps, N);
522       
523       /* Estimate residual echo */
524       for (i=0;i<=st->frame_size;i++)
525          Yout[i] = 2*leak_estimate*st->Yps[i];
526    }
527
528 }
529