Added an n-best VQ search function in order to simplify the code
[speexdsp.git] / libspeex / cb_search.c
1 /* Copyright (C) 2002 Jean-Marc Valin 
2    File: cb_search.c
3
4    This library is free software; you can redistribute it and/or
5    modify it under the terms of the GNU Lesser General Public
6    License as published by the Free Software Foundation; either
7    version 2.1 of the License, or (at your option) any later version.
8    
9    This library is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12    Lesser General Public License for more details.
13    
14    You should have received a copy of the GNU Lesser General Public
15    License along with this library; if not, write to the Free Software
16    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
17 */
18
19
20
21 #include <stdlib.h>
22 #include <cb_search.h>
23 #include "filters.h"
24 #include <math.h>
25 #include <stdio.h>
26 #include "stack_alloc.h"
27 #include "vq.h"
28
29 #define min(a,b) ((a) < (b) ? (a) : (b))
30 #define max(a,b) ((a) > (b) ? (a) : (b))
31
32 void split_cb_search_nogain(
33 float target[],                 /* target vector */
34 float ak[],                     /* LPCs for this subframe */
35 float awk1[],                   /* Weighted LPCs for this subframe */
36 float awk2[],                   /* Weighted LPCs for this subframe */
37 void *par,                      /* Codebook/search parameters*/
38 int   p,                        /* number of LPC coeffs */
39 int   nsf,                      /* number of samples in subframe */
40 float *exc,
41 SpeexBits *bits,
42 float *stack
43 )
44 {
45    int i,j;
46    float *resp;
47    float *t, *r, *e;
48    int *ind;
49    float *shape_cb;
50    int shape_cb_size, subvect_size, nb_subvect;
51    split_cb_params *params;
52
53    params = (split_cb_params *) par;
54    subvect_size = params->subvect_size;
55    nb_subvect = params->nb_subvect;
56    shape_cb_size = 1<<params->shape_bits;
57    shape_cb = params->shape_cb;
58    resp = PUSH(stack, shape_cb_size*subvect_size);
59    t = PUSH(stack, nsf);
60    r = PUSH(stack, nsf);
61    e = PUSH(stack, nsf);
62    ind = (int*)PUSH(stack, nb_subvect);
63
64    for (i=0;i<nsf;i++)
65       t[i]=target[i];
66
67    e[0]=1;
68    for (i=1;i<nsf;i++)
69       e[i]=0;
70    residue_zero(e, awk1, r, nsf, p);
71    syn_filt_zero(r, ak, r, nsf, p);
72    syn_filt_zero(r, awk2, r, nsf,p);
73    
74    /* Pre-compute codewords response and energy */
75    for (i=0;i<shape_cb_size;i++)
76    {
77       float *res = resp+i*subvect_size;
78
79       /* Compute codeword response */
80       int k;
81       for(j=0;j<subvect_size;j++)
82          res[j]=0;
83       for(j=0;j<subvect_size;j++)
84       {
85          for (k=j;k<subvect_size;k++)
86             res[k]+=shape_cb[i*subvect_size+j]*r[k-j];
87       }
88    }
89
90    for (i=0;i<nb_subvect;i++)
91    {
92       int best_index=0, k, m;
93       float g, dist, best_dist=-1;
94       float *a, *b;
95
96       /* Find best codeword for current sub-vector */
97       for (j=0;j<shape_cb_size;j++)
98       {
99          dist=0;
100          a=resp+j*subvect_size;
101          b=t+subvect_size*i;
102          for (k=0;k<subvect_size;k++)
103             dist += (a[k]-b[k])*(a[k]-b[k]);
104          if (dist<best_dist || j==0)
105          {
106             best_dist=dist;
107             best_index=j;
108          }
109       }
110       /*printf ("best index: %d/%d\n", best_index, shape_cb_size);*/
111       speex_bits_pack(bits,best_index,params->shape_bits);
112
113       ind[i]=best_index;
114       /* Update target for next subvector */
115       for (j=0;j<subvect_size;j++)
116       {
117          g=shape_cb[best_index*subvect_size+j];
118          for (k=subvect_size*i+j,m=0;k<nsf;k++,m++)
119             t[k] -= g*r[m];
120       }
121
122    }
123    
124    /* Put everything back together */
125    for (i=0;i<nb_subvect;i++)
126       for (j=0;j<subvect_size;j++)
127          e[subvect_size*i+j]=shape_cb[ind[i]*subvect_size+j];
128
129    /* Update excitation */
130    for (j=0;j<nsf;j++)
131       exc[j]+=e[j];
132    
133    /* Update target */
134    residue_zero(e, awk1, r, nsf, p);
135    syn_filt_zero(r, ak, r, nsf, p);
136    syn_filt_zero(r, awk2, r, nsf,p);
137    for (j=0;j<nsf;j++)
138       target[j]-=r[j];
139
140    
141
142    POP(stack);
143    POP(stack);
144    POP(stack);
145    POP(stack);
146    POP(stack);
147 }
148
149
150 void split_cb_search_nogain2(
151 float target[],                 /* target vector */
152 float ak[],                     /* LPCs for this subframe */
153 float awk1[],                   /* Weighted LPCs for this subframe */
154 float awk2[],                   /* Weighted LPCs for this subframe */
155 void *par,                      /* Codebook/search parameters*/
156 int   p,                        /* number of LPC coeffs */
157 int   nsf,                      /* number of samples in subframe */
158 float *exc,
159 SpeexBits *bits,
160 float *stack
161 )
162 {
163    int i,j;
164    float *resp;
165    float *t, *r, *e, *E;
166    int *ind;
167    float *shape_cb;
168    int shape_cb_size, subvect_size, nb_subvect;
169    split_cb_params *params;
170
171    params = (split_cb_params *) par;
172    subvect_size = params->subvect_size;
173    nb_subvect = params->nb_subvect;
174    shape_cb_size = 1<<params->shape_bits;
175    shape_cb = params->shape_cb;
176    resp = PUSH(stack, shape_cb_size*subvect_size);
177    t = PUSH(stack, nsf);
178    r = PUSH(stack, nsf);
179    e = PUSH(stack, nsf);
180    E = PUSH(stack, shape_cb_size);
181    ind = (int*)PUSH(stack, nb_subvect);
182
183    for (i=0;i<nsf;i++)
184       t[i]=target[i];
185
186    e[0]=1;
187    for (i=1;i<nsf;i++)
188       e[i]=0;
189    residue_zero(e, awk1, r, nsf, p);
190    syn_filt_zero(r, ak, r, nsf, p);
191    syn_filt_zero(r, awk2, r, nsf,p);
192    
193    /* Pre-compute codewords response and energy */
194    for (i=0;i<shape_cb_size;i++)
195    {
196       float *res = resp+i*subvect_size;
197
198       /* Compute codeword response */
199       int k;
200       for(j=0;j<subvect_size;j++)
201          res[j]=0;
202       for(j=0;j<subvect_size;j++)
203       {
204          for (k=j;k<subvect_size;k++)
205             res[k]+=shape_cb[i*subvect_size+j]*r[k-j];
206       }
207       E[i]=0;
208       for(j=0;j<subvect_size;j++)
209          E[i]+=res[j]*res[j];
210    }
211
212    for (i=0;i<nb_subvect;i++)
213    {
214       int best_index[2]={0,0}, k, m;
215       float g, best_dist[2]={-1,-1};
216       float *x;
217       float energy=0;
218       x=t+subvect_size*i;
219
220       for (k=0;k<subvect_size;k++)
221          energy+=x[k]*x[k];
222       /* Find best codewords for current sub-vector */
223       vq_nbest(x, resp, subvect_size, shape_cb_size, E, 2, best_index, best_dist);
224       if (i<nb_subvect-1)
225       {
226          int nbest;
227          float *tt, err[2];
228          float best_score[2];
229          tt=PUSH(stack,nsf);
230          for (nbest=0;nbest<2;nbest++)
231          {
232             for (j=0;j<nsf;j++)
233                tt[j]=t[j];
234             for (j=0;j<subvect_size;j++)
235             {
236                g=shape_cb[best_index[nbest]*subvect_size+j];
237                for (k=subvect_size*i+j,m=0;k<nsf;k++,m++)
238                   tt[k] -= g*r[m];
239             }
240             
241             {
242                float dd;
243                int i2;
244                vq_nbest(&tt[subvect_size*(i+1)], resp, subvect_size, shape_cb_size, E, 1, &i2, &dd);
245                for (j=0;j<subvect_size;j++)
246                {
247                   g=shape_cb[i2*subvect_size+j];
248                   for (k=subvect_size*(i+1)+j,m=0;k<nsf;k++,m++)
249                      tt[k] -= g*r[m];
250                }
251             }
252
253             err[nbest]=0;
254             for (j=subvect_size*i;j<subvect_size*(i+2);j++)
255                err[nbest]-=tt[j]*tt[j];
256             
257             best_score[nbest]=err[nbest];
258          }
259
260          if (best_score[1]>best_score[0])
261          {
262             best_index[0]=best_index[1];
263             best_score[0]=best_score[1];
264          }
265          POP(stack);
266
267       }
268
269       ind[i]=best_index[0];
270
271       /*printf ("best index: %d/%d\n", best_index, shape_cb_size);*/
272       speex_bits_pack(bits,ind[i],params->shape_bits);
273
274       /* Update target for next subvector */
275       for (j=0;j<subvect_size;j++)
276       {
277          g=shape_cb[ind[i]*subvect_size+j];
278          for (k=subvect_size*i+j,m=0;k<nsf;k++,m++)
279             t[k] -= g*r[m];
280       }
281    }
282    
283    /* Put everything back together */
284    for (i=0;i<nb_subvect;i++)
285       for (j=0;j<subvect_size;j++)
286          e[subvect_size*i+j]=shape_cb[ind[i]*subvect_size+j];
287
288    /* Update excitation */
289    for (j=0;j<nsf;j++)
290       exc[j]+=e[j];
291    
292    /* Update target */
293    residue_zero(e, awk1, r, nsf, p);
294    syn_filt_zero(r, ak, r, nsf, p);
295    syn_filt_zero(r, awk2, r, nsf,p);
296    for (j=0;j<nsf;j++)
297       target[j]-=r[j];
298
299    
300    POP(stack);
301    POP(stack);
302    POP(stack);
303    POP(stack);
304    POP(stack);
305    POP(stack);
306 }
307
308 void split_cb_search_shape_sign(
309 float target[],                 /* target vector */
310 float ak[],                     /* LPCs for this subframe */
311 float awk1[],                   /* Weighted LPCs for this subframe */
312 float awk2[],                   /* Weighted LPCs for this subframe */
313 void *par,                      /* Codebook/search parameters*/
314 int   p,                        /* number of LPC coeffs */
315 int   nsf,                      /* number of samples in subframe */
316 float *exc,
317 SpeexBits *bits,
318 float *stack
319 )
320 {
321    int i,j;
322    float *resp;
323    float *t, *r, *e, *E;
324    int *ind, *signs;
325    float *shape_cb;
326    int shape_cb_size, subvect_size, nb_subvect;
327    split_cb_params *params;
328
329    params = (split_cb_params *) par;
330    subvect_size = params->subvect_size;
331    nb_subvect = params->nb_subvect;
332    shape_cb_size = 1<<params->shape_bits;
333    shape_cb = params->shape_cb;
334    resp = PUSH(stack, shape_cb_size*subvect_size);
335    t = PUSH(stack, nsf);
336    r = PUSH(stack, nsf);
337    e = PUSH(stack, nsf);
338    E = PUSH(stack, shape_cb_size);
339    ind = (int*)PUSH(stack, nb_subvect);
340    signs = (int*)PUSH(stack, nb_subvect);
341
342    for (i=0;i<nsf;i++)
343       t[i]=target[i];
344
345    e[0]=1;
346    for (i=1;i<nsf;i++)
347       e[i]=0;
348    residue_zero(e, awk1, r, nsf, p);
349    syn_filt_zero(r, ak, r, nsf, p);
350    syn_filt_zero(r, awk2, r, nsf,p);
351    
352    /* Pre-compute codewords response and energy */
353    for (i=0;i<shape_cb_size;i++)
354    {
355       float *res = resp+i*subvect_size;
356
357       /* Compute codeword response */
358       int k;
359       for(j=0;j<subvect_size;j++)
360          res[j]=0;
361       for(j=0;j<subvect_size;j++)
362       {
363          for (k=j;k<subvect_size;k++)
364             res[k]+=shape_cb[i*subvect_size+j]*r[k-j];
365       }
366       E[i]=0;
367       for(j=0;j<subvect_size;j++)
368          E[i]+=res[j]*res[j];
369    }
370
371    for (i=0;i<nb_subvect;i++)
372    {
373       int best_index[2]={0,0}, k, m;
374       float g, dist, best_dist[2]={-1,-1}, best_sign[2]={0,0};
375       float *a, *x;
376       float energy=0;
377       x=t+subvect_size*i;
378
379       for (k=0;k<subvect_size;k++)
380          energy+=x[k]*x[k];
381       /* Find best codeword for current sub-vector */
382       for (j=0;j<shape_cb_size;j++)
383       {
384          int sign;
385          dist=0;
386          a=resp+j*subvect_size;
387          dist=0;
388          for (k=0;k<subvect_size;k++)
389             dist -= 2*a[k]*x[k];
390          if (dist > 0)
391          {
392             sign=1;
393             dist =- dist;
394          } else
395             sign=0;
396          dist += energy+E[j];
397          if (dist<best_dist[0] || best_dist[0]<0)
398          {
399             best_dist[1]=best_dist[0];
400             best_index[1]=best_index[0];
401             best_sign[1]=best_sign[0];
402             best_dist[0]=dist;
403             best_index[0]=j;
404             best_sign[0]=sign;
405          } else if (dist<best_dist[1] || best_dist[1]<0)
406          {
407             best_dist[1]=dist;
408             best_index[1]=j;
409             best_sign[1]=sign;
410          }
411       }
412       if (i<nb_subvect-1)
413       {
414          int nbest;
415          float *tt, err[2];
416          float best_score[2];
417          tt=PUSH(stack,nsf);
418          for (nbest=0;nbest<2;nbest++)
419          {
420             float s=1;
421             if (best_sign[nbest])
422                s=-1;
423             for (j=0;j<nsf;j++)
424                tt[j]=t[j];
425             for (j=0;j<subvect_size;j++)
426             {
427                g=s*shape_cb[best_index[nbest]*subvect_size+j];
428                for (k=subvect_size*i+j,m=0;k<nsf;k++,m++)
429                   tt[k] -= g*r[m];
430             }
431             
432             {
433                int best_index2=0, best_sign2=0, sign2;
434                float  best_dist2=0;
435                x=t+subvect_size*(i+1);
436                for (j=0;j<shape_cb_size;j++)
437                {
438                   a=resp+j*subvect_size;
439                   dist = 0;
440                   for (k=0;k<subvect_size;k++)
441                      dist -= 2*a[k]*x[k];
442                   if (dist > 0)
443                   {
444                      sign2=1;
445                      dist =- dist;
446                   } else
447                      sign2=0;
448                   dist += energy+E[j];
449                   if (dist<best_dist2 || j==0)
450                   {
451                      best_dist2=dist;
452                      best_index2=j;
453                      best_sign2=sign2;
454                   }
455                }
456                s=1;
457                if (best_sign2)
458                   s=-1;
459                /*int i2=vq_index(&tt[subvect_size*(i+1)], resp, subvect_size, shape_cb_size);*/
460                
461                for (j=0;j<subvect_size;j++)
462                {
463                   g=s*shape_cb[best_index2*subvect_size+j];
464                   for (k=subvect_size*(i+1)+j,m=0;k<nsf;k++,m++)
465                      tt[k] -= g*r[m];
466                }
467             }
468
469             err[nbest]=0;
470             for (j=subvect_size*i;j<subvect_size*(i+2);j++)
471                err[nbest]-=tt[j]*tt[j];
472             
473             best_score[nbest]=err[nbest];
474          }
475
476          if (best_score[1]>best_score[0])
477          {
478             best_sign[0]=best_sign[1];
479             best_index[0]=best_index[1];
480             best_score[0]=best_score[1];
481          }
482          POP(stack);
483
484       }
485
486       ind[i]=best_index[0];
487       signs[i] = best_sign[0];
488
489       /*printf ("best index: %d/%d\n", best_index, shape_cb_size);*/
490       speex_bits_pack(bits,signs[i],1);
491       speex_bits_pack(bits,ind[i],params->shape_bits);
492
493       /* Update target for next subvector */
494       for (j=0;j<subvect_size;j++)
495       {
496          g=shape_cb[ind[i]*subvect_size+j];
497          if (signs[i])
498             g=-g;
499          for (k=subvect_size*i+j,m=0;k<nsf;k++,m++)
500             t[k] -= g*r[m];
501       }
502    }
503    
504    /* Put everything back together */
505    for (i=0;i<nb_subvect;i++)
506    {
507       float s=1;
508       if (signs[i])
509          s=-1;
510       for (j=0;j<subvect_size;j++)
511          e[subvect_size*i+j]=s*shape_cb[ind[i]*subvect_size+j];
512    }
513    /* Update excitation */
514    for (j=0;j<nsf;j++)
515       exc[j]+=e[j];
516    
517    /* Update target */
518    residue_zero(e, awk1, r, nsf, p);
519    syn_filt_zero(r, ak, r, nsf, p);
520    syn_filt_zero(r, awk2, r, nsf,p);
521    for (j=0;j<nsf;j++)
522       target[j]-=r[j];
523
524    
525    POP(stack);
526    POP(stack);
527    POP(stack);
528    POP(stack);
529    POP(stack);
530    POP(stack);
531    POP(stack);
532 }
533
534
535 void split_cb_nogain_unquant(
536 float *exc,
537 void *par,                      /* non-overlapping codebook */
538 int   nsf,                      /* number of samples in subframe */
539 SpeexBits *bits,
540 float *stack
541 )
542 {
543    int i,j;
544    int *ind;
545    float *shape_cb;
546    int shape_cb_size, subvect_size, nb_subvect;
547    split_cb_params *params;
548
549    params = (split_cb_params *) par;
550    subvect_size = params->subvect_size;
551    nb_subvect = params->nb_subvect;
552    shape_cb_size = 1<<params->shape_bits;
553    shape_cb = params->shape_cb;
554    
555    ind = (int*)PUSH(stack, nb_subvect);
556
557    /* Decode codewords and gains */
558    for (i=0;i<nb_subvect;i++)
559       ind[i] = speex_bits_unpack_unsigned(bits, params->shape_bits);
560
561    /* Compute decoded excitation */
562    for (i=0;i<nb_subvect;i++)
563       for (j=0;j<subvect_size;j++)
564          exc[subvect_size*i+j]+=shape_cb[ind[i]*subvect_size+j];
565
566    POP(stack);
567 }
568
569 void split_cb_shape_sign_unquant(
570 float *exc,
571 void *par,                      /* non-overlapping codebook */
572 int   nsf,                      /* number of samples in subframe */
573 SpeexBits *bits,
574 float *stack
575 )
576 {
577    int i,j;
578    int *ind, *signs;
579    float *shape_cb;
580    int shape_cb_size, subvect_size, nb_subvect;
581    split_cb_params *params;
582
583    params = (split_cb_params *) par;
584    subvect_size = params->subvect_size;
585    nb_subvect = params->nb_subvect;
586    shape_cb_size = 1<<params->shape_bits;
587    shape_cb = params->shape_cb;
588    
589    ind = (int*)PUSH(stack, nb_subvect);
590    signs = (int*)PUSH(stack, nb_subvect);
591
592    /* Decode codewords and gains */
593    for (i=0;i<nb_subvect;i++)
594    {
595       signs[i] = speex_bits_unpack_unsigned(bits, 1);
596       ind[i] = speex_bits_unpack_unsigned(bits, params->shape_bits);
597    }
598    /* Compute decoded excitation */
599    for (i=0;i<nb_subvect;i++)
600    {
601       float s=1;
602       if (signs[i])
603          s=-1;
604       for (j=0;j<subvect_size;j++)
605          exc[subvect_size*i+j]+=s*shape_cb[ind[i]*subvect_size+j];
606    }
607    POP(stack);
608    POP(stack);
609 }