Some more AEC tuning
[speexdsp.git] / libspeex / mdf.c
1 /* Copyright (C) Jean-Marc Valin
2
3    File: speex_echo.c
4    Echo cancelling based on the MDF algorithm described in:
5
6    J. S. Soo, K. K. Pang Multidelay block frequency adaptive filter, 
7    IEEE Trans. Acoust. Speech Signal Process., Vol. ASSP-38, No. 2, 
8    February 1990.
9
10    Redistribution and use in source and binary forms, with or without
11    modification, are permitted provided that the following conditions are
12    met:
13
14    1. Redistributions of source code must retain the above copyright notice,
15    this list of conditions and the following disclaimer.
16
17    2. Redistributions in binary form must reproduce the above copyright
18    notice, this list of conditions and the following disclaimer in the
19    documentation and/or other materials provided with the distribution.
20
21    3. The name of the author may not be used to endorse or promote products
22    derived from this software without specific prior written permission.
23
24    THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
25    IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
26    OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27    DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT,
28    INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
29    (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30    SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
31    HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
32    STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33    ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34    POSSIBILITY OF SUCH DAMAGE.
35 */
36
37 #ifdef HAVE_CONFIG_H
38 #include "config.h"
39 #endif
40
41 #include "misc.h"
42 #include <speex/speex_echo.h>
43 #include "smallft.h"
44 #include <math.h>
45 #include <stdio.h>
46
47 #ifndef M_PI
48 #define M_PI 3.14159265358979323846
49 #endif
50
51 #define BETA .65
52 /*#define BETA 0*/
53
54 #define min(a,b) ((a)<(b) ? (a) : (b))
55 #define max(a,b) ((a)>(b) ? (a) : (b))
56
57 static inline float inner_prod(float *x, float *y, int N)
58 {
59    int i;
60    float ret=0;
61    for (i=0;i<N;i++)
62       ret += x[i]*y[i];
63    return ret;
64 }
65
66 static inline void power_spectrum(float *X, float *ps, int N)
67 {
68    int i, j;
69    ps[0]=X[0]*X[0];
70    for (i=1,j=1;i<N-1;i+=2,j++)
71    {
72       ps[j] =  X[i]*X[i] + X[i+1]*X[i+1];
73    }
74    ps[j]=X[i]*X[i];
75 }
76
77 static inline void spectral_mul_accum(float *X, float *Y, float *acc, int N)
78 {
79    int i;
80    acc[0] += X[0]*Y[0];
81    for (i=1;i<N-1;i+=2)
82    {
83       acc[i] += (X[i]*Y[i] - X[i+1]*Y[i+1]);
84       acc[i+1] += (X[i+1]*Y[i] + X[i]*Y[i+1]);
85    }
86    acc[i] += X[i]*Y[i];
87 }
88
89 static inline void spectral_mul_conj(float *X, float *Y, float *prod, int N)
90 {
91    int i;
92    prod[0] = X[0]*Y[0];
93    for (i=1;i<N-1;i+=2)
94    {
95       prod[i] = (X[i]*Y[i] + X[i+1]*Y[i+1]);
96       prod[i+1] = (-X[i+1]*Y[i] + X[i]*Y[i+1]);
97    }
98    prod[i] = X[i]*Y[i];
99 }
100
101
102 static inline void weighted_spectral_mul_conj(float *w, float *X, float *Y, float *prod, int N)
103 {
104    int i, j;
105    prod[0] = w[0]*X[0]*Y[0];
106    for (i=1,j=1;i<N-1;i+=2,j++)
107    {
108       prod[i] = w[j]*(X[i]*Y[i] + X[i+1]*Y[i+1]);
109       prod[i+1] = w[j]*(-X[i+1]*Y[i] + X[i]*Y[i+1]);
110    }
111    prod[i] = w[j]*X[i]*Y[i];
112 }
113
114
115 /** Creates a new echo canceller state */
116 SpeexEchoState *speex_echo_state_init(int frame_size, int filter_length)
117 {
118    int i,j,N,M;
119    SpeexEchoState *st = (SpeexEchoState *)speex_alloc(sizeof(SpeexEchoState));
120
121    st->frame_size = frame_size;
122    st->window_size = 2*frame_size;
123    N = st->window_size;
124    M = st->M = (filter_length+st->frame_size-1)/frame_size;
125    st->cancel_count=0;
126    st->adapt_rate = .01f;
127    st->sum_adapt = 0;
128    st->Sey = 0;
129    st->Syy = 0;
130    st->See = 0;
131          
132    st->fft_lookup = (struct drft_lookup*)speex_alloc(sizeof(struct drft_lookup));
133    spx_drft_init(st->fft_lookup, N);
134    
135    st->x = (float*)speex_alloc(N*sizeof(float));
136    st->d = (float*)speex_alloc(N*sizeof(float));
137    st->y = (float*)speex_alloc(N*sizeof(float));
138    st->Yps = (float*)speex_alloc(N*sizeof(float));
139    st->last_y = (float*)speex_alloc(N*sizeof(float));
140    st->Yf = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
141    st->Rf = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
142    st->Xf = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
143    st->fratio = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
144    st->regul = (float*)speex_alloc(N*sizeof(float));
145
146    st->X = (float*)speex_alloc(M*N*sizeof(float));
147    st->D = (float*)speex_alloc(N*sizeof(float));
148    st->Y = (float*)speex_alloc(N*sizeof(float));
149    st->E = (float*)speex_alloc(N*sizeof(float));
150    st->W = (float*)speex_alloc(M*N*sizeof(float));
151    st->PHI = (float*)speex_alloc(N*sizeof(float));
152    st->power = (float*)speex_alloc((frame_size+1)*sizeof(float));
153    st->power_1 = (float*)speex_alloc((frame_size+1)*sizeof(float));
154    st->grad = (float*)speex_alloc(N*M*sizeof(float));
155    
156    for (i=0;i<N*M;i++)
157    {
158       st->W[i] = 0;
159    }
160    
161    st->regul[0] = (.01+(10.)/((4.)*(4.)))/M;
162    for (i=1,j=1;i<N-1;i+=2,j++)
163    {
164       st->regul[i] = .01+((10.)/((j+4.)*(j+4.)))/M;
165       st->regul[i+1] = .01+((10.)/((j+4.)*(j+4.)))/M;
166    }
167    st->regul[i] = .01+((10.)/((j+4.)*(j+4.)))/M;
168          
169    st->adapted = 0;
170    return st;
171 }
172
173 void speex_echo_reset(SpeexEchoState *st)
174 {
175    int i, M, N;
176    st->cancel_count=0;
177    st->adapt_rate = .01f;
178    N = st->window_size;
179    M = st->M;
180    for (i=0;i<N*M;i++)
181    {
182       st->W[i] = 0;
183       st->X[i] = 0;
184    }
185    for (i=0;i<=st->frame_size;i++)
186       st->power[i] = 0;
187    
188    st->adapted = 0;
189    st->adapt_rate = .01f;
190    st->sum_adapt = 0;
191    st->Sey = 0;
192    st->Syy = 0;
193    st->See = 0;
194
195 }
196
197 /** Destroys an echo canceller state */
198 void speex_echo_state_destroy(SpeexEchoState *st)
199 {
200    spx_drft_clear(st->fft_lookup);
201    speex_free(st->fft_lookup);
202    speex_free(st->x);
203    speex_free(st->d);
204    speex_free(st->y);
205    speex_free(st->last_y);
206    speex_free(st->Yps);
207    speex_free(st->Yf);
208    speex_free(st->Rf);
209    speex_free(st->Xf);
210    speex_free(st->fratio);
211    speex_free(st->regul);
212
213    speex_free(st->X);
214    speex_free(st->D);
215    speex_free(st->Y);
216    speex_free(st->E);
217    speex_free(st->W);
218    speex_free(st->PHI);
219    speex_free(st->power);
220    speex_free(st->power_1);
221    speex_free(st->grad);
222
223    speex_free(st);
224 }
225
226       
227 /** Performs echo cancellation on a frame */
228 void speex_echo_cancel(SpeexEchoState *st, short *ref, short *echo, short *out, float *Yout)
229 {
230    int i,j,m;
231    int N,M;
232    float scale;
233    float ESR;
234    float SER;
235    float Sry=0,Srr=0,Syy=0,Sey=0,See=0,Sxx=0;
236    float leak_estimate = .1+(.9/(1+2*st->sum_adapt));
237          
238    N = st->window_size;
239    M = st->M;
240    scale = 1.0f/N;
241    st->cancel_count++;
242
243    /* Copy input data to buffer */
244    for (i=0;i<st->frame_size;i++)
245    {
246       st->x[i] = st->x[i+st->frame_size];
247       st->x[i+st->frame_size] = echo[i];
248
249       st->d[i] = st->d[i+st->frame_size];
250       st->d[i+st->frame_size] = ref[i];
251    }
252
253    /* Shift memory: this could be optimized eventually*/
254    for (i=0;i<N*(M-1);i++)
255       st->X[i]=st->X[i+N];
256
257    /* Copy new echo frame */
258    for (i=0;i<N;i++)
259       st->X[(M-1)*N+i]=st->x[i];
260
261    /* Convert x (echo input) to frequency domain */
262    spx_drft_forward(st->fft_lookup, &st->X[(M-1)*N]);
263
264    /* Compute filter response Y */
265    for (i=0;i<N;i++)
266       st->Y[i] = 0;
267    for (j=0;j<M;j++)
268       spectral_mul_accum(&st->X[j*N], &st->W[j*N], st->Y, N);
269    
270    /* Convert Y (filter response) to time domain */
271    for (i=0;i<N;i++)
272       st->y[i] = st->Y[i];
273    spx_drft_backward(st->fft_lookup, st->y);
274    for (i=0;i<N;i++)
275       st->y[i] *= scale;
276
277    /* Transform d (reference signal) to frequency domain */
278    for (i=0;i<N;i++)
279       st->D[i]=st->d[i];
280    spx_drft_forward(st->fft_lookup, st->D);
281
282    /* Compute error signal (signal with echo removed) */ 
283    for (i=0;i<st->frame_size;i++)
284    {
285       float tmp_out;
286       tmp_out = (float)ref[i] - st->y[i+st->frame_size];
287       
288       st->E[i] = 0;
289       st->E[i+st->frame_size] = tmp_out;
290       
291       /* Saturation */
292       if (tmp_out>32767)
293          tmp_out = 32767;
294       else if (tmp_out<-32768)
295          tmp_out = -32768;
296       out[i] = tmp_out;  
297    }
298    
299    /* Compute power spectrum of output (D-Y) and filter response (Y) */
300    for (i=0;i<N;i++)
301       st->D[i] -= st->Y[i];
302    power_spectrum(st->D, st->Rf, N);
303    power_spectrum(st->Y, st->Yf, N);
304    
305    /* Compute frequency-domain adaptation mask */
306    for (j=0;j<=st->frame_size;j++)
307    {
308       float r;
309       r = leak_estimate*st->Yf[j] / (1+st->Rf[j]);
310       if (r>1)
311          r = 1;
312       st->fratio[j] = r;
313       /*printf ("%f ", r);*/
314    }
315    /*printf ("\n");*/
316
317    /*float Sww=0;*/
318    /* Compute a bunch of correlations */
319    Sry = inner_prod(st->y+st->frame_size, st->d+st->frame_size, st->frame_size);
320    Sey = inner_prod(st->y+st->frame_size, st->E+st->frame_size, st->frame_size);
321    See = inner_prod(st->E+st->frame_size, st->E+st->frame_size, st->frame_size);
322    Syy = inner_prod(st->y+st->frame_size, st->y+st->frame_size, st->frame_size);
323    Srr = inner_prod(st->d+st->frame_size, st->d+st->frame_size, st->frame_size);
324    Sxx = inner_prod(st->x+st->frame_size, st->x+st->frame_size, st->frame_size);
325    
326    st->Sey = .98*st->Sey + .02*Sey;
327    st->Syy = .98*st->Syy + .02*Syy;
328    st->See = .98*st->See + .02*See;
329    
330    if (st->Sey/(1+st->Syy + .01*st->See) < -1)
331    {
332       fprintf (stderr, "reset at %d\n", st->cancel_count);
333       speex_echo_reset(st);
334       return;
335    }
336    
337    /*for (i=0;i<M*N;i++)
338       Sww += st->W[i]*st->W[i];
339    */
340    
341    SER = Srr / (1+Sxx);
342    ESR = leak_estimate*Syy / (1+See);
343    if (ESR>1)
344       ESR = 1;
345 #if 1
346    if (st->Sey/(1+st->Syy) < -.1 && (ESR > .3))
347    {
348       for (i=0;i<M*N;i++)
349          st->W[i] *= .95;
350       st->Sey *= .5;
351       /*fprintf (stderr, "corrected down\n");*/
352    }
353 #endif
354 #if 1
355    if (st->Sey/(1+st->Syy) > .1 && (ESR > .1 || SER < 10))
356    {
357       for (i=0;i<M*N;i++)
358          st->W[i] *= 1.05;
359       st->Sey *= .5;
360       /*fprintf (stderr, "corrected up %d\n", st->cancel_count);*/
361    }
362 #endif
363    
364    if (ESR>.6 && st->sum_adapt > 1)
365    /*if (st->cancel_count > 40)*/
366    {
367       if (!st->adapted)
368          fprintf(stderr, "Adapted at %d %f\n", st->cancel_count, st->sum_adapt);
369       st->adapted = 1;
370    }
371    /*printf ("%f %f %f %f %f %f %f %f %f %f %f %f\n", Srr, Syy, Sxx, See, ESR, SER, Sry, Sey, Sww, st->Sey, st->Syy, st->See);*/
372    for (i=0;i<=st->frame_size;i++)
373    {
374       st->fratio[i]  = (.2*ESR+.8*min(.005+ESR,st->fratio[i]));
375       /*printf ("%f ", st->fratio[i]);*/
376    }
377    /*printf ("\n");*/
378    
379    
380    if (st->adapted)
381    {
382       st->adapt_rate = .95f/(2+M);
383    } else {
384       if (SER<.1)
385          st->adapt_rate =.8/(2+M);
386       else if (SER<1)
387          st->adapt_rate =.4/(2+M);
388       else if (SER<10)
389          st->adapt_rate =.2/(2+M);
390       else if (SER<30)
391          st->adapt_rate =.08/(2+M);
392       else
393          st->adapt_rate = 0;
394    }
395    st->sum_adapt += st->adapt_rate;
396
397    /* Compute input power in each frequency bin */
398    {
399       float ss = 1.0f/st->cancel_count;
400       if (ss < .3/M)
401          ss=.3/M;
402       power_spectrum(&st->X[(M-1)*N], st->Xf, N);
403       for (j=0;j<=st->frame_size;j++)
404          st->power[j] = (1-ss)*st->power[j] + ss*st->Xf[j];
405       
406       
407       if (st->adapted)
408       {
409          for (i=0;i<=st->frame_size;i++)
410             st->power_1[i] = st->fratio[i] /(1.f+st->power[i]);
411       } else {
412          for (i=0;i<=st->frame_size;i++)
413             st->power_1[i] = 1.0f/(1.f+st->power[i]);
414       }
415    }
416
417    
418    /* Convert error to frequency domain */
419    spx_drft_forward(st->fft_lookup, st->E);
420
421    /* Do some regularization (prevents problems when system is ill-conditoned) */
422    for (m=0;m<M;m++)
423    {
424       for (i=0;i<N;i++)
425       {
426          st->W[m*N+i] *= 1-st->regul[i]*ESR;
427       }
428    }
429    
430    /* Compute weight gradient */
431    for (j=0;j<M;j++)
432    {
433       weighted_spectral_mul_conj(st->power_1, &st->X[j*N], st->E, st->PHI, N);
434
435       for (i=0;i<N;i++)
436          st->W[j*N+i] += st->adapt_rate*st->PHI[i];
437    }
438    
439    /* AUMDF weight constraint */
440    for (j=0;j<M;j++)
441    {
442       /* Remove the "if" to make this an MDF filter */
443       if (st->cancel_count%M == j)
444       {
445          spx_drft_backward(st->fft_lookup, &st->W[j*N]);
446          for (i=0;i<N;i++)
447             st->W[j*N+i]*=scale;
448          for (i=st->frame_size;i<N;i++)
449          {
450             st->W[j*N+i]=0;
451          }
452          spx_drft_forward(st->fft_lookup, &st->W[j*N]);
453       }
454    }
455    
456    /*if (st->cancel_count%100==0)
457    {
458       for (i=0;i<M*N;i++)
459          printf ("%f ", st->W[i]);
460       printf ("\n");
461    }*/
462
463
464    /* Compute spectrum of estimated echo for use in an echo post-filter (if necessary)*/
465    if (Yout)
466    {
467       if (st->adapted)
468       {
469          for (i=0;i<st->frame_size;i++)
470             st->last_y[i] = st->last_y[st->frame_size+i];
471          for (i=0;i<st->frame_size;i++)
472             st->last_y[st->frame_size+i] = st->y[st->frame_size+i];
473       } else {
474          for (i=0;i<N;i++)
475             st->last_y[i] = st->x[i];
476       }
477       for (i=0;i<N;i++)
478          st->Yps[i] = (.5-.5*cos(2*M_PI*i/N))*st->last_y[i];
479       
480       spx_drft_forward(st->fft_lookup, st->Yps);
481       power_spectrum(st->Yps, st->Yps, N);
482       
483       for (i=0;i<=st->frame_size;i++)
484          Yout[i] = 2*leak_estimate*st->Yps[i];
485    }
486
487 }
488