trying some ideas for soft-decision DTD based on residual-to-signal ratio
[speexdsp.git] / libspeex / mdf.c
1 /* Copyright (C) Jean-Marc Valin
2
3    File: speex_echo.c
4    Echo cancelling based on the MDF algorithm described in:
5
6    J. S. Soo, K. K. Pang Multidelay block frequency adaptive filter, 
7    IEEE Trans. Acoust. Speech Signal Process., Vol. ASSP-38, No. 2, 
8    February 1990.
9
10    Redistribution and use in source and binary forms, with or without
11    modification, are permitted provided that the following conditions are
12    met:
13
14    1. Redistributions of source code must retain the above copyright notice,
15    this list of conditions and the following disclaimer.
16
17    2. Redistributions in binary form must reproduce the above copyright
18    notice, this list of conditions and the following disclaimer in the
19    documentation and/or other materials provided with the distribution.
20
21    3. The name of the author may not be used to endorse or promote products
22    derived from this software without specific prior written permission.
23
24    THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
25    IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
26    OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27    DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT,
28    INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
29    (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30    SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
31    HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
32    STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33    ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34    POSSIBILITY OF SUCH DAMAGE.
35 */
36
37 #ifdef HAVE_CONFIG_H
38 #include "config.h"
39 #endif
40
41 #include "misc.h"
42 #include <speex/speex_echo.h>
43 #include "smallft.h"
44 #include <math.h>
45 /*#include <stdio.h>*/
46
47 #define BETA .65
48
49 #define min(a,b) ((a)<(b) ? (a) : (b))
50
51 /** Creates a new echo canceller state */
52 SpeexEchoState *speex_echo_state_init(int frame_size, int filter_length)
53 {
54    int i,N,M;
55    SpeexEchoState *st = (SpeexEchoState *)speex_alloc(sizeof(SpeexEchoState));
56
57    st->frame_size = frame_size;
58    st->window_size = 2*frame_size;
59    N = st->window_size;
60    M = st->M = (filter_length+N-1)/frame_size;
61    st->cancel_count=0;
62    st->adapt_rate = .01f;
63
64    st->fft_lookup = (struct drft_lookup*)speex_alloc(sizeof(struct drft_lookup));
65    spx_drft_init(st->fft_lookup, N);
66    
67    st->x = (float*)speex_alloc(N*sizeof(float));
68    st->d = (float*)speex_alloc(N*sizeof(float));
69    st->y = (float*)speex_alloc(N*sizeof(float));
70    st->Yf = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
71    st->Rf = (float*)speex_alloc((st->frame_size+1)*sizeof(float));
72
73    st->X = (float*)speex_alloc(M*N*sizeof(float));
74    st->D = (float*)speex_alloc(N*sizeof(float));
75    st->Y = (float*)speex_alloc(N*sizeof(float));
76    st->E = (float*)speex_alloc(N*sizeof(float));
77    st->W = (float*)speex_alloc(M*N*sizeof(float));
78    st->PHI = (float*)speex_alloc(N*sizeof(float));
79    st->power = (float*)speex_alloc((frame_size+1)*sizeof(float));
80    st->power_1 = (float*)speex_alloc((frame_size+1)*sizeof(float));
81    st->grad = (float*)speex_alloc(N*M*sizeof(float));
82    
83    for (i=0;i<N*M;i++)
84    {
85       st->W[i] = 0;
86    }
87    
88    st->adapted = 0;
89    return st;
90 }
91
92 void speex_echo_reset(SpeexEchoState *st)
93 {
94    int i, M, N;
95    st->cancel_count=0;
96    st->adapt_rate = .01f;
97    N = st->window_size;
98    M = st->M;
99    for (i=0;i<N*M;i++)
100    {
101       st->W[i] = 0;
102       st->X[i] = 0;
103    }
104 }
105
106 /** Destroys an echo canceller state */
107 void speex_echo_state_destroy(SpeexEchoState *st)
108 {
109    spx_drft_clear(st->fft_lookup);
110    speex_free(st->fft_lookup);
111    speex_free(st->x);
112    speex_free(st->d);
113    speex_free(st->y);
114    speex_free(st->Yf);
115    speex_free(st->Rf);
116
117    speex_free(st->X);
118    speex_free(st->D);
119    speex_free(st->Y);
120    speex_free(st->E);
121    speex_free(st->W);
122    speex_free(st->PHI);
123    speex_free(st->power);
124    speex_free(st->power_1);
125    speex_free(st->grad);
126
127    speex_free(st);
128 }
129
130 /** Performs echo cancellation on a frame */
131 void speex_echo_cancel(SpeexEchoState *st, short *ref, short *echo, short *out, int *Yout)
132 {
133    int i,j,m;
134    int N,M;
135    float scale;
136    float spectral_dist=0;
137    float cos_dist=0;
138    float Eout=0;
139    float See=0;
140    float ESR;
141          
142    N = st->window_size;
143    M = st->M;
144    scale = 1.0f/N;
145    st->cancel_count++;
146
147    /* Copy input data to buffer */
148    for (i=0;i<st->frame_size;i++)
149    {
150       st->x[i] = st->x[i+st->frame_size];
151       st->x[i+st->frame_size] = echo[i];
152
153       st->d[i] = st->d[i+st->frame_size];
154       st->d[i+st->frame_size] = ref[i];
155    }
156
157    /* Shift memory: this could be optimized eventually*/
158    for (i=0;i<N*(M-1);i++)
159       st->X[i]=st->X[i+N];
160
161    for (i=0;i<N;i++)
162       st->X[(M-1)*N+i]=st->x[i];
163
164    /* Convert x (echo input) to frequency domain */
165    spx_drft_forward(st->fft_lookup, &st->X[(M-1)*N]);
166
167    /* Compute filter response Y */
168    for (i=1;i<N-1;i+=2)
169    {
170       st->Y[i] = st->Y[i+1] = 0;
171       for (j=0;j<M;j++)
172       {
173          st->Y[i] += st->X[j*N+i]*st->W[j*N+i] - st->X[j*N+i+1]*st->W[j*N+i+1];
174          st->Y[i+1] += st->X[j*N+i+1]*st->W[j*N+i] + st->X[j*N+i]*st->W[j*N+i+1];
175       }
176    }
177    st->Y[0] = st->Y[N-1] = 0;
178    for (j=0;j<M;j++)
179    {
180       st->Y[0] += st->X[j*N]*st->W[j*N];
181       st->Y[N-1] += st->X[(j+1)*N-1]*st->W[(j+1)*N-1];
182    }
183
184
185    /* Transform d (reference signal) to frequency domain */
186    for (i=0;i<N;i++)
187       st->D[i]=st->d[i];
188    spx_drft_forward(st->fft_lookup, st->D);
189
190    {
191       for (i=1,j=1;i<N-1;i+=2,j++)
192       {
193          st->Yf[j] =  st->Y[i]*st->Y[i] + st->Y[i+1]*st->Y[i+1];
194       }
195       st->Yf[0]=st->Y[0]*st->Y[0];
196       st->Yf[st->frame_size]=st->Y[i]*st->Y[i];
197       
198       for (i=1,j=1;i<N-1;i+=2,j++)
199       {
200          st->Rf[j] = (st->Y[i]-st->D[i])*(st->Y[i]-st->D[i]) + (st->Y[i+1]-st->D[i+1])*(st->Y[i+1]-st->D[i+1]);
201       }
202       st->Rf[0]=(st->Y[0]-st->D[0])*(st->Y[0]-st->D[0]);
203       st->Rf[st->frame_size]=(st->Y[i]-st->D[i])*(st->Y[i]-st->D[i]);
204       for (j=0;j<=st->frame_size;j++)
205       {
206          float r;
207          r = .3*st->Yf[j] / (1+st->Rf[j]);
208          if (r>1)
209             r = 1;
210          st->power_1[j] = r;
211          //printf ("%f ", r);
212       }
213       //printf ("\n");
214    }
215    
216    /* Copy spectrum of Y to Yout for use in an echo post-filter */
217    if (Yout)
218    {
219       for (i=1,j=1;i<N-1;i+=2,j++)
220       {
221          Yout[j] =  st->Y[i]*st->Y[i] + st->Y[i+1]*st->Y[i+1];
222       }
223       Yout[0] = Yout[st->frame_size] = 0;
224       for (i=0;i<=st->frame_size;i++)
225          Yout[i] *= .1;
226    }
227
228    for (i=0;i<N;i++)
229       st->y[i] = st->Y[i];
230    
231    /* Convery Y (filter response) to time domain */
232    spx_drft_backward(st->fft_lookup, st->y);
233    for (i=0;i<N;i++)
234       st->y[i] *= scale;
235
236    Eout = 0;
237    /* Compute error signal (echo canceller output) */
238    for (i=0;i<st->frame_size;i++)
239    {
240       float tmp_out;
241       tmp_out = (float)ref[i] - st->y[i+st->frame_size];
242       Eout += tmp_out*tmp_out;
243       
244       if (tmp_out>32767)
245          tmp_out = 32767;
246       else if (tmp_out<-32768)
247          tmp_out = -32768;
248       out[i] = tmp_out;
249         
250       st->E[i] = 0;
251       st->E[i+st->frame_size] = out[i];
252    }
253
254    {
255       float Sry=0, Srr=0,Syy=0;
256       /*float cos_dist;*/
257       for (i=0;i<st->frame_size;i++)
258       {
259          Sry += st->y[i+st->frame_size] * ref[i];
260          Srr += (float)ref[i] * (float)ref[i];
261          See += (float)echo[i] * (float)echo[i];
262          Syy += st->y[i+st->frame_size]*st->y[i+st->frame_size];
263       }
264       cos_dist = Sry/(sqrt(1e8+Srr)*sqrt(1e8+Syy));
265       /*printf (" %f ", cos_dist);*/
266       spectral_dist = Sry/(1e8+Srr);
267       /*printf (" %f\n", spectral_dist);*/
268       ESR = .2*Syy / (1+Eout);
269       if (ESR>1)
270          ESR = 1;
271       
272       if (ESR>.5)// && st->cancel_count > 50)
273       {
274          if (!st->adapted)
275             fprintf(stderr, "Adapted at %d\n", st->cancel_count); 
276          st->adapted = 1;
277       }
278       //printf ("%f %f %f %f %f\n", Srr, Syy, See, Eout, ESR);
279       for (i=0;i<=st->frame_size;i++)
280       {
281          //st->power_1[i]  = (.1*ESR+.9*min(.3+2*ESR,st->power_1[i]));
282          //st->power_1[i]  = ESR;
283          printf ("%f ", st->power_1[i]);
284       }
285       printf ("\n");
286    }
287    /* Convert error to frequency domain */
288    spx_drft_forward(st->fft_lookup, st->E);
289
290    /* Compute input power in each frequency bin */
291    {
292       float s;
293       float tmp, tmp2;
294
295       if (st->cancel_count<M)
296          s = 1.0f/st->cancel_count;
297       else
298          s = 1.0f/M;
299       
300       for (i=1,j=1;i<N-1;i+=2,j++)
301       {
302          tmp=0;
303          for (m=0;m<M;m++)
304          {
305             float E = st->X[m*N+i]*st->X[m*N+i] + st->X[m*N+i+1]*st->X[m*N+i+1];
306             tmp += E;
307             if (st->power[j] < .2*E)
308                st->power[j] = .2*E;
309
310          }
311          tmp *= s;
312          if (st->cancel_count<M)
313             st->power[j] = tmp;
314          else
315             st->power[j] = BETA*st->power[j] + (1-BETA)*tmp;
316       }
317       tmp=tmp2=0;
318       for (m=0;m<M;m++)
319       {
320          tmp += st->X[m*N]*st->X[m*N];
321          tmp2 += st->X[(m+1)*N-1]*st->X[(m+1)*N-1];
322          /*FIXME: Should put a bound on energy like several lines above */
323       }
324       tmp *= s;
325       tmp2 *= s;
326       if (st->cancel_count<M)
327       {
328          st->power[0] = tmp;
329          st->power[st->frame_size] = tmp2;
330       } else {
331          st->power[0] = BETA*st->power[0] + (1-BETA)*tmp;
332          st->power[st->frame_size] = BETA*st->power[st->frame_size] + (1-BETA)*tmp2;
333       }
334       
335       if (st->adapted)
336       //if (0)
337       {
338          for (i=0;i<=st->frame_size;i++)
339          {
340             //st->power_1[i]  = (.1*ESR+.9*min(1.5f*ESR,st->power_1[i]));
341             st->power_1[i] /= (1e3f+st->power[i]);
342          }
343       } else {
344          for (i=0;i<=st->frame_size;i++)
345             st->power_1[i] = 1.0f/(1e5f+st->power[i]);
346       }
347    }
348
349    /* Compute weight gradient */
350    for (j=0;j<M;j++)
351    {
352       for (i=1,m=1;i<N-1;i+=2,m++)
353       {
354          st->PHI[i] = st->power_1[m] 
355          * (st->X[j*N+i]*st->E[i] + st->X[j*N+i+1]*st->E[i+1]);
356          st->PHI[i+1] = st->power_1[m] 
357          * (-st->X[j*N+i+1]*st->E[i] + st->X[j*N+i]*st->E[i+1]);
358       }
359       st->PHI[0] = st->power_1[0] * st->X[j*N]*st->E[0];
360       st->PHI[N-1] = st->power_1[st->frame_size] * st->X[(j+1)*N-1]*st->E[N-1];
361       
362
363 #if 0 /* Set to 1 to enable MDF instead of AUMDF (and comment out weight constraint below) */
364       spx_drft_backward(st->fft_lookup, st->PHI);
365       for (i=0;i<N;i++)
366          st->PHI[i]*=scale;
367       for (i=st->frame_size;i<N;i++)
368         st->PHI[i]=0;
369       spx_drft_forward(st->fft_lookup, st->PHI);
370 #endif
371      
372
373       for (i=0;i<N;i++)
374       {
375          st->grad[j*N+i] = st->PHI[i];
376       }
377
378       
379    }
380
381    /* Adjust adaptation rate */
382    if (st->cancel_count>2*M)
383    {
384       if (st->cancel_count<8*M)
385       {
386          st->adapt_rate = .3f/(2+M);
387       } else {
388          if (spectral_dist > .5 && cos_dist > .7)
389             st->adapt_rate = .4f/(2+M);
390          else if (spectral_dist > .3 && cos_dist > .5)
391             st->adapt_rate = .2f/(2+M);
392          else if (spectral_dist > .15 && cos_dist > .3)
393             st->adapt_rate = .1f/(2+M);
394          else if (cos_dist > .01)
395             st->adapt_rate = .05f/(2+M);
396          else
397             st->adapt_rate = .01f/(2+M);
398       }
399    } else
400       st->adapt_rate = .0f;
401       
402       //st->adapt_rate *=4;// .1f/(2+M);
403    if (See>1e8)
404       st->adapt_rate =.8/(2+M);
405    else if (See>3e7)
406       st->adapt_rate =.4/(2+M);
407    else if (See>1e7)
408       st->adapt_rate =.2/(2+M);
409    else if (See>3e6)
410       st->adapt_rate =.1/(2+M);
411    else
412       st->adapt_rate = 0;
413    
414    st->adapt_rate =.4/(2+M);
415    /*if (st->cancel_count < 40)
416    st->adapt_rate *= 2.;
417    */
418    
419    /*if (st->cancel_count<30)
420       st->adapt_rate *= 1.5;
421    else
422       st->adapt_rate *= .9;
423    */
424 #if 0
425    if (st->cancel_count>70)
426       st->adapt_rate = .6*ESR/(2+M);
427 #else
428    if (st->adapted)
429       st->adapt_rate = .9f/(1+M);
430 #endif
431    /* Update weights */
432    for (i=0;i<M*N;i++)
433       st->W[i] += st->adapt_rate*st->grad[i];
434
435    /* AUMDF weight constraint */
436    for (j=0;j<M;j++)
437    {
438       if (st->cancel_count%M == j)
439       {
440          spx_drft_backward(st->fft_lookup, &st->W[j*N]);
441          for (i=0;i<N;i++)
442             st->W[j*N+i]*=scale;
443          for (i=st->frame_size;i<N;i++)
444          {
445             st->W[j*N+i]=0;
446          }
447          spx_drft_forward(st->fft_lookup, &st->W[j*N]);
448       }
449
450    }
451
452 }
453