e1ed8854dc7fcd1409ee3d6bcd5f9a4daad6d005
[speexdsp.git] / libspeex / mpulse.c
1 /* Copyright (C) 2002 Jean-Marc Valin 
2    File: mpulse.c
3
4    Multi-pulse code
5
6    This library is free software; you can redistribute it and/or
7    modify it under the terms of the GNU Lesser General Public
8    License as published by the Free Software Foundation; either
9    version 2.1 of the License, or (at your option) any later version.
10    
11    This library is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14    Lesser General Public License for more details.
15    
16    You should have received a copy of the GNU Lesser General Public
17    License along with this library; if not, write to the Free Software
18    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
19
20 */
21
22 #include "mpulse.h"
23 #include "stack_alloc.h"
24 #include <stdio.h>
25 #include <stdlib.h>
26 #include "filters.h"
27 #include <math.h>
28
29 #define MAX_PULSE 30
30 #define MAX_POS  100
31
32 int porder(int *p, int *s, int *o, int len)
33 {
34    int i,j, bit1, nb_uniq=0;
35    int *st, *en;
36    int rep[MAX_POS];
37    int uniq[MAX_PULSE];
38    int n;
39    /*Stupid bubble sort but for small N, we don't care!*/
40    for (i=0;i<MAX_POS;i++)
41       rep[i]=0;
42    for (i=0;i<len;i++)
43    {
44       for (j=i+1;j<len;j++)
45       {
46          if (p[i]>p[j])
47          {
48             int tmp;
49             tmp=p[j];
50             p[j]=p[i];
51             p[i]=tmp;
52             tmp=s[j];
53             s[j]=s[i];
54             s[i]=tmp;
55          }
56       }
57    }
58 #ifdef DEBUG
59    printf ("quant_pulse\n");
60    for (i=0;i<len;i++)
61       printf ("%d ", p[i]);
62    printf ("\n");
63    for (i=0;i<len;i++)
64       printf ("%d ", s[i]);
65    printf ("\n");
66 #endif
67    for (i=0;i<len;i++)
68    {
69       rep[p[i]]++;
70       if (i==0 || p[i]!=p[i-1])
71       {
72          uniq[nb_uniq]=p[i];
73          s[nb_uniq]=s[i];
74          nb_uniq++;
75       }
76    }
77    st=uniq;
78    en=&uniq[nb_uniq-1];
79
80
81    bit1=s[0];
82    n=0;
83    for (i=0;i<nb_uniq;i++)
84    {
85       int next;
86       if (i==nb_uniq-1)
87       {
88          next=*st;
89          for (j=0;j<rep[next];j++)
90          {
91             o[n]=next;
92             n++;
93          }
94       } else {
95          if (s[i+1])
96          {
97             next=*en;
98             for (j=0;j<rep[next];j++)
99             {
100                o[n]=next;
101                n++;
102             }
103             en--;
104          } else {
105             next=*st;
106             for (j=0;j<rep[next];j++)
107             {
108                o[n]=next;
109                n++;
110             }
111             st++;
112          }
113       }
114    }
115 #ifdef DEBUG
116    for (i=0;i<len;i++)
117       printf ("%d ", o[i]);
118    printf ("\n");
119 #endif
120    return s[0];
121 }
122
123 void rorder(int *p, int *s, int *o, int bit, int len)
124 {
125    int i,j,nb_uniq=0;
126    int *st, *en;
127    int rep[MAX_POS];
128    int uniq[MAX_PULSE];
129    int ss[MAX_PULSE];
130    int n;
131    /*Stupid bubble sort but for small N, we don't care!*/
132    for (i=0;i<len;i++)
133       o[i]=p[i];
134    for (i=0;i<len;i++)
135    {
136       for (j=i+1;j<len;j++)
137       {
138          if (o[i]>o[j])
139          {
140             int tmp;
141             tmp=o[j];
142             o[j]=o[i];
143             o[i]=tmp;
144          }
145       }
146    }
147
148    for (i=0;i<len;i++)  
149    {
150       rep[p[i]]++;
151       if (i==0 || o[i]!=o[i-1])
152       {
153          uniq[nb_uniq]=o[i];
154          s[nb_uniq]=s[i];
155          nb_uniq++;
156       }
157    }
158    st=uniq;
159    en=&uniq[nb_uniq-1];
160
161    ss[0]=bit;
162    n=1;
163 #ifdef DEBUG
164    printf ("unquant_pulse\n");
165    for (i=0;i<len;i++)
166       printf ("%d ", o[i]);
167    printf ("\n");
168 #endif
169    for (i=1;i<len;i++)
170    {
171       if (i>1&&p[i-1]==p[i-2])
172          continue;
173       if (p[i-1]==*st)
174       {
175          ss[n++]=0;
176          st++;
177       } else if (p[i-1]==*en)
178       {
179          ss[n++]=1;
180          en--;
181       } else 
182       {
183          fprintf (stderr, "ERROR in decoding signs\n");
184          exit(1);
185       }
186    }
187    
188    n=0;
189    for (i=0;i<len;i++)
190    {
191       s[i]=ss[n];
192       if (i<len&&o[i]!=o[i+1])
193          n++;
194    }
195 #ifdef DEBUG
196    for (i=0;i<len;i++)
197       printf ("%d ", s[i]);
198    printf ("\n");
199 #endif
200 }
201
202
203 void mpulse_search(
204 float target[],                 /* target vector */
205 float ak[],                     /* LPCs for this subframe */
206 float awk1[],                   /* Weighted LPCs for this subframe */
207 float awk2[],                   /* Weighted LPCs for this subframe */
208 void *par,                      /* Codebook/search parameters*/
209 int   p,                        /* number of LPC coeffs */
210 int   nsf,                      /* number of samples in subframe */
211 float *exc,
212 FrameBits *bits,
213 float *stack
214 )
215 {
216    int i,j, nb_pulse;
217    float *resp, *t, *e, *pulses;
218    float te=0,ee=0;
219    float g;
220    int nb_tracks, track_ind_bits;
221    int *tracks, *signs, *tr, *nb;
222    mpulse_params *params;
223    int pulses_per_track;
224    params = (mpulse_params *) par;
225
226    nb_pulse=params->nb_pulse;
227    nb_tracks=params->nb_tracks;
228    pulses_per_track=nb_pulse/nb_tracks;
229    track_ind_bits=params->track_ind_bits;
230
231    tracks = (int*)PUSH(stack,nb_pulse);
232    signs = (int*)PUSH(stack,nb_pulse);
233    tr = (int*)PUSH(stack,pulses_per_track);
234    nb = (int*)PUSH(stack,nb_tracks);
235
236    resp=PUSH(stack, nsf);
237    t=PUSH(stack, nsf);
238    e=PUSH(stack, nsf);
239    pulses=PUSH(stack, nsf);
240    
241    syn_filt_zero(target, awk1, e, nsf, p);
242    residue_zero(e, ak, e, nsf, p);
243    residue_zero(e, awk2, e, nsf, p);
244    for (i=0;i<nsf;i++)
245    {
246       pulses[i]=0;
247       te+=target[i]*target[i];
248       ee+=e[i]*e[i];
249    }
250    g=2.2/sqrt(nb_pulse)*exp(0.18163*log(te+1)+0.17293*log(ee+1));
251    
252    e[0]=1;
253    for (i=1;i<nsf;i++)
254       e[i]=0;
255
256    residue_zero(e, awk1, resp, nsf, p);
257    syn_filt_zero(resp, ak, resp, nsf, p);
258    syn_filt_zero(resp, awk2, resp, nsf, p);
259    
260    for (i=0;i<nsf;i++)
261       e[i]=0;
262
263    for (i=0;i<nsf;i++)
264       t[i]=target[i];
265
266    for (i=0;i<nb_tracks;i++)
267       nb[i]=0;
268
269    /*For all pulses*/
270    for (i=0;i<nb_pulse;i++)
271    {
272       float best_score=1e30, best_gain=0;
273       int best_ind=0;
274       /*For all positions*/
275       for (j=0;j<nsf;j++)
276       {
277          int k;
278          float dist=0;
279          /*Fill any track until it's full*/
280          /*if (nb[j%nb_tracks]==pulses_per_track)
281               continue;*/
282          /*Constrain search in alternating tracks*/
283          if ((i%nb_tracks) != (j%nb_tracks))
284            continue;
285          /*Try for positive sign*/
286          for (k=0;k<j;k++)
287             dist+=t[k]*t[k];
288          for (k=0;k<nsf-j;k++)
289             dist+=(t[k+j]-g*resp[k])*(t[k+j]-g*resp[k]);
290          if (dist<best_score || j==0)
291          {
292             best_score=dist;
293             best_gain=g;
294             best_ind=j;
295          }
296          /*Try again for negative sign*/
297          dist=0;
298          for (k=0;k<j;k++)
299             dist+=t[k]*t[k];
300          for (k=0;k<nsf-j;k++)
301             dist+=(t[k+j]+g*resp[k])*(t[k+j]+g*resp[k]);
302          if (dist<best_score || j==0)
303          {
304             best_score=dist;
305             best_gain=-g;
306             best_ind=j;
307          }
308       }
309 #ifdef DEBUG
310       printf ("best pulse: %d %d %f %f %f %f\n", i, best_ind, best_gain, te, ee, g);
311 #endif
312       /*Remove pulse contribution from target*/
313       for (j=best_ind;j<nsf;j++)
314          t[j] -= best_gain * resp[j-best_ind];
315       e[best_ind]+=best_gain;
316       if (best_gain>0)
317          pulses[best_ind]+=1;
318       else
319          pulses[best_ind]-=1;
320       {
321          int t=best_ind%nb_tracks;
322          tracks[t*pulses_per_track+nb[t]] = best_ind/nb_tracks;
323          signs[t*pulses_per_track+nb[t]]  = best_gain >= 0 ? 0 : 1;
324          nb[t]++;
325       }
326    }
327    
328    /*Global gain re-estimation*/
329    if (1) {
330       float f;
331       int quant_gain;
332       residue_zero(e, awk1, resp, nsf, p);
333       syn_filt_zero(resp, ak, resp, nsf, p);
334       syn_filt_zero(resp, awk2, resp, nsf, p);
335
336       f=((.1+(xcorr(resp,target,nsf)))/(.1+xcorr(resp,resp,nsf)));
337       /*for (i=0;i<nsf;i++)
338         e[i]*=f;*/
339       g *= f;
340       if (g<0)
341          g=0;
342       
343       quant_gain=(int)floor(.5+8*(log(1+fabs(g))-1));
344       if (quant_gain<0)
345          quant_gain=0;
346       if (quant_gain>127)
347          quant_gain=127;
348       frame_bits_pack(bits,quant_gain,7);
349       g=exp((quant_gain/8.0)+1);
350       
351       for (i=0;i<nsf;i++)
352          e[i]=g*pulses[i];
353 #ifdef DEBUG
354       printf ("global gain = %f\n", g);
355 #endif
356       for (i=0;i<nsf;i++)
357          t[i]=target[i]-f*resp[i];
358
359    }
360 #ifdef DEBUG
361    for (i=0;i<nsf;i++)
362       printf ("%f ", e[i]);
363    printf ("\n");
364 #endif
365    for (i=0;i<nsf;i++)
366       exc[i]+=e[i];
367    for (i=0;i<nsf;i++)
368       target[i]=t[i];
369    
370    for (i=0;i<nb_tracks;i++)
371    {
372       int bit1, ind=0;
373       bit1=porder(tracks+i*pulses_per_track, signs+i*pulses_per_track,tr,pulses_per_track);
374       frame_bits_pack(bits,bit1,1);
375       for (j=0;j<pulses_per_track;j++)
376       {
377          ind*=nsf/nb_tracks;
378          ind+=tr[j];
379          /*printf ("%d ", ind);*/
380       }
381       
382       frame_bits_pack(bits,ind,track_ind_bits);
383
384       /*printf ("track %d %d:", i, ind);
385       for (j=0;j<pulses_per_track;j++)
386         printf ("%d ", tr[j]);
387         printf ("\n");*/
388    }
389    POP(stack);
390    POP(stack);
391    POP(stack);
392    POP(stack);
393    POP(stack);
394    POP(stack);
395    POP(stack);
396    POP(stack);
397 }
398
399
400 void mpulse_unquant(
401 float *exc,
402 void *par,                      /* non-overlapping codebook */
403 int   nsf,                      /* number of samples in subframe */
404 FrameBits *bits,
405 float *stack
406 )
407 {
408    int i,j, bit1, nb_pulse, quant_gain;
409    float g;
410    int nb_tracks, track_ind_bits;
411    int *track, *signs, *tr;
412    mpulse_params *params;
413    int pulses_per_track;
414    params = (mpulse_params *) par;
415
416    nb_pulse=params->nb_pulse;
417    nb_tracks=params->nb_tracks;
418    pulses_per_track=nb_pulse/nb_tracks;
419    track_ind_bits=params->track_ind_bits;
420
421    track = (int*)PUSH(stack,pulses_per_track);
422    signs = (int*)PUSH(stack,pulses_per_track);
423    tr = (int*)PUSH(stack,pulses_per_track);
424    
425    quant_gain=frame_bits_unpack_unsigned(bits, 7);
426    g=exp((quant_gain/8.0)+1);
427    /*Removes glitches when energy is near-zero*/
428    if (g<3)
429       g=0;
430    for (i=0;i<nb_tracks;i++)
431    {
432       int ind;
433       int max_val=nsf/nb_tracks;
434       bit1=frame_bits_unpack_unsigned(bits, 1);
435       ind = frame_bits_unpack_unsigned(bits,track_ind_bits);
436       /*printf ("unquant ind = %d\n", ind);*/
437       for (j=0;j<pulses_per_track;j++)
438       {
439          track[pulses_per_track-1-j]=ind%max_val;
440          ind /= max_val;
441       }
442       rorder(track, signs, tr, bit1, pulses_per_track);
443       for (j=0;j<pulses_per_track;j++)
444       {
445          exc[tr[j]*nb_tracks+i] += signs[j] ? -g : g;
446       }
447    }
448
449    POP(stack);
450    POP(stack);
451    POP(stack);
452 }