denoiser now behaves correctly with 240-sample frames
[speexdsp.git] / libspeex / denoise.c
1 /* Copyright (C) 2003 Epic Games 
2    Written by Jean-Marc Valin
3
4    File: denoise.c
5
6
7    Redistribution and use in source and binary forms, with or without
8    modification, are permitted provided that the following conditions are
9    met:
10
11    1. Redistributions of source code must retain the above copyright notice,
12    this list of conditions and the following disclaimer.
13
14    2. Redistributions in binary form must reproduce the above copyright
15    notice, this list of conditions and the following disclaimer in the
16    documentation and/or other materials provided with the distribution.
17
18    3. The name of the author may not be used to endorse or promote products
19    derived from this software without specific prior written permission.
20
21    THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
22    IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
23    OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24    DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT,
25    INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26    (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27    SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
28    HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
29    STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30    ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31    POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #include <math.h>
35 #include "speex_denoise.h"
36 #include <stdio.h>
37 #include "misc.h"
38
39 #define STABILITY_TIME 20
40 #define NB_LAST_PS 10
41
42 #define max(a,b) ((a) > (b) ? (a) : (b))
43 #define min(a,b) ((a) < (b) ? (a) : (b))
44
45 #ifndef M_PI
46 #define M_PI 3.14159263
47 #endif
48
49 #define SQRT_M_PI_2 0.88623
50 #define LOUDNESS_EXP 3.5
51
52 static void conj_window(float *w, int len)
53 {
54    int i;
55    for (i=0;i<len;i++)
56    {
57       float x=4*((float)i)/len;
58       int inv=0;
59       if (x<1)
60       {
61       } else if (x<2)
62       {
63          x=2-x;
64          inv=1;
65       } else if (x<3)
66       {
67          x=x-2;
68          inv=1;
69       } else {
70          x=4-x;
71       }
72       x*=1.9979;
73       w[i]=(.5-.5*cos(x))*(.5-.5*cos(x));
74       if (inv)
75          w[i]=1-w[i];
76       w[i]=sqrt(w[i]);
77    }
78 }
79
80 DenoiseState *denoise_state_init(int frame_size)
81 {
82    int i;
83    int N, N3, N4;
84
85    DenoiseState *st = (DenoiseState *)speex_alloc(sizeof(DenoiseState));
86    st->frame_size = frame_size;
87
88    /* Round ps_size down to the nearest power of two */
89    i=1;
90    st->ps_size = st->frame_size;
91    while(1)
92    {
93       if (st->ps_size & ~i)
94       {
95          st->ps_size &= ~i;
96          i<<=1;
97       } else {
98          break;
99       }
100    }
101
102    if (st->ps_size < 3*st->frame_size/4)
103       st->ps_size = st->ps_size * 3 / 2;
104    N = st->ps_size;
105    N3 = 2*N - st->frame_size;
106    N4 = st->frame_size - N3;
107
108    st->frame = (float*)speex_alloc(2*N*sizeof(float));
109    st->ps = (float*)speex_alloc(N*sizeof(float));
110    st->gain2 = (float*)speex_alloc(N*sizeof(float));
111    st->window = (float*)speex_alloc(2*N*sizeof(float));
112    st->noise = (float*)speex_alloc(N*sizeof(float));
113    st->old_ps = (float*)speex_alloc(N*sizeof(float));
114    st->gain = (float*)speex_alloc(N*sizeof(float));
115    st->prior = (float*)speex_alloc(N*sizeof(float));
116    st->post = (float*)speex_alloc(N*sizeof(float));
117    st->min_ps = (float*)speex_alloc(N*sizeof(float));
118    st->last_energy = (float*)speex_alloc(STABILITY_TIME*sizeof(float));
119    st->last_ps = (float*)speex_alloc(NB_LAST_PS*N*sizeof(float));
120    st->loudness_weight = (float*)speex_alloc(N*sizeof(float));
121    st->inbuf = (float*)speex_alloc(N3*sizeof(float));
122    st->outbuf = (float*)speex_alloc(N3*sizeof(float));
123
124    conj_window(st->window, 2*N3);
125    for (i=2*N3;i<2*st->ps_size;i++)
126       st->window[i]=1;
127    for (i=N3-1;i>=0;i--)
128    {
129       st->window[i+N3+N4]=st->window[i+N3];
130       st->window[i+N3]=1;
131    }
132
133    for (i=0;i<N;i++)
134    {
135       st->noise[i]=1e4;
136       st->old_ps[i]=1e4;
137       st->gain[i]=1;
138       st->post[i]=1;
139       st->prior[i]=1;
140    }
141
142    for (i=0;i<N3;i++)
143    {
144       st->inbuf[i]=0;
145       st->outbuf[i]=0;
146    }
147
148    for (i=0;i<N;i++)
149    {
150       float ff=((float)i)*128.0/4000.0;
151       st->loudness_weight[i] = .35-.35*ff/16000+.73*exp(-.5*(ff-3800)*(ff-3800)/9e5);
152       st->loudness_weight[i] *= st->loudness_weight[i];
153    }
154
155    st->loudness = pow(6000,LOUDNESS_EXP);
156    st->loudness2 = 6000;
157    st->nb_loudness_adapt = 0;
158
159    drft_init(&st->fft_lookup,2*N);
160
161
162    st->nb_adapt=0;
163    st->consec_noise=0;
164    st->nb_denoise=0;
165    st->nb_min_estimate=0;
166    st->last_update=0;
167    st->last_id=0;
168    return st;
169 }
170
171 void denoise_state_destroy(DenoiseState *st)
172 {
173    speex_free(st->frame);
174    speex_free(st->ps);
175    speex_free(st->gain2);
176    speex_free(st->window);
177    speex_free(st->noise);
178    speex_free(st->old_ps);
179    speex_free(st->gain);
180    speex_free(st->prior);
181    speex_free(st->post);
182    speex_free(st->min_ps);
183    speex_free(st->last_energy);
184    speex_free(st->last_ps);
185    speex_free(st->loudness_weight);
186
187    speex_free(st->inbuf);
188    speex_free(st->outbuf);
189
190    drft_clear(&st->fft_lookup);
191    
192    speex_free(st);
193 }
194
195 static void update_noise(DenoiseState *st, float *ps)
196 {
197    int i;
198    float beta;
199    st->nb_adapt++;
200    beta=1.0/st->nb_adapt;
201    if (beta < .05)
202       beta=.05;
203    
204    for (i=0;i<st->ps_size;i++)
205       st->noise[i] = (1-beta)*st->noise[i] + beta*ps[i];   
206 }
207
208 int denoise(DenoiseState *st, float *x)
209 {
210    int i;
211    float mean_post=0;
212    float mean_prior=0;
213    float energy;
214    int N = st->ps_size;
215    int N3 = 2*N - st->frame_size;
216    int N4 = st->frame_size - N3;
217    float scale=.5/N;
218    float *ps=st->ps;
219
220    /* 'Build' input frame */
221    for (i=0;i<N3;i++)
222       st->frame[i]=st->inbuf[i];
223    for (i=0;i<st->frame_size;i++)
224       st->frame[N3+i]=x[i];
225    
226    /* Update inbuf */
227    for (i=0;i<N3;i++)
228       st->inbuf[i]=x[N4+i];
229
230    /* Windowing */
231    for (i=0;i<2*N;i++)
232       st->frame[i] *= st->window[i];
233
234    /* Perform FFT */
235    drft_forward(&st->fft_lookup, st->frame);
236
237    /************************************************************** 
238     *  Denoise in spectral domain using Ephraim-Malah algorithm  *
239     **************************************************************/
240
241    /* Power spectrum */
242    ps[0]=1;
243    for (i=1;i<N;i++)
244       ps[i]=1+st->frame[2*i-1]*st->frame[2*i-1] + st->frame[2*i]*st->frame[2*i];
245
246    energy=0;
247    for (i=1;i<N;i++)
248       energy += log(100+ps[i]);
249    energy /= 160;
250    st->last_energy[st->nb_denoise%STABILITY_TIME]=energy;
251
252    if (st->nb_denoise>=STABILITY_TIME)
253    {
254       float E=0, E2=0;
255       float std;
256       for (i=0;i<STABILITY_TIME;i++)
257       {
258          E+=st->last_energy[i];
259          E2+=st->last_energy[i]*st->last_energy[i];
260       }
261       E2=E2/STABILITY_TIME;
262       E=E/STABILITY_TIME;
263       std = sqrt(E2-E*E);
264       if (std<.15 && st->last_update>20)
265       {
266          update_noise(st, &st->last_ps[st->last_id*N]);
267       }
268       /*fprintf (stderr, "%f\n", std);*/
269    }
270
271    st->nb_denoise++;
272 #if 0
273    if (st->nb_min_estimate<50)
274    {
275       float ener=0;
276       for (i=1;i<N;i++)
277          ener += ps[i];
278       /*fprintf (stderr, "%f\n", ener);*/
279       if (ener < st->min_ener || st->nb_min_estimate==0)
280       {
281          st->min_ener = ener;
282          for (i=1;i<N;i++)
283             st->min_ps[i] = ps[i];
284       }
285       st->nb_min_estimate++;
286    } else {
287       float noise_ener=0;
288       st->nb_min_estimate=0;
289       for (i=1;i<N;i++)
290          noise_ener += st->noise[i];
291       /*fprintf (stderr, "%f %f\n", noise_ener, st->min_ener);*/
292       if (0&&(st->last_update>50 && st->min_ener > 3*noise_ener) || st->last_update>50)
293       {
294          for (i=1;i<N;i++)
295          {
296             if (st->noise[i] < st->min_ps[i])
297                st->noise[i] = st->min_ps[i];
298          }
299          /*fprintf (stderr, "tata %d\n",st->last_update);*/
300          st->last_update=0;
301       } else {
302          /*fprintf (stderr, "+");*/
303       }
304    }
305 #endif
306
307    /* Noise estimation always updated for the 20 first times */
308    if (st->nb_adapt<15)
309    {
310       update_noise(st, ps);
311       st->last_update=0;
312    }
313
314    /* Compute a posteriori SNR */
315    for (i=1;i<N;i++)
316    {
317       st->post[i] = ps[i]/(1+st->noise[i]) - 1;
318       if (st->post[i]>100)
319          st->post[i]=100;
320       if (st->post[i]<0)
321         st->post[i]=0;
322       mean_post+=st->post[i];
323    }
324    mean_post /= N;
325    if (mean_post<0)
326       mean_post=0;
327
328    /* Special case for first frame */
329    if (st->nb_adapt==1)
330       for (i=1;i<N;i++)
331          st->old_ps[i] = ps[i];
332
333    /* Compute a priori SNR */
334    {
335       /* A priori update rate */
336       float gamma;
337       float min_gamma=0.05;
338       gamma = 1.0/st->nb_denoise;
339
340       /*Make update rate smaller when there's no speech*/
341       if (mean_post<3)
342          min_gamma *= (mean_post+.1);
343       else
344          min_gamma *= 3.1;
345
346       if (gamma<min_gamma)
347          gamma=min_gamma;
348
349       for (i=1;i<N;i++)
350       {
351          
352          /* A priori SNR update */
353          st->prior[i] = gamma*max(0.0,st->post[i]) +
354          (1-gamma)*st->gain[i]*st->gain[i]*st->old_ps[i]/st->noise[i];
355          
356          if (st->prior[i]>100)
357             st->prior[i]=100;
358          
359          mean_prior+=st->prior[i];
360       }
361    }
362    mean_prior /= N;
363
364 #if 0
365    for (i=0;i<N;i++)
366    {
367       fprintf (stderr, "%f ", st->prior[i]);
368    }
369    fprintf (stderr, "\n");
370 #endif
371    /*fprintf (stderr, "%f %f\n", mean_prior,mean_post);*/
372
373    /* If SNR is low (both a priori and a posteriori), update the noise estimate*/
374    if (mean_prior<.23 && mean_post < .5 && st->nb_adapt>=20)
375    {
376       st->consec_noise++;
377       /*fprintf (stderr, "noise\n");*/
378    } else {
379       st->consec_noise=0;
380    }
381
382    if (st->consec_noise>=3)
383    {
384       update_noise(st, st->old_ps);
385       st->last_update=0;
386    } else {
387       st->last_update++;
388    }
389
390    /* Compute gain according to the Ephraim-Malah algorithm */
391    for (i=1;i<N;i++)
392    {
393       float MM;
394       float theta;
395       float prior_ratio;
396
397       prior_ratio = st->prior[i]/(1.0001+st->prior[i]);
398       theta = (1+st->post[i])*prior_ratio;
399
400       /* Approximation of:
401          exp(-theta/2)*((1+theta)*I0(theta/2) + theta.*I1(theta/2))
402          because I don't feel like computing Bessel functions
403       */
404       /*MM = -.22+1.155*sqrt(theta+1.1);*/
405       MM=-.22+1.163*sqrt(theta+1.1)-.0015*theta;
406
407       st->gain[i] = SQRT_M_PI_2*sqrt(prior_ratio/(1.0001+st->post[i]))*MM;
408       if (st->gain[i]>1)
409       {
410          st->gain[i]=1;
411       }
412       /*st->gain[i] = prior_ratio;*/
413    }
414    st->gain[0]=0;
415    st->gain[N-1]=0;
416
417    for (i=1;i<N-1;i++)
418    {
419       st->gain2[i]=st->gain[i];
420       if (st->gain2[i]<.1)
421          st->gain2[i]=.1;
422    }
423    st->gain2[N-1]=0;
424
425    if ((mean_prior>3&&mean_prior>3))
426    {
427       st->nb_loudness_adapt++;
428       float rate=2.0/(1+st->nb_loudness_adapt);
429       if (rate < .01)
430          rate = .01;
431
432       float loudness=0;
433       for (i=2;i<N;i++)
434       {
435          loudness += scale*st->ps[i] * st->gain2[i] * st->gain2[i] * st->loudness_weight[i];
436       }
437       loudness=sqrt(loudness);
438       /*if (loudness < 2*pow(st->loudness, 1.0/LOUDNESS_EXP) &&
439         loudness*2 > pow(st->loudness, 1.0/LOUDNESS_EXP))*/
440       st->loudness = (1-rate)*st->loudness + (rate)*pow(loudness, LOUDNESS_EXP);
441       
442       st->loudness2 = (1-rate)*st->loudness2 + rate*pow(st->loudness, 1.0/LOUDNESS_EXP);
443
444       loudness = pow(st->loudness, 1.0/LOUDNESS_EXP);
445
446       /*fprintf (stderr, "%f %f %f\n", loudness, st->loudness2, rate);*/
447    }
448    for (i=0;i<N;i++)
449       st->gain2[i] *= 6000.0/st->loudness2;
450
451    /* Apply computed gain */
452    for (i=1;i<N;i++)
453    {
454       st->frame[2*i-1] *= st->gain2[i];
455       st->frame[2*i] *= st->gain2[i];
456    }
457    /* Get rid of the DC and very low frequencies */
458    st->frame[0]=0;
459    st->frame[1]=0;
460    st->frame[2]=0;
461    /* Nyquist frequency is mostly useless too */
462    st->frame[2*N-1]=0;
463
464    /* Inverse FFT with 1/N scaling */
465    drft_backward(&st->fft_lookup, st->frame);
466
467    for (i=0;i<2*N;i++)
468       st->frame[i] *= scale*st->window[i];
469
470    /* Perform overlap and add */
471    for (i=0;i<N3;i++)
472       x[i] = st->outbuf[i] + st->frame[i];
473    for (i=0;i<N4;i++)
474       x[N3+i] = st->frame[N3+i];
475    
476    /* Update outbuf */
477    for (i=0;i<N3;i++)
478       st->outbuf[i] = st->frame[st->frame_size+i];
479
480    /* Save old power spectrum */
481    for (i=1;i<N;i++)
482       st->old_ps[i] = ps[i];
483
484    for (i=1;i<N;i++)
485       st->last_ps[st->last_id*N+i] = ps[i];
486    st->last_id++;
487    if (st->last_id>=NB_LAST_PS)
488       st->last_id=0;
489
490    return 1;
491 }