Added probability of speech presence to denoiser.
[speexdsp.git] / libspeex / preprocess.c
1 /* Copyright (C) 2003 Epic Games 
2    Written by Jean-Marc Valin
3
4    File: preprocess.c
5    Preprocessor with denoising based on the algorithm by Ephraim and Malah
6
7    Redistribution and use in source and binary forms, with or without
8    modification, are permitted provided that the following conditions are
9    met:
10
11    1. Redistributions of source code must retain the above copyright notice,
12    this list of conditions and the following disclaimer.
13
14    2. Redistributions in binary form must reproduce the above copyright
15    notice, this list of conditions and the following disclaimer in the
16    documentation and/or other materials provided with the distribution.
17
18    3. The name of the author may not be used to endorse or promote products
19    derived from this software without specific prior written permission.
20
21    THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
22    IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
23    OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24    DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT,
25    INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26    (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27    SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
28    HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
29    STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30    ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31    POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #include <math.h>
35 #include "speex_preprocess.h"
36 #include <stdio.h>
37 #include "misc.h"
38 #include "smallft.h"
39
40 #define max(a,b) ((a) > (b) ? (a) : (b))
41 #define min(a,b) ((a) < (b) ? (a) : (b))
42
43 #ifndef M_PI
44 #define M_PI 3.14159263
45 #endif
46
47 #define SQRT_M_PI_2 0.88623
48 #define LOUDNESS_EXP 2.5
49
50 #define NB_BANDS 8
51
52 #define ZMIN .1
53 #define ZMAX .316
54 #define ZMIN_1 10
55 #define LOG_MIN_MAX_1 0.86859
56
57 static void conj_window(float *w, int len)
58 {
59    int i;
60    for (i=0;i<len;i++)
61    {
62       float x=4*((float)i)/len;
63       int inv=0;
64       if (x<1)
65       {
66       } else if (x<2)
67       {
68          x=2-x;
69          inv=1;
70       } else if (x<3)
71       {
72          x=x-2;
73          inv=1;
74       } else {
75          x=4-x;
76       }
77       x*=1.9979;
78       w[i]=(.5-.5*cos(x))*(.5-.5*cos(x));
79       if (inv)
80          w[i]=1-w[i];
81       w[i]=sqrt(w[i]);
82    }
83 }
84
85 /* This function approximates the gain function 
86    y = gamma(1.25)^2 * M(-.25;1;-x) / sqrt(x)  
87    which multiplied by xi/(1+xi) is the optimal gain
88    in the loudness domain ( sqrt[amplitude] )
89 */
90 static float hypergeom_gain(float x)
91 {
92    int ind;
93    float integer, frac;
94    static float table[21] = {
95       0.82157, 1.02017, 1.20461, 1.37534, 1.53363, 1.68092, 1.81865, 
96       1.94811, 2.07038, 2.18638, 2.29688, 2.40255, 2.50391, 2.60144, 
97       2.69551, 2.78647, 2.87458, 2.96015, 3.04333, 3.12431, 3.20326};
98    
99    if (x>9.5)
100       return 1+.12/x;
101
102    integer = floor(x);
103    frac = x-integer;
104    ind = (int)integer;
105    
106    return ((1-frac)*table[ind] + frac*table[ind+1])/sqrt(x+.0001);
107 }
108
109 SpeexPreprocessState *speex_preprocess_state_init(int frame_size, int sampling_rate)
110 {
111    int i;
112    int N, N3, N4;
113
114    SpeexPreprocessState *st = (SpeexPreprocessState *)speex_alloc(sizeof(SpeexPreprocessState));
115    st->frame_size = frame_size;
116
117    /* Round ps_size down to the nearest power of two */
118 #if 0
119    i=1;
120    st->ps_size = st->frame_size;
121    while(1)
122    {
123       if (st->ps_size & ~i)
124       {
125          st->ps_size &= ~i;
126          i<<=1;
127       } else {
128          break;
129       }
130    }
131    
132    
133    if (st->ps_size < 3*st->frame_size/4)
134       st->ps_size = st->ps_size * 3 / 2;
135 #else
136    st->ps_size = st->frame_size;
137 #endif
138
139    N = st->ps_size;
140    N3 = 2*N - st->frame_size;
141    N4 = st->frame_size - N3;
142    
143    st->sampling_rate = sampling_rate;
144    st->denoise_enabled = 1;
145    st->agc_enabled = 0;
146    st->agc_level = 8000;
147    st->vad_enabled = 0;
148
149    st->frame = (float*)speex_alloc(2*N*sizeof(float));
150    st->ps = (float*)speex_alloc(N*sizeof(float));
151    st->gain2 = (float*)speex_alloc(N*sizeof(float));
152    st->window = (float*)speex_alloc(2*N*sizeof(float));
153    st->noise = (float*)speex_alloc(N*sizeof(float));
154    st->old_ps = (float*)speex_alloc(N*sizeof(float));
155    st->gain = (float*)speex_alloc(N*sizeof(float));
156    st->prior = (float*)speex_alloc(N*sizeof(float));
157    st->post = (float*)speex_alloc(N*sizeof(float));
158    st->loudness_weight = (float*)speex_alloc(N*sizeof(float));
159    st->inbuf = (float*)speex_alloc(N3*sizeof(float));
160    st->outbuf = (float*)speex_alloc(N3*sizeof(float));
161    st->echo_noise = (float*)speex_alloc(N*sizeof(float));
162
163    st->S = (float*)speex_alloc(N*sizeof(float));
164    st->Smin = (float*)speex_alloc(N*sizeof(float));
165    st->Stmp = (float*)speex_alloc(N*sizeof(float));
166    st->update_prob = (float*)speex_alloc(N*sizeof(float));
167
168    st->zeta = (float*)speex_alloc(N*sizeof(float));
169    st->Zpeak = 0;
170    st->Zlast = 0;
171
172    st->noise_bands = (float*)speex_alloc(NB_BANDS*sizeof(float));
173    st->noise_bands2 = (float*)speex_alloc(NB_BANDS*sizeof(float));
174    st->speech_bands = (float*)speex_alloc(NB_BANDS*sizeof(float));
175    st->speech_bands2 = (float*)speex_alloc(NB_BANDS*sizeof(float));
176    st->noise_bandsN = st->speech_bandsN = 1;
177
178    conj_window(st->window, 2*N3);
179    for (i=2*N3;i<2*st->ps_size;i++)
180       st->window[i]=1;
181    
182    if (N4>0)
183    {
184       for (i=N3-1;i>=0;i--)
185       {
186          st->window[i+N3+N4]=st->window[i+N3];
187          st->window[i+N3]=1;
188       }
189    }
190    for (i=0;i<N;i++)
191    {
192       st->noise[i]=1e4;
193       st->old_ps[i]=1e4;
194       st->gain[i]=1;
195       st->post[i]=1;
196       st->prior[i]=1;
197    }
198
199    for (i=0;i<N3;i++)
200    {
201       st->inbuf[i]=0;
202       st->outbuf[i]=0;
203    }
204
205    for (i=0;i<N;i++)
206    {
207       float ff=((float)i)*.5*sampling_rate/((float)N);
208       st->loudness_weight[i] = .35-.35*ff/16000+.73*exp(-.5*(ff-3800)*(ff-3800)/9e5);
209       if (st->loudness_weight[i]<.01)
210          st->loudness_weight[i]=.01;
211       st->loudness_weight[i] *= st->loudness_weight[i];
212    }
213
214    st->speech_prob = 0;
215    st->last_speech = 1000;
216    st->loudness = pow(6000,LOUDNESS_EXP);
217    st->loudness2 = 6000;
218    st->nb_loudness_adapt = 0;
219
220    st->fft_lookup = speex_alloc(sizeof(struct drft_lookup));
221    drft_init(st->fft_lookup,2*N);
222
223    st->nb_adapt=0;
224    st->consec_noise=0;
225    st->nb_preprocess=0;
226    return st;
227 }
228
229 void speex_preprocess_state_destroy(SpeexPreprocessState *st)
230 {
231    speex_free(st->frame);
232    speex_free(st->ps);
233    speex_free(st->gain2);
234    speex_free(st->window);
235    speex_free(st->noise);
236    speex_free(st->old_ps);
237    speex_free(st->gain);
238    speex_free(st->prior);
239    speex_free(st->post);
240    speex_free(st->loudness_weight);
241    speex_free(st->echo_noise);
242
243    speex_free(st->S);
244    speex_free(st->Smin);
245    speex_free(st->Stmp);
246    speex_free(st->update_prob);
247
248    speex_free(st->noise_bands);
249    speex_free(st->noise_bands2);
250    speex_free(st->speech_bands);
251    speex_free(st->speech_bands2);
252
253    speex_free(st->inbuf);
254    speex_free(st->outbuf);
255
256    drft_clear(st->fft_lookup);
257    speex_free(st->fft_lookup);
258
259    speex_free(st);
260 }
261
262 static void update_noise(SpeexPreprocessState *st, float *ps, float *echo)
263 {
264    int i;
265    float beta;
266    st->nb_adapt++;
267    beta=1.0/st->nb_adapt;
268    if (beta < .05)
269       beta=.05;
270    
271    if (!echo)
272    {
273       for (i=0;i<st->ps_size;i++)
274          st->noise[i] = (1-beta)*st->noise[i] + beta*ps[i];   
275    } else {
276       for (i=0;i<st->ps_size;i++)
277          st->noise[i] = (1-beta)*st->noise[i] + beta*max(0,ps[i]-echo[i]); 
278 #if 0
279       for (i=0;i<st->ps_size;i++)
280          st->noise[i] = 0;
281 #endif
282    }
283 }
284
285 static int speex_compute_vad(SpeexPreprocessState *st, float *ps, float mean_prior, float mean_post)
286 {
287    int i, is_speech=0;
288    int N = st->ps_size;
289    float scale=.5/N;
290
291    /* FIXME: Clean this up a bit */
292    {
293       float bands[NB_BANDS];
294       int j;
295       float p0, p1;
296       float tot_loudness=0;
297       float x = sqrt(mean_post);
298
299       for (i=5;i<N-10;i++)
300       {
301          tot_loudness += scale*st->ps[i] * st->loudness_weight[i];
302       }
303
304       for (i=0;i<NB_BANDS;i++)
305       {
306          bands[i]=1e4;
307          for (j=i*N/NB_BANDS;j<(i+1)*N/NB_BANDS;j++)
308          {
309             bands[i] += ps[j];
310          }
311          bands[i]=log(bands[i]);
312       }
313       
314       /*p1 = .0005+.6*exp(-.5*(x-.4)*(x-.4)*11)+.1*exp(-1.2*x);
315       if (x<1.5)
316          p0=.1*exp(2*(x-1.5));
317       else
318          p0=.02+.1*exp(-.2*(x-1.5));
319       */
320
321       p0=1/(1+exp(3*(1.5-x)));
322       p1=1-p0;
323
324       /*fprintf (stderr, "%f %f ", p0, p1);*/
325       /*p0 *= .99*st->speech_prob + .01*(1-st->speech_prob);
326       p1 *= .01*st->speech_prob + .99*(1-st->speech_prob);
327       
328       st->speech_prob = p0/(p1+p0);
329       */
330
331       if (st->noise_bandsN < 50 || st->speech_bandsN < 50)
332       {
333          if (mean_post > 5)
334          {
335             float adapt = 1./st->speech_bandsN++;
336             if (adapt<.005)
337                adapt = .005;
338             for (i=0;i<NB_BANDS;i++)
339             {
340                st->speech_bands[i] = (1-adapt)*st->speech_bands[i] + adapt*bands[i];
341                /*st->speech_bands2[i] = (1-adapt)*st->speech_bands2[i] + adapt*bands[i]*bands[i];*/
342                st->speech_bands2[i] = (1-adapt)*st->speech_bands2[i] + adapt*(bands[i]-st->speech_bands[i])*(bands[i]-st->speech_bands[i]);
343             }
344          } else {
345             float adapt = 1./st->noise_bandsN++;
346             if (adapt<.005)
347                adapt = .005;
348             for (i=0;i<NB_BANDS;i++)
349             {
350                st->noise_bands[i] = (1-adapt)*st->noise_bands[i] + adapt*bands[i];
351                /*st->noise_bands2[i] = (1-adapt)*st->noise_bands2[i] + adapt*bands[i]*bands[i];*/
352                st->noise_bands2[i] = (1-adapt)*st->noise_bands2[i] + adapt*(bands[i]-st->noise_bands[i])*(bands[i]-st->noise_bands[i]);
353             }
354          }
355       }
356       p0=p1=1;
357       for (i=0;i<NB_BANDS;i++)
358       {
359          float noise_var, speech_var;
360          float noise_mean, speech_mean;
361          float tmp1, tmp2, pr;
362
363          /*noise_var = 1.01*st->noise_bands2[i] - st->noise_bands[i]*st->noise_bands[i];
364            speech_var = 1.01*st->speech_bands2[i] - st->speech_bands[i]*st->speech_bands[i];*/
365          noise_var = st->noise_bands2[i];
366          speech_var = st->speech_bands2[i];
367          if (noise_var < .1)
368             noise_var = .1;
369          if (speech_var < .1)
370             speech_var = .1;
371          
372          /*speech_var = sqrt(speech_var*noise_var);
373            noise_var = speech_var;*/
374          if (speech_var < .05*speech_var)
375             noise_var = .05*speech_var; 
376          if (speech_var < .05*noise_var)
377             speech_var = .05*noise_var;
378          
379          if (bands[i] < st->noise_bands[i])
380             speech_var = noise_var;
381          if (bands[i] > st->speech_bands[i])
382             noise_var = speech_var;
383
384          speech_mean = st->speech_bands[i];
385          noise_mean = st->noise_bands[i];
386          if (noise_mean < speech_mean - 5)
387             noise_mean = speech_mean - 5;
388
389          tmp1 = exp(-.5*(bands[i]-speech_mean)*(bands[i]-speech_mean)/speech_var)/sqrt(2*M_PI*speech_var);
390          tmp2 = exp(-.5*(bands[i]-noise_mean)*(bands[i]-noise_mean)/noise_var)/sqrt(2*M_PI*noise_var);
391          /*fprintf (stderr, "%f ", (float)(p0/(.01+p0+p1)));*/
392          /*fprintf (stderr, "%f ", (float)(bands[i]));*/
393          pr = tmp1/(1e-25+tmp1+tmp2);
394          /*if (bands[i] < st->noise_bands[i])
395             pr=.01;
396          if (bands[i] > st->speech_bands[i] && pr < .995)
397          pr=.995;*/
398          if (pr>.999)
399             pr=.999;
400          if (pr<.001)
401             pr=.001;
402          /*fprintf (stderr, "%f ", pr);*/
403          p0 *= pr;
404          p1 *= (1-pr);
405       }
406
407       p0 = pow(p0,.2);
408       p1 = pow(p1,.2);      
409       
410 #if 1
411       p0 *= 2;
412       p0=p0/(p1+p0);
413       if (st->last_speech>20) 
414       {
415          float tmp = sqrt(tot_loudness)/st->loudness2;
416          tmp = 1-exp(-10*tmp);
417          if (p0>tmp)
418             p0=tmp;
419       }
420       p1=1-p0;
421 #else
422       if (sqrt(tot_loudness) < .6*st->loudness2 && p0>15*p1)
423          p0=15*p1;
424       if (sqrt(tot_loudness) < .45*st->loudness2 && p0>7*p1)
425          p0=7*p1;
426       if (sqrt(tot_loudness) < .3*st->loudness2 && p0>3*p1)
427          p0=3*p1;
428       if (sqrt(tot_loudness) < .15*st->loudness2 && p0>p1)
429          p0=p1;
430       /*fprintf (stderr, "%f %f ", (float)(sqrt(tot_loudness) /( .25*st->loudness2)), p0/(p1+p0));*/
431 #endif
432
433       p0 *= .99*st->speech_prob + .01*(1-st->speech_prob);
434       p1 *= .01*st->speech_prob + .99*(1-st->speech_prob);
435       
436       st->speech_prob = p0/(1e-25+p1+p0);
437       /*fprintf (stderr, "%f %f %f ", tot_loudness, st->loudness2, st->speech_prob);*/
438
439       if (st->speech_prob>.35 || (st->last_speech < 20 && st->speech_prob>.1))
440       {
441          is_speech = 1;
442          st->last_speech = 0;
443       } else {
444          st->last_speech++;
445          if (st->last_speech<20)
446            is_speech = 1;
447       }
448
449       if (st->noise_bandsN > 50 && st->speech_bandsN > 50)
450       {
451          if (mean_post > 5)
452          {
453             float adapt = 1./st->speech_bandsN++;
454             if (adapt<.005)
455                adapt = .005;
456             for (i=0;i<NB_BANDS;i++)
457             {
458                st->speech_bands[i] = (1-adapt)*st->speech_bands[i] + adapt*bands[i];
459                /*st->speech_bands2[i] = (1-adapt)*st->speech_bands2[i] + adapt*bands[i]*bands[i];*/
460                st->speech_bands2[i] = (1-adapt)*st->speech_bands2[i] + adapt*(bands[i]-st->speech_bands[i])*(bands[i]-st->speech_bands[i]);
461             }
462          } else {
463             float adapt = 1./st->noise_bandsN++;
464             if (adapt<.005)
465                adapt = .005;
466             for (i=0;i<NB_BANDS;i++)
467             {
468                st->noise_bands[i] = (1-adapt)*st->noise_bands[i] + adapt*bands[i];
469                /*st->noise_bands2[i] = (1-adapt)*st->noise_bands2[i] + adapt*bands[i]*bands[i];*/
470                st->noise_bands2[i] = (1-adapt)*st->noise_bands2[i] + adapt*(bands[i]-st->noise_bands[i])*(bands[i]-st->noise_bands[i]);
471             }
472          }
473       }
474
475
476    }
477
478    return is_speech;
479 }
480
481 static void speex_compute_agc(SpeexPreprocessState *st, float mean_prior)
482 {
483    int i;
484    int N = st->ps_size;
485    float scale=.5/N;
486    float agc_gain;
487    int freq_start, freq_end;
488    float active_bands = 0;
489
490    freq_start = (int)(300.0*2*N/st->sampling_rate);
491    freq_end   = (int)(2000.0*2*N/st->sampling_rate);
492    for (i=freq_start;i<freq_end;i++)
493    {
494       if (st->S[i] > 20*st->Smin[i]+1000)
495          active_bands+=1;
496    }
497    active_bands /= (freq_end-freq_start+1);
498
499    if (active_bands > .2)
500    {
501       float loudness=0;
502       float rate, rate2=.2;
503       st->nb_loudness_adapt++;
504       rate=2.0/(1+st->nb_loudness_adapt);
505       if (rate < .05)
506          rate = .05;
507       if (rate < .1 && pow(loudness, LOUDNESS_EXP) > st->loudness)
508          rate = .1;
509       if (rate < .2 && pow(loudness, LOUDNESS_EXP) > 3*st->loudness)
510          rate = .2;
511       if (rate < .4 && pow(loudness, LOUDNESS_EXP) > 10*st->loudness)
512          rate = .4;
513
514       for (i=2;i<N;i++)
515       {
516          loudness += scale*st->ps[i] * st->gain2[i] * st->gain2[i] * st->loudness_weight[i];
517       }
518       loudness=sqrt(loudness);
519       /*if (loudness < 2*pow(st->loudness, 1.0/LOUDNESS_EXP) &&
520         loudness*2 > pow(st->loudness, 1.0/LOUDNESS_EXP))*/
521       st->loudness = (1-rate)*st->loudness + (rate)*pow(loudness, LOUDNESS_EXP);
522       
523       st->loudness2 = (1-rate2)*st->loudness2 + rate2*pow(st->loudness, 1.0/LOUDNESS_EXP);
524
525       loudness = pow(st->loudness, 1.0/LOUDNESS_EXP);
526
527       /*fprintf (stderr, "%f %f %f\n", loudness, st->loudness2, rate);*/
528    }
529    
530    agc_gain = st->agc_level/st->loudness2;
531    /*fprintf (stderr, "%f %f %f %f\n", active_bands, st->loudness, st->loudness2, agc_gain);*/
532    if (agc_gain>200)
533       agc_gain = 200;
534
535    for (i=0;i<N;i++)
536       st->gain2[i] *= agc_gain;
537    
538 }
539
540 static void preprocess_analysis(SpeexPreprocessState *st, float *x)
541 {
542    int i;
543    int N = st->ps_size;
544    int N3 = 2*N - st->frame_size;
545    int N4 = st->frame_size - N3;
546    float *ps=st->ps;
547
548    /* 'Build' input frame */
549    for (i=0;i<N3;i++)
550       st->frame[i]=st->inbuf[i];
551    for (i=0;i<st->frame_size;i++)
552       st->frame[N3+i]=x[i];
553    
554    /* Update inbuf */
555    for (i=0;i<N3;i++)
556       st->inbuf[i]=x[N4+i];
557
558    /* Windowing */
559    for (i=0;i<2*N;i++)
560       st->frame[i] *= st->window[i];
561
562    /* Perform FFT */
563    drft_forward(st->fft_lookup, st->frame);
564
565    /* Power spectrum */
566    ps[0]=1;
567    for (i=1;i<N;i++)
568       ps[i]=1+st->frame[2*i-1]*st->frame[2*i-1] + st->frame[2*i]*st->frame[2*i];
569
570 }
571
572 static void update_noise_prob(SpeexPreprocessState *st)
573 {
574    int i;
575    int N = st->ps_size;
576
577    for (i=1;i<N-1;i++)
578       st->S[i] = 100+ .8*st->S[i] + .05*st->ps[i-1]+.1*st->ps[i]+.05*st->ps[i+1];
579    
580    if (st->nb_preprocess<1)
581    {
582       for (i=1;i<N-1;i++)
583          st->Smin[i] = st->Stmp[i] = st->S[i]+100;
584    }
585
586    if (st->nb_preprocess%80==0)
587    {
588       for (i=1;i<N-1;i++)
589       {
590          st->Smin[i] = min(st->Stmp[i], st->S[i]);
591          st->Stmp[i] = st->S[i];
592       }
593    } else {
594       for (i=1;i<N-1;i++)
595       {
596          st->Smin[i] = min(st->Smin[i], st->S[i]);
597          st->Stmp[i] = min(st->Stmp[i], st->S[i]);      
598       }
599    }
600    for (i=1;i<N-1;i++)
601    {
602       st->update_prob[i] *= .2;
603       if (st->S[i] > 5*st->Smin[i])
604          st->update_prob[i] += .8;
605       /*fprintf (stderr, "%f ", st->S[i]/st->Smin[i]);*/
606       /*fprintf (stderr, "%f ", st->update_prob[i]);*/
607    }
608
609 }
610
611 int speex_preprocess(SpeexPreprocessState *st, float *x, float *echo)
612 {
613    int i;
614    int is_speech=1;
615    float mean_post=0;
616    float mean_prior=0;
617    int N = st->ps_size;
618    int N3 = 2*N - st->frame_size;
619    int N4 = st->frame_size - N3;
620    float scale=.5/N;
621    float *ps=st->ps;
622    float Zframe=0, Pframe;
623
624    preprocess_analysis(st, x);
625
626    update_noise_prob(st);
627
628    st->nb_preprocess++;
629
630    /* Noise estimation always updated for the 20 first times */
631    if (st->nb_adapt<10)
632    {
633       update_noise(st, ps, echo);
634    }
635
636    /* Deal with residual echo if provided */
637    if (echo)
638       for (i=1;i<N;i++)
639          st->echo_noise[i] = (.7*st->echo_noise[i] + .3* echo[i]);
640
641    /* Compute a posteriori SNR */
642    for (i=1;i<N;i++)
643    {
644       st->post[i] = ps[i]/(1+st->noise[i]+st->echo_noise[i]) - 1;
645       if (st->post[i]>100)
646          st->post[i]=100;
647       /*if (st->post[i]<0)
648         st->post[i]=0;*/
649       mean_post+=st->post[i];
650    }
651    mean_post /= N;
652    if (mean_post<0)
653       mean_post=0;
654
655    /* Special case for first frame */
656    if (st->nb_adapt==1)
657       for (i=1;i<N;i++)
658          st->old_ps[i] = ps[i];
659
660    /* Compute a priori SNR */
661    {
662       /* A priori update rate */
663       float gamma;
664       float min_gamma=0.12;
665       gamma = 1.0/st->nb_preprocess;
666
667       /*Make update rate smaller when there's no speech*/
668 #if 0
669       if (mean_post<3.5 && mean_prior < 1)
670          min_gamma *= (mean_post+.5);
671       else
672          min_gamma *= 4.;
673 #else
674       min_gamma = .2*fabs(mean_prior - mean_post)*fabs(mean_prior - mean_post);
675       if (min_gamma>.6)
676          min_gamma = .6;
677       if (min_gamma<.01)
678          min_gamma = .01;
679 #endif
680
681       if (gamma<min_gamma)
682          gamma=min_gamma;
683       
684       for (i=1;i<N;i++)
685       {
686          
687          /* A priori SNR update */
688          st->prior[i] = gamma*max(0.0,st->post[i]) +
689          (1-gamma)*st->gain[i]*st->gain[i]*st->old_ps[i]/(1+st->noise[i]+st->echo_noise[i]);
690          
691          if (st->prior[i]>100)
692             st->prior[i]=100;
693          
694          mean_prior+=st->prior[i];
695       }
696    }
697    mean_prior /= N;
698
699 #if 0
700    for (i=0;i<N;i++)
701    {
702       fprintf (stderr, "%f ", st->prior[i]);
703    }
704    fprintf (stderr, "\n");
705 #endif
706    /*fprintf (stderr, "%f %f\n", mean_prior,mean_post);*/
707
708    if (st->nb_preprocess>=20)
709    {
710       int do_update = 0;
711       float noise_ener=0, sig_ener=0;
712       /* If SNR is low (both a priori and a posteriori), update the noise estimate*/
713       /*if (mean_prior<.23 && mean_post < .5)*/
714       if (mean_prior<.23 && mean_post < .5)
715          do_update = 1;
716       for (i=1;i<N;i++)
717       {
718          noise_ener += st->noise[i];
719          sig_ener += ps[i];
720       }
721       if (noise_ener > 3*sig_ener)
722          do_update = 1;
723       /*do_update = 0;*/
724       if (do_update)
725       {
726          st->consec_noise++;
727       } else {
728          st->consec_noise=0;
729       }
730    }
731
732    if (st->vad_enabled)
733       is_speech = speex_compute_vad(st, ps, mean_prior, mean_post);
734
735
736    if (st->consec_noise>=3)
737    {
738       update_noise(st, st->old_ps, echo);
739    } else {
740       for (i=1;i<N-1;i++)
741       {
742          if (st->update_prob[i]<.5)
743             st->noise[i] = .90*st->noise[i] + .1*st->ps[i];
744       }
745    }
746
747    for (i=1;i<N;i++)
748    {
749       st->zeta[i] = .7*st->zeta[i] + .3*st->prior[i];
750    }
751
752    {
753       int freq_start = (int)(300.0*2*N/st->sampling_rate);
754       int freq_end   = (int)(2000.0*2*N/st->sampling_rate);
755       for (i=freq_start;i<freq_end;i++)
756       {
757          Zframe += st->zeta[i];         
758       }
759    }
760
761    Zframe /= N;
762    if (Zframe<ZMIN)
763    {
764       Pframe = 0;
765    } else {
766       if (Zframe > 1.5*st->Zlast)
767       {
768          Pframe = 1;
769          st->Zpeak = Zframe;
770          if (st->Zpeak > 10)
771             st->Zpeak = 10;
772          if (st->Zpeak < 1)
773             st->Zpeak = 1;
774       } else {
775          if (Zframe < st->Zpeak*ZMIN)
776          {
777             Pframe = 0;
778          } else if (Zframe > st->Zpeak*ZMAX)
779          {
780             Pframe = 1;
781          } else {
782             Pframe = log(Zframe/(st->Zpeak*ZMIN)) / log(ZMAX/ZMIN);
783          }
784       }
785    }
786    st->Zlast = Zframe;
787
788    fprintf (stderr, "%f\n", Pframe);
789    /* Compute gain according to the Ephraim-Malah algorithm */
790    for (i=1;i<N;i++)
791    {
792       float MM;
793       float theta;
794       float prior_ratio;
795       float p, q;
796       float zeta1;
797       float P1;
798
799       prior_ratio = st->prior[i]/(1.0001+st->prior[i]);
800       theta = (1+st->post[i])*prior_ratio;
801
802       if (i==1 || i==N-1)
803          zeta1 = st->zeta[i];
804       else
805          zeta1 = .25*st->zeta[i-1] + .5*st->zeta[i] + .25*st->zeta[i+1];
806       if (zeta1<ZMIN)
807          P1 = 0;
808       else if (zeta1>ZMAX)
809          P1 = 1;
810       else
811          P1 = LOG_MIN_MAX_1 * log(ZMIN_1*zeta1);
812   
813       /*P1 = log(zeta1/ZMIN)/log(ZMAX/ZMIN);*/
814       
815       /* FIXME: add global prop (P2) */
816       q = 1-Pframe*P1;
817       if (q>.95)
818          q=.95;
819       p=1/(1 + (q/(1-q))*(1+st->prior[i])*exp(-theta));
820       
821
822 #if 0
823       /* log-spectral magnitude estimator */
824       if (theta<6)
825          MM = 0.74082*pow(theta+1,.61)/sqrt(.0001+theta);
826       else
827          MM=1;
828 #else
829       /* Optimal estimator for loudness domain */
830       MM = hypergeom_gain(theta);
831 #endif
832
833       st->gain[i] = prior_ratio * MM;
834       /*Put some (very arbitraty) limit on the gain*/
835       if (st->gain[i]>2)
836       {
837          st->gain[i]=2;
838       }
839
840       if (st->denoise_enabled)
841       {
842          st->gain2[i]=p*p*st->gain[i];
843       } else {
844          st->gain2[i]=1;
845       }
846    }
847    st->gain2[0]=st->gain[0]=0;
848    st->gain2[N-1]=st->gain[N-1]=0;
849
850    if (st->agc_enabled)
851       speex_compute_agc(st, mean_prior);
852
853 #if 0
854    if (!is_speech)
855    {
856       for (i=0;i<N;i++)
857          st->gain2[i] = 0;
858    }
859 #if 0
860  else {
861       for (i=0;i<N;i++)
862          st->gain2[i] = 1;
863    }
864 #endif
865 #endif
866
867    /* Apply computed gain */
868    for (i=1;i<N;i++)
869    {
870       st->frame[2*i-1] *= st->gain2[i];
871       st->frame[2*i] *= st->gain2[i];
872    }
873
874    /* Get rid of the DC and very low frequencies */
875    st->frame[0]=0;
876    st->frame[1]=0;
877    st->frame[2]=0;
878    /* Nyquist frequency is mostly useless too */
879    st->frame[2*N-1]=0;
880
881    /* Inverse FFT with 1/N scaling */
882    drft_backward(st->fft_lookup, st->frame);
883
884    for (i=0;i<2*N;i++)
885       st->frame[i] *= scale;
886
887    {
888       float max_sample=0;
889       for (i=0;i<2*N;i++)
890          if (fabs(st->frame[i])>max_sample)
891             max_sample = fabs(st->frame[i]);
892       if (max_sample>28000)
893       {
894          float damp = 28000./max_sample;
895          for (i=0;i<2*N;i++)
896             st->frame[i] *= damp;
897       }
898    }
899
900    for (i=0;i<2*N;i++)
901       st->frame[i] *= st->window[i];
902
903    /* Perform overlap and add */
904    for (i=0;i<N3;i++)
905       x[i] = st->outbuf[i] + st->frame[i];
906    for (i=0;i<N4;i++)
907       x[N3+i] = st->frame[N3+i];
908    
909    /* Update outbuf */
910    for (i=0;i<N3;i++)
911       st->outbuf[i] = st->frame[st->frame_size+i];
912
913    /* Save old power spectrum */
914    for (i=1;i<N;i++)
915       st->old_ps[i] = ps[i];
916
917    return is_speech;
918 }
919
920 void speex_preprocess_estimate_update(SpeexPreprocessState *st, float *x, float *noise)
921 {
922    int i;
923    int N = st->ps_size;
924    int N3 = 2*N - st->frame_size;
925
926    float *ps=st->ps;
927
928    preprocess_analysis(st, x);
929
930    update_noise_prob(st);
931
932    st->nb_preprocess++;
933    
934    for (i=1;i<N-1;i++)
935    {
936       if (st->update_prob[i]<.5)
937          st->noise[i] = .90*st->noise[i] + .1*ps[i];
938    }
939
940    for (i=0;i<N3;i++)
941       st->outbuf[i] = x[st->frame_size-N3+i]*st->window[st->frame_size+i];
942
943    /* Save old power spectrum */
944    for (i=1;i<N;i++)
945       st->old_ps[i] = ps[i];
946
947 }
948
949
950 int speex_preprocess_ctl(SpeexPreprocessState *state, int request, void *ptr)
951 {
952    SpeexPreprocessState *st;
953    st=(SpeexPreprocessState*)state;
954    switch(request)
955    {
956    case SPEEX_PREPROCESS_SET_DENOISE:
957       st->denoise_enabled = (*(int*)ptr);
958       break;
959    case SPEEX_PREPROCESS_GET_DENOISE:
960       (*(int*)ptr) = st->denoise_enabled;
961       break;
962
963    case SPEEX_PREPROCESS_SET_AGC:
964       st->agc_enabled = (*(int*)ptr);
965       break;
966    case SPEEX_PREPROCESS_GET_AGC:
967       (*(int*)ptr) = st->agc_enabled;
968       break;
969
970    case SPEEX_PREPROCESS_SET_AGC_LEVEL:
971       st->agc_level = (*(float*)ptr);
972       if (st->agc_level<1)
973          st->agc_level=1;
974       if (st->agc_level>32768)
975          st->agc_level=32768;
976       break;
977    case SPEEX_PREPROCESS_GET_AGC_LEVEL:
978       (*(float*)ptr) = st->agc_level;
979       break;
980
981    case SPEEX_PREPROCESS_SET_VAD:
982       st->vad_enabled = (*(int*)ptr);
983       break;
984    case SPEEX_PREPROCESS_GET_VAD:
985       (*(int*)ptr) = st->vad_enabled;
986       break;
987    default:
988       speex_warning_int("Unknown speex_preprocess_ctl request: ", request);
989       return -1;
990    }
991    return 0;
992 }