some comments
[speexdsp.git] / libspeex / mdf.c
1 /* Copyright (C) Jean-Marc Valin
2
3    File: speex_echo.c
4    Echo cancelling based on the MDF algorithm described in:
5
6    J. S. Soo, K. K. Pang Multidelay block frequency adaptive filter, 
7    IEEE Trans. Acoust. Speech Signal Process., Vol. ASSP-38, No. 2, 
8    February 1990.
9
10    Redistribution and use in source and binary forms, with or without
11    modification, are permitted provided that the following conditions are
12    met:
13
14    1. Redistributions of source code must retain the above copyright notice,
15    this list of conditions and the following disclaimer.
16
17    2. Redistributions in binary form must reproduce the above copyright
18    notice, this list of conditions and the following disclaimer in the
19    documentation and/or other materials provided with the distribution.
20
21    3. The name of the author may not be used to endorse or promote products
22    derived from this software without specific prior written permission.
23
24    THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
25    IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
26    OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27    DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT,
28    INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
29    (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30    SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
31    HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
32    STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33    ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34    POSSIBILITY OF SUCH DAMAGE.
35 */
36
37 #ifdef HAVE_CONFIG_H
38 #include "config.h"
39 #endif
40
41 #include "misc.h"
42 #include <speex/speex_echo.h>
43 #include "smallft.h"
44 #include <math.h>
45 #include <stdio.h>
46
47 #ifndef M_PI
48 #define M_PI 3.14159265358979323846
49 #endif
50
51 #define BETA .65
52 /*#define BETA 0*/
53
54 #define min(a,b) ((a)<(b) ? (a) : (b))
55 #define max(a,b) ((a)>(b) ? (a) : (b))
56
57 /** Compute inner product of two real vectors */
58 static inline float inner_prod(float *x, float *y, int N)
59 {
60    int i;
61    float ret=0;
62    for (i=0;i<N;i++)
63       ret += x[i]*y[i];
64    return ret;
65 }
66
67 /** Compute power spectrum of a half-complex (packed) vector */
68 static inline void power_spectrum(float *X, float *ps, int N)
69 {
70    int i, j;
71    ps[0]=X[0]*X[0];
72    for (i=1,j=1;i<N-1;i+=2,j++)
73    {
74       ps[j] =  X[i]*X[i] + X[i+1]*X[i+1];
75    }
76    ps[j]=X[i]*X[i];
77 }
78
79 /** Compute cross-power spectrum of a half-complex (packed) vectors and add to acc */
80 static inline void spectral_mul_accum(float *X, float *Y, float *acc, int N)
81 {
82    int i;
83    acc[0] += X[0]*Y[0];
84    for (i=1;i<N-1;i+=2)
85    {
86       acc[i] += (X[i]*Y[i] - X[i+1]*Y[i+1]);
87       acc[i+1] += (X[i+1]*Y[i] + X[i]*Y[i+1]);
88    }
89    acc[i] += X[i]*Y[i];
90 }
91
92 /** Compute cross-power spectrum of a half-complex (packed) vector with conjugate */
93 static inline void spectral_mul_conj(float *X, float *Y, float *prod, int N)
94 {
95    int i;
96    prod[0] = X[0]*Y[0];
97    for (i=1;i<N-1;i+=2)
98    {
99       prod[i] = (X[i]*Y[i] + X[i+1]*Y[i+1]);
100       prod[i+1] = (-X[i+1]*Y[i] + X[i]*Y[i+1]);
101    }
102    prod[i] = X[i]*Y[i];
103 }
104
105
106 /** Compute weighted cross-power spectrum of a half-complex (packed) vector with conjugate */
107 static inline void weighted_spectral_mul_conj(float *w, float *X, float *Y, float *prod, int N)
108 {
109    int i, j;
110    prod[0] = w[0]*X[0]*Y[0];
111    for (i=1,j=1;i<N-1;i+=2,j++)
112    {
113       prod[i] = w[j]*(X[i]*Y[i] + X[i+1]*Y[i+1]);
114       prod[i+1] = w[j]*(-X[i+1]*Y[i] + X[i]*Y[i+1]);
115    }
116    prod[i] = w[j]*X[i]*Y[i];
117 }
118
119
120 /** Creates a new echo canceller state */
121 SpeexEchoState *speex_echo_state_init(int frame_size, int filter_length)
122 {
123    int i,j,N,M;
124    SpeexEchoState *st = (SpeexEchoState *)speex_alloc(sizeof(SpeexEchoState));
125
126    st->frame_size = frame_size;
127    st->window_size = 2*frame_size;
128    N = st->window_size;
129    M = st->M = (filter_length+st->frame_size-1)/frame_size;
130    st->cancel_count=0;
131    st->adapt_rate = .01f;
132    st->sum_adapt = 0;
133    st->Sey = 0;
134    st->Syy = 0;
135    st->See = 0;
136          
137    st->fft_lookup = (struct drft_lookup*)speex_alloc(sizeof(struct drft_lookup));
138    spx_drft_init(st->fft_lookup, N);
139    
140    st->x = (float*)speex_alloc(N*sizeof(float));
141    st->d = (float*)speex_alloc(N*sizeof(float));
142    st->y = (float*)speex_alloc(N*sizeof(float));
143    st->Yps = (float*)speex_alloc(N*sizeof(float));
144    st->last_y = (float*)speex_alloc(N*sizeof(float));
145    st->Yf = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
146    st->Rf = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
147    st->Xf = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
148    st->fratio = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
149    st->regul = (float*)speex_alloc(N*sizeof(float));
150
151    st->X = (float*)speex_alloc(M*N*sizeof(float));
152    st->D = (float*)speex_alloc(N*sizeof(float));
153    st->Y = (float*)speex_alloc(N*sizeof(float));
154    st->E = (float*)speex_alloc(N*sizeof(float));
155    st->W = (float*)speex_alloc(M*N*sizeof(float));
156    st->PHI = (float*)speex_alloc(N*sizeof(float));
157    st->power = (float*)speex_alloc((frame_size+1)*sizeof(float));
158    st->power_1 = (float*)speex_alloc((frame_size+1)*sizeof(float));
159    st->grad = (float*)speex_alloc(N*M*sizeof(float));
160    
161    for (i=0;i<N*M;i++)
162    {
163       st->W[i] = 0;
164    }
165    
166    st->regul[0] = (.01+(10.)/((4.)*(4.)))/M;
167    for (i=1,j=1;i<N-1;i+=2,j++)
168    {
169       st->regul[i] = .01+((10.)/((j+4.)*(j+4.)))/M;
170       st->regul[i+1] = .01+((10.)/((j+4.)*(j+4.)))/M;
171    }
172    st->regul[i] = .01+((10.)/((j+4.)*(j+4.)))/M;
173          
174    st->adapted = 0;
175    return st;
176 }
177
178 /** Resets echo canceller state */
179 void speex_echo_reset(SpeexEchoState *st)
180 {
181    int i, M, N;
182    st->cancel_count=0;
183    st->adapt_rate = .01f;
184    N = st->window_size;
185    M = st->M;
186    for (i=0;i<N*M;i++)
187    {
188       st->W[i] = 0;
189       st->X[i] = 0;
190    }
191    for (i=0;i<=st->frame_size;i++)
192       st->power[i] = 0;
193    
194    st->adapted = 0;
195    st->adapt_rate = .01f;
196    st->sum_adapt = 0;
197    st->Sey = 0;
198    st->Syy = 0;
199    st->See = 0;
200
201 }
202
203 /** Destroys an echo canceller state */
204 void speex_echo_state_destroy(SpeexEchoState *st)
205 {
206    spx_drft_clear(st->fft_lookup);
207    speex_free(st->fft_lookup);
208    speex_free(st->x);
209    speex_free(st->d);
210    speex_free(st->y);
211    speex_free(st->last_y);
212    speex_free(st->Yps);
213    speex_free(st->Yf);
214    speex_free(st->Rf);
215    speex_free(st->Xf);
216    speex_free(st->fratio);
217    speex_free(st->regul);
218
219    speex_free(st->X);
220    speex_free(st->D);
221    speex_free(st->Y);
222    speex_free(st->E);
223    speex_free(st->W);
224    speex_free(st->PHI);
225    speex_free(st->power);
226    speex_free(st->power_1);
227    speex_free(st->grad);
228
229    speex_free(st);
230 }
231
232       
233 /** Performs echo cancellation on a frame */
234 void speex_echo_cancel(SpeexEchoState *st, short *ref, short *echo, short *out, float *Yout)
235 {
236    int i,j,m;
237    int N,M;
238    float scale;
239    float ESR;
240    float SER;
241    float Sry=0,Srr=0,Syy=0,Sey=0,See=0,Sxx=0;
242    float leak_estimate;
243    
244    leak_estimate = .1+(.9/(1+2*st->sum_adapt));
245          
246    N = st->window_size;
247    M = st->M;
248    scale = 1.0f/N;
249    st->cancel_count++;
250
251    /* Copy input data to buffer */
252    for (i=0;i<st->frame_size;i++)
253    {
254       st->x[i] = st->x[i+st->frame_size];
255       st->x[i+st->frame_size] = echo[i];
256
257       st->d[i] = st->d[i+st->frame_size];
258       st->d[i+st->frame_size] = ref[i];
259    }
260
261    /* Shift memory: this could be optimized eventually*/
262    for (i=0;i<N*(M-1);i++)
263       st->X[i]=st->X[i+N];
264
265    /* Copy new echo frame */
266    for (i=0;i<N;i++)
267       st->X[(M-1)*N+i]=st->x[i];
268
269    /* Convert x (echo input) to frequency domain */
270    spx_drft_forward(st->fft_lookup, &st->X[(M-1)*N]);
271
272    /* Compute filter response Y */
273    for (i=0;i<N;i++)
274       st->Y[i] = 0;
275    for (j=0;j<M;j++)
276       spectral_mul_accum(&st->X[j*N], &st->W[j*N], st->Y, N);
277    
278    /* Convert Y (filter response) to time domain */
279    for (i=0;i<N;i++)
280       st->y[i] = st->Y[i];
281    spx_drft_backward(st->fft_lookup, st->y);
282    for (i=0;i<N;i++)
283       st->y[i] *= scale;
284
285    /* Transform d (reference signal) to frequency domain */
286    for (i=0;i<N;i++)
287       st->D[i]=st->d[i];
288    spx_drft_forward(st->fft_lookup, st->D);
289
290    /* Compute error signal (signal with echo removed) */ 
291    for (i=0;i<st->frame_size;i++)
292    {
293       float tmp_out;
294       tmp_out = (float)ref[i] - st->y[i+st->frame_size];
295       
296       st->E[i] = 0;
297       st->E[i+st->frame_size] = tmp_out;
298       
299       /* Saturation */
300       if (tmp_out>32767)
301          tmp_out = 32767;
302       else if (tmp_out<-32768)
303          tmp_out = -32768;
304       out[i] = tmp_out;  
305    }
306    
307    /* Compute power spectrum of output (D-Y) and filter response (Y) */
308    for (i=0;i<N;i++)
309       st->D[i] -= st->Y[i];
310    power_spectrum(st->D, st->Rf, N);
311    power_spectrum(st->Y, st->Yf, N);
312    
313    /* Compute frequency-domain adaptation mask */
314    for (j=0;j<=st->frame_size;j++)
315    {
316       float r;
317       r = leak_estimate*st->Yf[j] / (1+st->Rf[j]);
318       if (r>1)
319          r = 1;
320       st->fratio[j] = r;
321       /*printf ("%f ", r);*/
322    }
323    /*printf ("\n");*/
324
325    /* Compute a bunch of correlations */
326    Sry = inner_prod(st->y+st->frame_size, st->d+st->frame_size, st->frame_size);
327    Sey = inner_prod(st->y+st->frame_size, st->E+st->frame_size, st->frame_size);
328    See = inner_prod(st->E+st->frame_size, st->E+st->frame_size, st->frame_size);
329    Syy = inner_prod(st->y+st->frame_size, st->y+st->frame_size, st->frame_size);
330    Srr = inner_prod(st->d+st->frame_size, st->d+st->frame_size, st->frame_size);
331    Sxx = inner_prod(st->x+st->frame_size, st->x+st->frame_size, st->frame_size);
332
333    /* Compute smoothed cross-correlation and energy */   
334    st->Sey = .98*st->Sey + .02*Sey;
335    st->Syy = .98*st->Syy + .02*Syy;
336    st->See = .98*st->See + .02*See;
337    
338    if (st->Sey/(1+st->Syy + .01*st->See) < -1)
339    {
340       fprintf (stderr, "reset at %d\n", st->cancel_count);
341       speex_echo_reset(st);
342       return;
343    }
344    
345    /*for (i=0;i<M*N;i++)
346       Sww += st->W[i]*st->W[i];
347    */
348    
349    SER = Srr / (1+Sxx);
350    ESR = leak_estimate*Syy / (1+See);
351    if (ESR>1)
352       ESR = 1;
353 #if 1
354    /* If over-cancellation (creating echo with 180 phase) damp filter */
355    if (st->Sey/(1+st->Syy) < -.1 && (ESR > .3))
356    {
357       for (i=0;i<M*N;i++)
358          st->W[i] *= .95;
359       st->Sey *= .5;
360       /*fprintf (stderr, "corrected down\n");*/
361    }
362 #endif
363 #if 1
364    /* If under-cancellation (leaving echo with 0 phase) scale filter up */
365    if (st->Sey/(1+st->Syy) > .1 && (ESR > .1 || SER < 10))
366    {
367       for (i=0;i<M*N;i++)
368          st->W[i] *= 1.05;
369       st->Sey *= .5;
370       /*fprintf (stderr, "corrected up %d\n", st->cancel_count);*/
371    }
372 #endif
373    
374    /* We consider that the filter is adapted if the following is true*/
375    if (ESR>.6 && st->sum_adapt > 1)
376    /*if (st->cancel_count > 40)*/
377    {
378       if (!st->adapted)
379          fprintf(stderr, "Adapted at %d %f\n", st->cancel_count, st->sum_adapt);
380       st->adapted = 1;
381    }
382    /*printf ("%f %f %f %f %f %f %f %f %f %f %f %f\n", Srr, Syy, Sxx, See, ESR, SER, Sry, Sey, Sww, st->Sey, st->Syy, st->See);*/
383    for (i=0;i<=st->frame_size;i++)
384    {
385       st->fratio[i]  = (.2*ESR+.8*min(.005+ESR,st->fratio[i]));
386       /*printf ("%f ", st->fratio[i]);*/
387    }
388    /*printf ("\n");*/
389    
390    
391    if (st->adapted)
392    {
393       st->adapt_rate = .95f/(2+M);
394    } else {
395       /* Temporary adaption if filter is not adapted correctly */
396       if (SER<.1)
397          st->adapt_rate =.8/(2+M);
398       else if (SER<1)
399          st->adapt_rate =.4/(2+M);
400       else if (SER<10)
401          st->adapt_rate =.2/(2+M);
402       else if (SER<30)
403          st->adapt_rate =.08/(2+M);
404       else
405          st->adapt_rate = 0;
406    }
407    st->sum_adapt += st->adapt_rate;
408
409    /* Compute echo power in each frequency bin */
410    {
411       float ss = 1.0f/st->cancel_count;
412       if (ss < .3/M)
413          ss=.3/M;
414       power_spectrum(&st->X[(M-1)*N], st->Xf, N);
415       for (j=0;j<=st->frame_size;j++)
416          st->power[j] = (1-ss)*st->power[j] + ss*st->Xf[j];
417       
418       
419       if (st->adapted)
420       {
421          for (i=0;i<=st->frame_size;i++)
422             st->power_1[i] = st->fratio[i] /(1.f+st->power[i]);
423       } else {
424          for (i=0;i<=st->frame_size;i++)
425             st->power_1[i] = 1.0f/(1.f+st->power[i]);
426       }
427    }
428
429    
430    /* Convert error to frequency domain */
431    spx_drft_forward(st->fft_lookup, st->E);
432
433    /* Do some regularization (prevents problems when system is ill-conditoned) */
434    for (m=0;m<M;m++)
435    {
436       for (i=0;i<N;i++)
437       {
438          st->W[m*N+i] *= 1-st->regul[i]*ESR;
439       }
440    }
441    
442    /* Compute weight gradient */
443    for (j=0;j<M;j++)
444    {
445       weighted_spectral_mul_conj(st->power_1, &st->X[j*N], st->E, st->PHI, N);
446
447       for (i=0;i<N;i++)
448          st->W[j*N+i] += st->adapt_rate*st->PHI[i];
449    }
450    
451    /* AUMDF weight constraint */
452    for (j=0;j<M;j++)
453    {
454       /* Remove the "if" to make this an MDF filter */
455       if (st->cancel_count%M == j)
456       {
457          spx_drft_backward(st->fft_lookup, &st->W[j*N]);
458          for (i=0;i<N;i++)
459             st->W[j*N+i]*=scale;
460          for (i=st->frame_size;i<N;i++)
461          {
462             st->W[j*N+i]=0;
463          }
464          spx_drft_forward(st->fft_lookup, &st->W[j*N]);
465       }
466    }
467    
468    /*if (st->cancel_count%100==0)
469    {
470       for (i=0;i<M*N;i++)
471          printf ("%f ", st->W[i]);
472       printf ("\n");
473    }*/
474
475
476    /* Compute spectrum of estimated echo for use in an echo post-filter (if necessary)*/
477    if (Yout)
478    {
479       if (st->adapted)
480       {
481          for (i=0;i<st->frame_size;i++)
482             st->last_y[i] = st->last_y[st->frame_size+i];
483          for (i=0;i<st->frame_size;i++)
484             st->last_y[st->frame_size+i] = st->y[st->frame_size+i];
485       } else {
486          for (i=0;i<N;i++)
487             st->last_y[i] = st->x[i];
488       }
489       for (i=0;i<N;i++)
490          st->Yps[i] = (.5-.5*cos(2*M_PI*i/N))*st->last_y[i];
491       
492       spx_drft_forward(st->fft_lookup, st->Yps);
493       power_spectrum(st->Yps, st->Yps, N);
494       
495       for (i=0;i<=st->frame_size;i++)
496          Yout[i] = 2*leak_estimate*st->Yps[i];
497    }
498
499 }
500