code cleanup, some optimizations
[speexdsp.git] / libspeex / mpulse.c
1 /* Copyright (C) 2002 Jean-Marc Valin 
2    File: mpulse.c
3
4    Multi-pulse code
5
6    This library is free software; you can redistribute it and/or
7    modify it under the terms of the GNU Lesser General Public
8    License as published by the Free Software Foundation; either
9    version 2.1 of the License, or (at your option) any later version.
10    
11    This library is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14    Lesser General Public License for more details.
15    
16    You should have received a copy of the GNU Lesser General Public
17    License along with this library; if not, write to the Free Software
18    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
19
20 */
21
22 #include "mpulse.h"
23 #include "stack_alloc.h"
24 #include <stdio.h>
25 #include <stdlib.h>
26 #include "filters.h"
27 #include <math.h>
28
29 #define MAX_PULSE 30
30 #define MAX_POS  100
31
32 int porder(int *p, int *s, int *o, int len)
33 {
34    int i,j, bit1, nb_uniq=0;
35    int *st, *en;
36    int rep[MAX_POS];
37    int uniq[MAX_PULSE];
38    int n;
39    /*Stupid bubble sort but for small N, we don't care!*/
40    for (i=0;i<MAX_POS;i++)
41       rep[i]=0;
42    for (i=0;i<len;i++)
43    {
44       for (j=i+1;j<len;j++)
45       {
46          if (p[i]>p[j])
47          {
48             int tmp;
49             tmp=p[j];
50             p[j]=p[i];
51             p[i]=tmp;
52             tmp=s[j];
53             s[j]=s[i];
54             s[i]=tmp;
55          }
56       }
57    }
58 #ifdef DEBUG
59    printf ("quant_pulse\n");
60    for (i=0;i<len;i++)
61       printf ("%d ", p[i]);
62    printf ("\n");
63    for (i=0;i<len;i++)
64       printf ("%d ", s[i]);
65    printf ("\n");
66 #endif
67    for (i=0;i<len;i++)
68    {
69       rep[p[i]]++;
70       if (i==0 || p[i]!=p[i-1])
71       {
72          uniq[nb_uniq]=p[i];
73          s[nb_uniq]=s[i];
74          nb_uniq++;
75       }
76    }
77    st=uniq;
78    en=&uniq[nb_uniq-1];
79
80
81    bit1=s[0];
82    n=0;
83    for (i=0;i<nb_uniq;i++)
84    {
85       int next;
86       if (i==nb_uniq-1)
87       {
88          next=*st;
89          for (j=0;j<rep[next];j++)
90          {
91             o[n]=next;
92             n++;
93          }
94       } else {
95          if (s[i+1])
96          {
97             next=*en;
98             for (j=0;j<rep[next];j++)
99             {
100                o[n]=next;
101                n++;
102             }
103             en--;
104          } else {
105             next=*st;
106             for (j=0;j<rep[next];j++)
107             {
108                o[n]=next;
109                n++;
110             }
111             st++;
112          }
113       }
114    }
115 #ifdef DEBUG
116    for (i=0;i<len;i++)
117       printf ("%d ", o[i]);
118    printf ("\n");
119 #endif
120    return s[0];
121 }
122
123 void rorder(int *p, int *s, int *o, int bit, int len)
124 {
125    int i,j,nb_uniq=0;
126    int *st, *en;
127    int rep[MAX_POS];
128    int uniq[MAX_PULSE];
129    int ss[MAX_PULSE];
130    int n;
131    /*Stupid bubble sort but for small N, we don't care!*/
132    for (i=0;i<len;i++)
133       o[i]=p[i];
134    for (i=0;i<len;i++)
135    {
136       for (j=i+1;j<len;j++)
137       {
138          if (o[i]>o[j])
139          {
140             int tmp;
141             tmp=o[j];
142             o[j]=o[i];
143             o[i]=tmp;
144          }
145       }
146    }
147
148    for (i=0;i<len;i++)  
149    {
150       rep[p[i]]++;
151       if (i==0 || o[i]!=o[i-1])
152       {
153          uniq[nb_uniq]=o[i];
154          s[nb_uniq]=s[i];
155          nb_uniq++;
156       }
157    }
158    st=uniq;
159    en=&uniq[nb_uniq-1];
160
161    ss[0]=bit;
162    n=1;
163 #ifdef DEBUG
164    printf ("unquant_pulse\n");
165    for (i=0;i<len;i++)
166       printf ("%d ", o[i]);
167    printf ("\n");
168 #endif
169    for (i=1;i<len;i++)
170    {
171       if (i>1&&p[i-1]==p[i-2])
172          continue;
173       if (p[i-1]==*st)
174       {
175          ss[n++]=0;
176          st++;
177       } else if (p[i-1]==*en)
178       {
179          ss[n++]=1;
180          en--;
181       } else 
182       {
183          fprintf (stderr, "ERROR in decoding signs\n");
184          exit(1);
185       }
186    }
187    
188    n=0;
189    for (i=0;i<len;i++)
190    {
191       s[i]=ss[n];
192       if (i<len&&o[i]!=o[i+1])
193          n++;
194    }
195 #ifdef DEBUG
196    for (i=0;i<len;i++)
197       printf ("%d ", s[i]);
198    printf ("\n");
199 #endif
200 }
201
202
203 void mpulse_search(
204 float target[],                 /* target vector */
205 float ak[],                     /* LPCs for this subframe */
206 float awk1[],                   /* Weighted LPCs for this subframe */
207 float awk2[],                   /* Weighted LPCs for this subframe */
208 void *par,                      /* Codebook/search parameters*/
209 int   p,                        /* number of LPC coeffs */
210 int   nsf,                      /* number of samples in subframe */
211 float *exc,
212 FrameBits *bits,
213 float *stack
214 )
215 {
216    int i,j, nb_pulse;
217    float *resp, *resp2, *energy, *t, *e, *pulses;
218    float te=0,ee=0;
219    float g;
220    int nb_tracks, track_ind_bits;
221    int *tracks, *signs, *tr, *nb;
222    mpulse_params *params;
223    int pulses_per_track;
224    params = (mpulse_params *) par;
225
226    nb_pulse=params->nb_pulse;
227    nb_tracks=params->nb_tracks;
228    pulses_per_track=nb_pulse/nb_tracks;
229    track_ind_bits=params->track_ind_bits;
230
231    tracks = (int*)PUSH(stack,nb_pulse);
232    signs = (int*)PUSH(stack,nb_pulse);
233    tr = (int*)PUSH(stack,pulses_per_track);
234    nb = (int*)PUSH(stack,nb_tracks);
235
236    resp=PUSH(stack, nsf);
237    resp2=PUSH(stack, nsf);
238    energy=PUSH(stack, nsf);
239    t=PUSH(stack, nsf);
240    e=PUSH(stack, nsf);
241    pulses=PUSH(stack, nsf);
242    
243    /*Compute optimal (real) excitation from target*/
244    syn_filt_zero(target, awk1, e, nsf, p);
245    residue_zero(e, ak, e, nsf, p);
246    residue_zero(e, awk2, e, nsf, p);
247    for (i=0;i<nsf;i++)
248    {
249       pulses[i]=0;
250       te+=target[i]*target[i];
251       ee+=e[i]*e[i];
252    }
253    /*Compute global gain (coef found from linear regression and tweaking)*/
254    g=2.2/sqrt(nb_pulse)*exp(0.18163*log(te+1)+0.17293*log(ee+1));
255    
256    e[0]=1;
257    for (i=1;i<nsf;i++)
258       e[i]=0;
259
260    /*Impulse response of W(z)/A(z)*/
261    residue_zero(e, awk1, resp, nsf, p);
262    syn_filt_zero(resp, ak, resp, nsf, p);
263    syn_filt_zero(resp, awk2, resp, nsf, p);
264    
265    /*Impulse response * gain*/
266    for (i=0;i<nsf;i++)
267       resp2[i]=g*resp[i];
268
269    for (i=0;i<nsf;i++)
270       e[i]=0;
271
272    for (i=0;i<nsf;i++)
273       t[i]=target[i];
274
275    for (i=0;i<nb_tracks;i++)
276       nb[i]=0;
277
278    /*For all pulses*/
279    for (i=0;i<nb_pulse;i++)
280    {
281       float best_score=1e30, best_gain=0;
282       int best_ind=0;
283       /*For all positions*/
284       energy[0]=0;
285       for (j=1;j<nsf;j++)
286          energy[j]=energy[j-1]+t[j-1]*t[j-1];
287       /*For each position*/
288       for (j=0;j<nsf;j++)
289       {
290          int k;
291          float dist;
292          float *base=t+j;
293          /*Fill any track until it's full*/
294          /*if (nb[j%nb_tracks]==pulses_per_track)
295               continue;*/
296          /*Constrain search in alternating tracks*/
297          /*FIXME: Should get rid of this, it's *really* slow*/
298          if ((i%nb_tracks) != (j%nb_tracks))
299            continue;
300          /*Try for positive sign*/
301          dist=energy[j];
302          for (k=0;k<nsf-j;k++)
303          {
304             float tmp=(base[k]-resp2[k]);
305             dist+=tmp*tmp;
306          }
307          if (dist<best_score || j==0)
308          {
309             best_score=dist;
310             best_gain=g;
311             best_ind=j;
312          }
313          /*Try again for negative sign*/
314          dist=energy[j];
315          for (k=0;k<nsf-j;k++)
316          {
317             float tmp=(base[k]+resp2[k]);
318             dist+=tmp*tmp;
319          }
320             if (dist<best_score || j==0)
321          {
322             best_score=dist;
323             best_gain=-g;
324             best_ind=j;
325          }
326       }
327 #ifdef DEBUG
328       printf ("best pulse: %d %d %f %f %f %f\n", i, best_ind, best_gain, te, ee, g);
329 #endif
330       /*Remove pulse contribution from target*/
331       for (j=best_ind;j<nsf;j++)
332          t[j] -= best_gain * resp[j-best_ind];
333       e[best_ind]+=best_gain;
334       if (best_gain>0)
335          pulses[best_ind]+=1;
336       else
337          pulses[best_ind]-=1;
338       {
339          int t=best_ind%nb_tracks;
340          tracks[t*pulses_per_track+nb[t]] = best_ind/nb_tracks;
341          signs[t*pulses_per_track+nb[t]]  = best_gain >= 0 ? 0 : 1;
342          nb[t]++;
343       }
344    }
345
346    /*Global gain re-estimation*/
347    if (1) {
348       float f;
349       int quant_gain;
350       residue_zero(e, awk1, resp, nsf, p);
351       syn_filt_zero(resp, ak, resp, nsf, p);
352       syn_filt_zero(resp, awk2, resp, nsf, p);
353
354       f=((.1+(xcorr(resp,target,nsf)))/(.1+xcorr(resp,resp,nsf)));
355       /*for (i=0;i<nsf;i++)
356         e[i]*=f;*/
357       g *= f;
358       if (g<0)
359          g=0;
360       
361       quant_gain=(int)floor(.5+8*(log(1+fabs(g))-1));
362       if (quant_gain<0)
363          quant_gain=0;
364       if (quant_gain>127)
365          quant_gain=127;
366       frame_bits_pack(bits,quant_gain,7);
367       g=exp((quant_gain/8.0)+1);
368       
369       for (i=0;i<nsf;i++)
370          e[i]=g*pulses[i];
371 #ifdef DEBUG
372       printf ("global gain = %f\n", g);
373 #endif
374       for (i=0;i<nsf;i++)
375          t[i]=target[i]-f*resp[i];
376
377    }
378 #ifdef DEBUG
379    for (i=0;i<nsf;i++)
380       printf ("%f ", e[i]);
381    printf ("\n");
382 #endif
383    for (i=0;i<nsf;i++)
384       exc[i]+=e[i];
385    for (i=0;i<nsf;i++)
386       target[i]=t[i];
387    
388    for (i=0;i<nb_tracks;i++)
389    {
390       int bit1, ind=0;
391       bit1=porder(tracks+i*pulses_per_track, signs+i*pulses_per_track,tr,pulses_per_track);
392       frame_bits_pack(bits,bit1,1);
393       for (j=0;j<pulses_per_track;j++)
394       {
395          ind*=nsf/nb_tracks;
396          ind+=tr[j];
397          /*printf ("%d ", ind);*/
398       }
399       
400       frame_bits_pack(bits,ind,track_ind_bits);
401
402       /*printf ("track %d %d:", i, ind);
403       for (j=0;j<pulses_per_track;j++)
404         printf ("%d ", tr[j]);
405         printf ("\n");*/
406    }
407    POP(stack);
408    POP(stack);
409    POP(stack);
410    POP(stack);
411    POP(stack);
412    POP(stack);
413    POP(stack);
414    POP(stack);
415    POP(stack);
416    POP(stack);
417 }
418
419
420 void mpulse_unquant(
421 float *exc,
422 void *par,                      /* non-overlapping codebook */
423 int   nsf,                      /* number of samples in subframe */
424 FrameBits *bits,
425 float *stack
426 )
427 {
428    int i,j, bit1, nb_pulse, quant_gain;
429    float g;
430    int nb_tracks, track_ind_bits;
431    int *track, *signs, *tr;
432    mpulse_params *params;
433    int pulses_per_track;
434    params = (mpulse_params *) par;
435
436    nb_pulse=params->nb_pulse;
437    nb_tracks=params->nb_tracks;
438    pulses_per_track=nb_pulse/nb_tracks;
439    track_ind_bits=params->track_ind_bits;
440
441    track = (int*)PUSH(stack,pulses_per_track);
442    signs = (int*)PUSH(stack,pulses_per_track);
443    tr = (int*)PUSH(stack,pulses_per_track);
444    
445    quant_gain=frame_bits_unpack_unsigned(bits, 7);
446    g=exp((quant_gain/8.0)+1);
447    /*Removes glitches when energy is near-zero*/
448    if (g<3)
449       g=0;
450    for (i=0;i<nb_tracks;i++)
451    {
452       int ind;
453       int max_val=nsf/nb_tracks;
454       bit1=frame_bits_unpack_unsigned(bits, 1);
455       ind = frame_bits_unpack_unsigned(bits,track_ind_bits);
456       /*printf ("unquant ind = %d\n", ind);*/
457       for (j=0;j<pulses_per_track;j++)
458       {
459          track[pulses_per_track-1-j]=ind%max_val;
460          ind /= max_val;
461       }
462       rorder(track, signs, tr, bit1, pulses_per_track);
463       for (j=0;j<pulses_per_track;j++)
464       {
465          exc[tr[j]*nb_tracks+i] += signs[j] ? -g : g;
466       }
467    }
468
469    POP(stack);
470    POP(stack);
471    POP(stack);
472 }