Last cleanup for 0.0.3
[speexdsp.git] / libspeex / mpulse.c
1 /* Copyright (C) 2002 Jean-Marc Valin 
2    File: mpulse.c
3
4    Multi-pulse code
5
6    This library is free software; you can redistribute it and/or
7    modify it under the terms of the GNU Lesser General Public
8    License as published by the Free Software Foundation; either
9    version 2.1 of the License, or (at your option) any later version.
10    
11    This library is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14    Lesser General Public License for more details.
15    
16    You should have received a copy of the GNU Lesser General Public
17    License along with this library; if not, write to the Free Software
18    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
19
20 */
21
22 #include "mpulse.h"
23 #include "stack_alloc.h"
24 #include <stdio.h>
25 #include <stdlib.h>
26 #include "filters.h"
27 #include <math.h>
28
29 #define MAX_PULSE 60
30 #define MAX_POS  100
31
32 int porder(int *p, int *s, int *o, int len)
33 {
34    int i,j, bit1, nb_uniq=0;
35    int *st, *en;
36    int rep[MAX_POS];
37    int uniq[MAX_PULSE];
38    int n;
39    /*Stupid bubble sort but for small N, we don't care!*/
40    for (i=0;i<MAX_POS;i++)
41       rep[i]=0;
42    for (i=0;i<len;i++)
43    {
44       for (j=i+1;j<len;j++)
45       {
46          if (p[i]>p[j])
47          {
48             int tmp;
49             tmp=p[j];
50             p[j]=p[i];
51             p[i]=tmp;
52             tmp=s[j];
53             s[j]=s[i];
54             s[i]=tmp;
55          }
56       }
57    }
58 #ifdef DEBUG
59    printf ("quant_pulse\n");
60    for (i=0;i<len;i++)
61       printf ("%d ", p[i]);
62    printf ("\n");
63    for (i=0;i<len;i++)
64       printf ("%d ", s[i]);
65    printf ("\n");
66 #endif
67    for (i=0;i<len;i++)
68    {
69       rep[p[i]]++;
70       if (i==0 || p[i]!=p[i-1])
71       {
72          uniq[nb_uniq]=p[i];
73          s[nb_uniq]=s[i];
74          nb_uniq++;
75       }
76    }
77    st=uniq;
78    en=&uniq[nb_uniq-1];
79
80
81    bit1=s[0];
82    n=0;
83    for (i=0;i<nb_uniq;i++)
84    {
85       int next;
86       if (i==nb_uniq-1)
87       {
88          next=*st;
89          for (j=0;j<rep[next];j++)
90          {
91             o[n]=next;
92             n++;
93          }
94       } else {
95          if (s[i+1])
96          {
97             next=*en;
98             for (j=0;j<rep[next];j++)
99             {
100                o[n]=next;
101                n++;
102             }
103             en--;
104          } else {
105             next=*st;
106             for (j=0;j<rep[next];j++)
107             {
108                o[n]=next;
109                n++;
110             }
111             st++;
112          }
113       }
114    }
115 #ifdef DEBUG
116    for (i=0;i<len;i++)
117       printf ("%d ", o[i]);
118    printf ("\n");
119 #endif
120    return s[0];
121 }
122
123 void rorder(int *p, int *s, int *o, int bit, int len)
124 {
125    int i,j,nb_uniq=0;
126    int *st, *en;
127    int rep[MAX_POS];
128    int uniq[MAX_PULSE];
129    int ss[MAX_PULSE];
130    int n;
131    /*Stupid bubble sort but for small N, we don't care!*/
132    for (i=0;i<len;i++)
133       o[i]=p[i];
134    for (i=0;i<len;i++)
135    {
136       for (j=i+1;j<len;j++)
137       {
138          if (o[i]>o[j])
139          {
140             int tmp;
141             tmp=o[j];
142             o[j]=o[i];
143             o[i]=tmp;
144          }
145       }
146    }
147
148    for (i=0;i<len;i++)  
149    {
150       rep[p[i]]++;
151       if (i==0 || o[i]!=o[i-1])
152       {
153          uniq[nb_uniq]=o[i];
154          s[nb_uniq]=s[i];
155          nb_uniq++;
156       }
157    }
158    st=uniq;
159    en=&uniq[nb_uniq-1];
160
161    ss[0]=bit;
162    n=1;
163 #ifdef DEBUG
164    printf ("unquant_pulse\n");
165    for (i=0;i<len;i++)
166       printf ("%d ", o[i]);
167    printf ("\n");
168 #endif
169    for (i=1;i<len;i++)
170    {
171       if (i>1&&p[i-1]==p[i-2])
172          continue;
173       if (p[i-1]==*st)
174       {
175          ss[n++]=0;
176          st++;
177       } else if (p[i-1]==*en)
178       {
179          ss[n++]=1;
180          en--;
181       } else 
182       {
183          fprintf (stderr, "ERROR in decoding signs\n");
184          exit(1);
185       }
186    }
187    
188    n=0;
189    for (i=0;i<len;i++)
190    {
191       s[i]=ss[n];
192       if (i<len&&o[i]!=o[i+1])
193          n++;
194    }
195 #ifdef DEBUG
196    for (i=0;i<len;i++)
197       printf ("%d ", s[i]);
198    printf ("\n");
199 #endif
200 }
201
202
203 void mpulse_search(
204 float target[],                 /* target vector */
205 float ak[],                     /* LPCs for this subframe */
206 float awk1[],                   /* Weighted LPCs for this subframe */
207 float awk2[],                   /* Weighted LPCs for this subframe */
208 void *par,                      /* Codebook/search parameters*/
209 int   p,                        /* number of LPC coeffs */
210 int   nsf,                      /* number of samples in subframe */
211 float *exc,
212 FrameBits *bits,
213 float *stack
214 )
215 {
216    int i,j, nb_pulse;
217    float *resp, *resp2, *energy, *t, *e, *pulses;
218    float te=0,ee=0;
219    float g, gain_coef;
220    int nb_tracks, track_ind_bits;
221    int *tracks, *signs, *tr, *nb;
222    int full_div=0, full_mod=0;
223    mpulse_params *params;
224    int pulses_per_track;
225    params = (mpulse_params *) par;
226
227    nb_pulse=params->nb_pulse;
228    nb_tracks=params->nb_tracks;
229    pulses_per_track=nb_pulse/nb_tracks;
230    track_ind_bits=params->track_ind_bits;
231    gain_coef=params->gain_coef;
232
233    tracks = (int*)PUSH(stack,nb_pulse);
234    signs = (int*)PUSH(stack,nb_pulse);
235    tr = (int*)PUSH(stack,pulses_per_track);
236    nb = (int*)PUSH(stack,nb_tracks);
237
238    resp=PUSH(stack, nsf);
239    resp2=PUSH(stack, nsf);
240    energy=PUSH(stack, nsf);
241    t=PUSH(stack, nsf);
242    e=PUSH(stack, nsf);
243    pulses=PUSH(stack, nsf);
244    
245    /*Compute optimal (real) excitation from target*/
246    syn_filt_zero(target, awk1, e, nsf, p);
247    residue_zero(e, ak, e, nsf, p);
248    residue_zero(e, awk2, e, nsf, p);
249    for (i=0;i<nsf;i++)
250    {
251       pulses[i]=0;
252       te+=target[i]*target[i];
253       ee+=e[i]*e[i];
254    }
255    /*Compute global gain (coef found from linear regression and tweaking)*/
256    g=gain_coef/sqrt(nb_pulse)*exp(0.18163*log(te+1)+0.17293*log(ee+1));
257    
258    e[0]=1;
259    for (i=1;i<nsf;i++)
260       e[i]=0;
261
262    /*Impulse response of W(z)/A(z)*/
263    residue_zero(e, awk1, resp, nsf, p);
264    syn_filt_zero(resp, ak, resp, nsf, p);
265    syn_filt_zero(resp, awk2, resp, nsf, p);
266    
267    /*Impulse response * gain*/
268    for (i=0;i<nsf;i++)
269       resp2[i]=g*resp[i];
270
271    for (i=0;i<nsf;i++)
272       e[i]=0;
273
274    for (i=0;i<nsf;i++)
275       t[i]=target[i];
276
277    for (i=0;i<nb_tracks;i++)
278       nb[i]=0;
279
280    /*For all pulses*/
281    for (i=0;i<nb_pulse;i++)
282    {
283       float best_score=1e30, best_gain=0;
284       int best_ind=0;
285       int mod_track=nb_tracks-1;
286       /*For all positions*/
287       energy[0]=0;
288       for (j=1;j<nsf;j++)
289          energy[j]=energy[j-1]+t[j-1]*t[j-1];
290       /*For each position*/
291       for (j=0;j<nsf;j++)
292       {
293          int k;
294          float dist;
295          float *base=t+j;
296          mod_track++;
297          if (mod_track==nb_tracks)
298             mod_track=0;
299          /*Fill any track until it's full*/
300          if (nb[mod_track]==pulses_per_track || nb[mod_track] > full_div)
301               continue;
302
303          /*Try for positive sign*/
304          if (pulses[j]>=0)
305          {
306             dist=energy[j];
307             for (k=0;k<nsf-j;k++)
308             {
309                float tmp=(base[k]-resp2[k]);
310                dist+=tmp*tmp;
311             }
312             if (dist<best_score || j==0)
313             {
314                best_score=dist;
315                best_gain=g;
316                best_ind=j;
317             }
318          }
319          /*Try again for negative sign*/
320          if (pulses[j]<=0)
321          {
322             dist=energy[j];
323             for (k=0;k<nsf-j;k++)
324             {
325                float tmp=(base[k]+resp2[k]);
326                dist+=tmp*tmp;
327             }
328             if (dist<best_score)
329             {
330                best_score=dist;
331                best_gain=-g;
332                best_ind=j;
333             }
334          }
335       }
336 #ifdef DEBUG
337       printf ("best pulse: %d %d %f %f %f %f\n", i, best_ind, best_gain, te, ee, g);
338 #endif
339       /*Remove pulse contribution from target*/
340       for (j=best_ind;j<nsf;j++)
341          t[j] -= best_gain * resp[j-best_ind];
342       e[best_ind]+=best_gain;
343       if (best_gain>0)
344          pulses[best_ind]+=1;
345       else
346          pulses[best_ind]-=1;
347       {
348          int t=best_ind%nb_tracks;
349          tracks[t*pulses_per_track+nb[t]] = best_ind/nb_tracks;
350          signs[t*pulses_per_track+nb[t]]  = best_gain >= 0 ? 0 : 1;
351          nb[t]++;
352       }
353          full_mod++;
354          if (full_mod==nb_tracks)
355          {
356             full_mod=0;
357             full_div++;
358          }
359    }
360
361    /*Global gain re-estimation*/
362    if (1) {
363       float f;
364       int quant_gain;
365       residue_zero(e, awk1, resp, nsf, p);
366       syn_filt_zero(resp, ak, resp, nsf, p);
367       syn_filt_zero(resp, awk2, resp, nsf, p);
368
369       f=((.1+(xcorr(resp,target,nsf)))/(.1+xcorr(resp,resp,nsf)));
370       /*for (i=0;i<nsf;i++)
371         e[i]*=f;*/
372       g *= f;
373       if (g<0)
374          g=0;
375       
376       quant_gain=(int)floor(.5+8*(log(1+fabs(g))-1));
377       if (quant_gain<0)
378          quant_gain=0;
379       if (quant_gain>127)
380          quant_gain=127;
381       speex_bits_pack(bits,quant_gain,7);
382       g=exp((quant_gain/8.0)+1);
383       
384       for (i=0;i<nsf;i++)
385          e[i]=g*pulses[i];
386 #ifdef DEBUG
387       printf ("global gain = %f\n", g);
388 #endif
389       for (i=0;i<nsf;i++)
390          t[i]=target[i]-f*resp[i];
391
392    }
393 #ifdef DEBUG
394    for (i=0;i<nsf;i++)
395       printf ("%f ", e[i]);
396    printf ("\n");
397 #endif
398    for (i=0;i<nsf;i++)
399       exc[i]+=e[i];
400    for (i=0;i<nsf;i++)
401       target[i]=t[i];
402    
403    for (i=0;i<nb_tracks;i++)
404    {
405       int bit1, ind=0;
406       bit1=porder(tracks+i*pulses_per_track, signs+i*pulses_per_track,tr,pulses_per_track);
407       speex_bits_pack(bits,bit1,1);
408       for (j=0;j<pulses_per_track;j++)
409       {
410          ind*=nsf/nb_tracks;
411          ind+=tr[j];
412          /*printf ("%d ", ind);*/
413       }
414       
415       speex_bits_pack(bits,ind,track_ind_bits);
416
417       /*printf ("track %d %d:", i, ind);
418       for (j=0;j<pulses_per_track;j++)
419         printf ("%d ", tr[j]);
420         printf ("\n");*/
421    }
422    POP(stack);
423    POP(stack);
424    POP(stack);
425    POP(stack);
426    POP(stack);
427    POP(stack);
428    POP(stack);
429    POP(stack);
430    POP(stack);
431    POP(stack);
432 }
433
434
435 void mpulse_unquant(
436 float *exc,
437 void *par,                      /* non-overlapping codebook */
438 int   nsf,                      /* number of samples in subframe */
439 FrameBits *bits,
440 float *stack
441 )
442 {
443    int i,j, bit1, nb_pulse, quant_gain;
444    float g;
445    int nb_tracks, track_ind_bits;
446    int *track, *signs, *tr;
447    mpulse_params *params;
448    int pulses_per_track;
449    params = (mpulse_params *) par;
450
451    nb_pulse=params->nb_pulse;
452    nb_tracks=params->nb_tracks;
453    pulses_per_track=nb_pulse/nb_tracks;
454    track_ind_bits=params->track_ind_bits;
455
456    track = (int*)PUSH(stack,pulses_per_track);
457    signs = (int*)PUSH(stack,pulses_per_track);
458    tr = (int*)PUSH(stack,pulses_per_track);
459    
460    quant_gain=speex_bits_unpack_unsigned(bits, 7);
461    g=exp((quant_gain/8.0)+1);
462    /*Removes glitches when energy is near-zero*/
463    if (g<3)
464       g=0;
465    for (i=0;i<nb_tracks;i++)
466    {
467       int ind;
468       int max_val=nsf/nb_tracks;
469       bit1=speex_bits_unpack_unsigned(bits, 1);
470       ind = speex_bits_unpack_unsigned(bits,track_ind_bits);
471       /*printf ("unquant ind = %d\n", ind);*/
472       for (j=0;j<pulses_per_track;j++)
473       {
474          track[pulses_per_track-1-j]=ind%max_val;
475          ind /= max_val;
476       }
477       rorder(track, signs, tr, bit1, pulses_per_track);
478       for (j=0;j<pulses_per_track;j++)
479       {
480          exc[tr[j]*nb_tracks+i] += signs[j] ? -g : g;
481       }
482    }
483
484    POP(stack);
485    POP(stack);
486    POP(stack);
487 }