added the "speex" prefix to the denoising stuff
[speexdsp.git] / libspeex / denoise.c
1 /* Copyright (C) 2003 Epic Games 
2    Written by Jean-Marc Valin
3
4    File: denoise.c
5
6
7    Redistribution and use in source and binary forms, with or without
8    modification, are permitted provided that the following conditions are
9    met:
10
11    1. Redistributions of source code must retain the above copyright notice,
12    this list of conditions and the following disclaimer.
13
14    2. Redistributions in binary form must reproduce the above copyright
15    notice, this list of conditions and the following disclaimer in the
16    documentation and/or other materials provided with the distribution.
17
18    3. The name of the author may not be used to endorse or promote products
19    derived from this software without specific prior written permission.
20
21    THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
22    IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
23    OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24    DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT,
25    INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26    (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27    SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
28    HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
29    STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30    ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31    POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #include <math.h>
35 #include "speex_denoise.h"
36 #include <stdio.h>
37 #include "misc.h"
38
39 #define STABILITY_TIME 20
40 #define NB_LAST_PS 10
41
42 #define max(a,b) ((a) > (b) ? (a) : (b))
43 #define min(a,b) ((a) < (b) ? (a) : (b))
44
45 #ifndef M_PI
46 #define M_PI 3.14159263
47 #endif
48
49 #define SQRT_M_PI_2 0.88623
50 #define LOUDNESS_EXP 3.5
51
52 static void conj_window(float *w, int len)
53 {
54    int i;
55    for (i=0;i<len;i++)
56    {
57       float x=4*((float)i)/len;
58       int inv=0;
59       if (x<1)
60       {
61       } else if (x<2)
62       {
63          x=2-x;
64          inv=1;
65       } else if (x<3)
66       {
67          x=x-2;
68          inv=1;
69       } else {
70          x=4-x;
71       }
72       x*=1.9979;
73       w[i]=(.5-.5*cos(x))*(.5-.5*cos(x));
74       if (inv)
75          w[i]=1-w[i];
76       w[i]=sqrt(w[i]);
77    }
78 }
79
80 SpeexDenoiseState *speex_denoise_state_init(int frame_size)
81 {
82    int i;
83    int N, N3, N4;
84
85    SpeexDenoiseState *st = (SpeexDenoiseState *)speex_alloc(sizeof(SpeexDenoiseState));
86    st->frame_size = frame_size;
87
88    /* Round ps_size down to the nearest power of two */
89    i=1;
90    st->ps_size = st->frame_size;
91    while(1)
92    {
93       if (st->ps_size & ~i)
94       {
95          st->ps_size &= ~i;
96          i<<=1;
97       } else {
98          break;
99       }
100    }
101
102    if (st->ps_size < 3*st->frame_size/4)
103       st->ps_size = st->ps_size * 3 / 2;
104    N = st->ps_size;
105    N3 = 2*N - st->frame_size;
106    N4 = st->frame_size - N3;
107
108    st->frame = (float*)speex_alloc(2*N*sizeof(float));
109    st->ps = (float*)speex_alloc(N*sizeof(float));
110    st->gain2 = (float*)speex_alloc(N*sizeof(float));
111    st->window = (float*)speex_alloc(2*N*sizeof(float));
112    st->noise = (float*)speex_alloc(N*sizeof(float));
113    st->old_ps = (float*)speex_alloc(N*sizeof(float));
114    st->gain = (float*)speex_alloc(N*sizeof(float));
115    st->prior = (float*)speex_alloc(N*sizeof(float));
116    st->post = (float*)speex_alloc(N*sizeof(float));
117    st->min_ps = (float*)speex_alloc(N*sizeof(float));
118    st->last_energy = (float*)speex_alloc(STABILITY_TIME*sizeof(float));
119    st->last_ps = (float*)speex_alloc(NB_LAST_PS*N*sizeof(float));
120    st->loudness_weight = (float*)speex_alloc(N*sizeof(float));
121    st->inbuf = (float*)speex_alloc(N3*sizeof(float));
122    st->outbuf = (float*)speex_alloc(N3*sizeof(float));
123
124    conj_window(st->window, 2*N3);
125    for (i=2*N3;i<2*st->ps_size;i++)
126       st->window[i]=1;
127    for (i=N3-1;i>=0;i--)
128    {
129       st->window[i+N3+N4]=st->window[i+N3];
130       st->window[i+N3]=1;
131    }
132
133    for (i=0;i<N;i++)
134    {
135       st->noise[i]=1e4;
136       st->old_ps[i]=1e4;
137       st->gain[i]=1;
138       st->post[i]=1;
139       st->prior[i]=1;
140    }
141
142    for (i=0;i<N3;i++)
143    {
144       st->inbuf[i]=0;
145       st->outbuf[i]=0;
146    }
147
148    for (i=0;i<N;i++)
149    {
150       float ff=((float)i)*128.0/4000.0;
151       st->loudness_weight[i] = .35-.35*ff/16000+.73*exp(-.5*(ff-3800)*(ff-3800)/9e5);
152       st->loudness_weight[i] *= st->loudness_weight[i];
153    }
154
155    st->loudness = pow(6000,LOUDNESS_EXP);
156    st->loudness2 = 6000;
157    st->nb_loudness_adapt = 0;
158
159    drft_init(&st->fft_lookup,2*N);
160
161
162    st->nb_adapt=0;
163    st->consec_noise=0;
164    st->nb_denoise=0;
165    st->nb_min_estimate=0;
166    st->last_update=0;
167    st->last_id=0;
168    return st;
169 }
170
171 void speex_denoise_state_destroy(SpeexDenoiseState *st)
172 {
173    speex_free(st->frame);
174    speex_free(st->ps);
175    speex_free(st->gain2);
176    speex_free(st->window);
177    speex_free(st->noise);
178    speex_free(st->old_ps);
179    speex_free(st->gain);
180    speex_free(st->prior);
181    speex_free(st->post);
182    speex_free(st->min_ps);
183    speex_free(st->last_energy);
184    speex_free(st->last_ps);
185    speex_free(st->loudness_weight);
186
187    speex_free(st->inbuf);
188    speex_free(st->outbuf);
189
190    drft_clear(&st->fft_lookup);
191    
192    speex_free(st);
193 }
194
195 static void update_noise(SpeexDenoiseState *st, float *ps)
196 {
197    int i;
198    float beta;
199    st->nb_adapt++;
200    beta=1.0/st->nb_adapt;
201    if (beta < .05)
202       beta=.05;
203    
204    for (i=0;i<st->ps_size;i++)
205       st->noise[i] = (1-beta)*st->noise[i] + beta*ps[i];   
206 }
207
208 int speex_denoise(SpeexDenoiseState *st, float *x)
209 {
210    int i;
211    int is_speech=0;
212    float mean_post=0;
213    float mean_prior=0;
214    float energy;
215    int N = st->ps_size;
216    int N3 = 2*N - st->frame_size;
217    int N4 = st->frame_size - N3;
218    float scale=.5/N;
219    float *ps=st->ps;
220
221    /* 'Build' input frame */
222    for (i=0;i<N3;i++)
223       st->frame[i]=st->inbuf[i];
224    for (i=0;i<st->frame_size;i++)
225       st->frame[N3+i]=x[i];
226    
227    /* Update inbuf */
228    for (i=0;i<N3;i++)
229       st->inbuf[i]=x[N4+i];
230
231    /* Windowing */
232    for (i=0;i<2*N;i++)
233       st->frame[i] *= st->window[i];
234
235    /* Perform FFT */
236    drft_forward(&st->fft_lookup, st->frame);
237
238    /************************************************************** 
239     *  Denoise in spectral domain using Ephraim-Malah algorithm  *
240     **************************************************************/
241
242    /* Power spectrum */
243    ps[0]=1;
244    for (i=1;i<N;i++)
245       ps[i]=1+st->frame[2*i-1]*st->frame[2*i-1] + st->frame[2*i]*st->frame[2*i];
246
247    energy=0;
248    for (i=1;i<N;i++)
249       energy += log(100+ps[i]);
250    energy /= 160;
251    st->last_energy[st->nb_denoise%STABILITY_TIME]=energy;
252
253    if (st->nb_denoise>=STABILITY_TIME)
254    {
255       float E=0, E2=0;
256       float std;
257       for (i=0;i<STABILITY_TIME;i++)
258       {
259          E+=st->last_energy[i];
260          E2+=st->last_energy[i]*st->last_energy[i];
261       }
262       E2=E2/STABILITY_TIME;
263       E=E/STABILITY_TIME;
264       std = sqrt(E2-E*E);
265       if (std<.15 && st->last_update>20)
266       {
267          update_noise(st, &st->last_ps[st->last_id*N]);
268       }
269       /*fprintf (stderr, "%f\n", std);*/
270    }
271
272    st->nb_denoise++;
273 #if 0
274    if (st->nb_min_estimate<50)
275    {
276       float ener=0;
277       for (i=1;i<N;i++)
278          ener += ps[i];
279       /*fprintf (stderr, "%f\n", ener);*/
280       if (ener < st->min_ener || st->nb_min_estimate==0)
281       {
282          st->min_ener = ener;
283          for (i=1;i<N;i++)
284             st->min_ps[i] = ps[i];
285       }
286       st->nb_min_estimate++;
287    } else {
288       float noise_ener=0;
289       st->nb_min_estimate=0;
290       for (i=1;i<N;i++)
291          noise_ener += st->noise[i];
292       /*fprintf (stderr, "%f %f\n", noise_ener, st->min_ener);*/
293       if (0&&(st->last_update>50 && st->min_ener > 3*noise_ener) || st->last_update>50)
294       {
295          for (i=1;i<N;i++)
296          {
297             if (st->noise[i] < st->min_ps[i])
298                st->noise[i] = st->min_ps[i];
299          }
300          /*fprintf (stderr, "tata %d\n",st->last_update);*/
301          st->last_update=0;
302       } else {
303          /*fprintf (stderr, "+");*/
304       }
305    }
306 #endif
307
308    /* Noise estimation always updated for the 20 first times */
309    if (st->nb_adapt<15)
310    {
311       update_noise(st, ps);
312       st->last_update=0;
313    }
314
315    /* Compute a posteriori SNR */
316    for (i=1;i<N;i++)
317    {
318       st->post[i] = ps[i]/(1+st->noise[i]) - 1;
319       if (st->post[i]>100)
320          st->post[i]=100;
321       if (st->post[i]<0)
322         st->post[i]=0;
323       mean_post+=st->post[i];
324    }
325    mean_post /= N;
326    if (mean_post<0)
327       mean_post=0;
328
329    /* Special case for first frame */
330    if (st->nb_adapt==1)
331       for (i=1;i<N;i++)
332          st->old_ps[i] = ps[i];
333
334    /* Compute a priori SNR */
335    {
336       /* A priori update rate */
337       float gamma;
338       float min_gamma=0.05;
339       gamma = 1.0/st->nb_denoise;
340
341       /*Make update rate smaller when there's no speech*/
342       if (mean_post<3)
343          min_gamma *= (mean_post+.1);
344       else
345          min_gamma *= 3.1;
346
347       if (gamma<min_gamma)
348          gamma=min_gamma;
349
350       for (i=1;i<N;i++)
351       {
352          
353          /* A priori SNR update */
354          st->prior[i] = gamma*max(0.0,st->post[i]) +
355          (1-gamma)*st->gain[i]*st->gain[i]*st->old_ps[i]/st->noise[i];
356          
357          if (st->prior[i]>100)
358             st->prior[i]=100;
359          
360          mean_prior+=st->prior[i];
361       }
362    }
363    mean_prior /= N;
364
365 #if 0
366    for (i=0;i<N;i++)
367    {
368       fprintf (stderr, "%f ", st->prior[i]);
369    }
370    fprintf (stderr, "\n");
371 #endif
372    /*fprintf (stderr, "%f %f\n", mean_prior,mean_post);*/
373
374    /* If SNR is low (both a priori and a posteriori), update the noise estimate*/
375    if (mean_prior<.23 && mean_post < .5 && st->nb_adapt>=20)
376    {
377       st->consec_noise++;
378       /*fprintf (stderr, "noise\n");*/
379    } else {
380       st->consec_noise=0;
381    }
382
383    /*fprintf (stderr, "%f %f ", mean_prior, mean_post);*/
384    if (mean_prior>1 && mean_post > 1)
385    {
386       is_speech=1;
387    }
388
389    if (st->consec_noise>=3)
390    {
391       update_noise(st, st->old_ps);
392       st->last_update=0;
393    } else {
394       st->last_update++;
395    }
396
397    /* Compute gain according to the Ephraim-Malah algorithm */
398    for (i=1;i<N;i++)
399    {
400       float MM;
401       float theta;
402       float prior_ratio;
403
404       prior_ratio = st->prior[i]/(1.0001+st->prior[i]);
405       theta = (1+st->post[i])*prior_ratio;
406
407       /* Approximation of:
408          exp(-theta/2)*((1+theta)*I0(theta/2) + theta.*I1(theta/2))
409          because I don't feel like computing Bessel functions
410       */
411       /*MM = -.22+1.155*sqrt(theta+1.1);*/
412       MM=-.22+1.163*sqrt(theta+1.1)-.0015*theta;
413
414       st->gain[i] = SQRT_M_PI_2*sqrt(prior_ratio/(1.0001+st->post[i]))*MM;
415       if (st->gain[i]>1)
416       {
417          st->gain[i]=1;
418       }
419       /*st->gain[i] = prior_ratio;*/
420    }
421    st->gain[0]=0;
422    st->gain[N-1]=0;
423
424    for (i=1;i<N-1;i++)
425    {
426       st->gain2[i]=st->gain[i];
427       if (st->gain2[i]<.1)
428          st->gain2[i]=.1;
429    }
430    st->gain2[N-1]=0;
431
432    if ((mean_prior>3&&mean_prior>3))
433    {
434       float loudness=0;
435       float rate;
436       st->nb_loudness_adapt++;
437       rate=2.0/(1+st->nb_loudness_adapt);
438       if (rate < .01)
439          rate = .01;
440
441       for (i=2;i<N;i++)
442       {
443          loudness += scale*st->ps[i] * st->gain2[i] * st->gain2[i] * st->loudness_weight[i];
444       }
445       loudness=sqrt(loudness);
446       /*if (loudness < 2*pow(st->loudness, 1.0/LOUDNESS_EXP) &&
447         loudness*2 > pow(st->loudness, 1.0/LOUDNESS_EXP))*/
448       st->loudness = (1-rate)*st->loudness + (rate)*pow(loudness, LOUDNESS_EXP);
449       
450       st->loudness2 = (1-rate)*st->loudness2 + rate*pow(st->loudness, 1.0/LOUDNESS_EXP);
451
452       loudness = pow(st->loudness, 1.0/LOUDNESS_EXP);
453
454       /*fprintf (stderr, "%f %f %f\n", loudness, st->loudness2, rate);*/
455    }
456    for (i=0;i<N;i++)
457       st->gain2[i] *= 6000.0/st->loudness2;
458
459    /* Apply computed gain */
460    for (i=1;i<N;i++)
461    {
462       st->frame[2*i-1] *= st->gain2[i];
463       st->frame[2*i] *= st->gain2[i];
464    }
465    /* Get rid of the DC and very low frequencies */
466    st->frame[0]=0;
467    st->frame[1]=0;
468    st->frame[2]=0;
469    /* Nyquist frequency is mostly useless too */
470    st->frame[2*N-1]=0;
471
472    /* Inverse FFT with 1/N scaling */
473    drft_backward(&st->fft_lookup, st->frame);
474
475    for (i=0;i<2*N;i++)
476       st->frame[i] *= scale*st->window[i];
477
478    /* Perform overlap and add */
479    for (i=0;i<N3;i++)
480       x[i] = st->outbuf[i] + st->frame[i];
481    for (i=0;i<N4;i++)
482       x[N3+i] = st->frame[N3+i];
483    
484    /* Update outbuf */
485    for (i=0;i<N3;i++)
486       st->outbuf[i] = st->frame[st->frame_size+i];
487
488    /* Save old power spectrum */
489    for (i=1;i<N;i++)
490       st->old_ps[i] = ps[i];
491
492    for (i=1;i<N;i++)
493       st->last_ps[st->last_id*N+i] = ps[i];
494    st->last_id++;
495    if (st->last_id>=NB_LAST_PS)
496       st->last_id=0;
497
498    return is_speech;
499 }