Cleaned up vocoder mode...
[speexdsp.git] / libspeex / cb_search.c
1 /* Copyright (C) 2002 Jean-Marc Valin 
2    File: cb_search.c
3
4    This library is free software; you can redistribute it and/or
5    modify it under the terms of the GNU Lesser General Public
6    License as published by the Free Software Foundation; either
7    version 2.1 of the License, or (at your option) any later version.
8    
9    This library is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12    Lesser General Public License for more details.
13    
14    You should have received a copy of the GNU Lesser General Public
15    License along with this library; if not, write to the Free Software
16    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
17 */
18
19
20
21 #include <stdlib.h>
22 #include "cb_search.h"
23 #include "filters.h"
24 #include <math.h>
25 #ifdef DEBUG
26 #include <stdio.h>
27 #endif
28 #include "stack_alloc.h"
29 #include "vq.h"
30
31 #define min(a,b) ((a) < (b) ? (a) : (b))
32 #define max(a,b) ((a) > (b) ? (a) : (b))
33
34
35
36 void split_cb_search_nogain(
37 float target[],                 /* target vector */
38 float ak[],                     /* LPCs for this subframe */
39 float awk1[],                   /* Weighted LPCs for this subframe */
40 float awk2[],                   /* Weighted LPCs for this subframe */
41 void *par,                      /* Codebook/search parameters*/
42 int   p,                        /* number of LPC coeffs */
43 int   nsf,                      /* number of samples in subframe */
44 float *exc,
45 SpeexBits *bits,
46 float *stack,
47 int   complexity
48 )
49 {
50    int i,j,k,m,n,q;
51    float *resp;
52    float *t, *r, *e, *E;
53    /*FIXME: Should make this dynamic*/
54    float *tmp, *ot[20], *nt[20];
55    float *ndist, *odist;
56    int *itmp, *nind[20], *oind[20];
57
58    int *ind;
59    float *shape_cb;
60    int shape_cb_size, subvect_size, nb_subvect;
61    split_cb_params *params;
62    int N=2;
63    int *best_index;
64    float *best_dist;
65
66    N=complexity;
67    if (N<1)
68       N=1;
69    if (N>10)
70       N=10;
71
72    params = (split_cb_params *) par;
73    subvect_size = params->subvect_size;
74    nb_subvect = params->nb_subvect;
75    shape_cb_size = 1<<params->shape_bits;
76    shape_cb = params->shape_cb;
77    resp = PUSH(stack, shape_cb_size*subvect_size);
78    t = PUSH(stack, nsf);
79    r = PUSH(stack, nsf);
80    e = PUSH(stack, nsf);
81    E = PUSH(stack, shape_cb_size);
82    ind = (int*)PUSH(stack, nb_subvect);
83
84    tmp = PUSH(stack, 2*N*nsf);
85    for (i=0;i<N;i++)
86    {
87       ot[i]=tmp;
88       tmp += nsf;
89       nt[i]=tmp;
90       tmp += nsf;
91    }
92
93    best_index = (int*)PUSH(stack, N);
94    best_dist = PUSH(stack, N);
95    ndist = PUSH(stack, N);
96    odist = PUSH(stack, N);
97    
98    itmp = (int*)PUSH(stack, 2*N*nb_subvect);
99    for (i=0;i<N;i++)
100    {
101       nind[i]=itmp;
102       itmp+=nb_subvect;
103       oind[i]=itmp;
104       itmp+=nb_subvect;
105       for (j=0;j<nb_subvect;j++)
106          nind[i][j]=oind[i][j]=-1;
107    }
108
109    for (j=0;j<N;j++)
110       for (i=0;i<nsf;i++)
111          ot[j][i]=target[i];
112
113    for (i=0;i<nsf;i++)
114       t[i]=target[i];
115
116    e[0]=1;
117    for (i=1;i<nsf;i++)
118       e[i]=0;
119    residue_zero(e, awk1, r, nsf, p);
120    syn_filt_zero(r, ak, r, nsf, p);
121    syn_filt_zero(r, awk2, r, nsf,p);
122    
123    /* Pre-compute codewords response and energy */
124    for (i=0;i<shape_cb_size;i++)
125    {
126       float *res = resp+i*subvect_size;
127
128       /* Compute codeword response */
129       int k;
130       for(j=0;j<subvect_size;j++)
131          res[j]=0;
132       for(j=0;j<subvect_size;j++)
133       {
134          for (k=j;k<subvect_size;k++)
135             res[k]+=shape_cb[i*subvect_size+j]*r[k-j];
136       }
137       E[i]=0;
138       for(j=0;j<subvect_size;j++)
139          E[i]+=res[j]*res[j];
140    }
141
142    for (j=0;j<N;j++)
143       odist[j]=0;
144    /*For all subvectors*/
145    for (i=0;i<nb_subvect;i++)
146    {
147       /*"erase" nbest list*/
148       for (j=0;j<N;j++)
149          ndist[j]=-1;
150
151       /*For all n-bests of previous subvector*/
152       for (j=0;j<N;j++)
153       {
154          float *x=ot[j]+subvect_size*i;
155          /*Find new n-best based on previous n-best j*/
156          vq_nbest(x, resp, subvect_size, shape_cb_size, E, N, best_index, best_dist);
157
158          /*For all new n-bests*/
159          for (k=0;k<N;k++)
160          {
161             float err=0;
162             /*previous target*/
163             for (m=0;m<nsf;m++)
164                t[m]=ot[j][m];
165             /*update target*/
166             for (m=0;m<subvect_size;m++)
167             {
168                float g=shape_cb[best_index[k]*subvect_size+m];
169                for (n=subvect_size*i+m,q=0;n<nsf;n++,q++)
170                   t[n] -= g*r[q];
171             }
172             
173             /*compute error (distance)*/
174             err=odist[j];
175             for (m=i*subvect_size;m<(i+1)*subvect_size;m++)
176                err += t[m]*t[m];
177             /*update n-best list*/
178             if (err<ndist[N-1] || ndist[N-1]<-.5)
179             {
180                for (m=0;m<N;m++)
181                {
182                   if (err < ndist[m] || ndist[m]<-.5)
183                   {
184                      for (n=N-1;n>m;n--)
185                      {
186                         for (q=0;q<nsf;q++)
187                            nt[n][q]=nt[n-1][q];
188                         for (q=0;q<nb_subvect;q++)
189                            nind[n][q]=nind[n-1][q];
190                         ndist[n]=ndist[n-1];
191                      }
192                      for (q=0;q<nsf;q++)
193                         nt[m][q]=t[q];
194                      for (q=0;q<nb_subvect;q++)
195                         nind[m][q]=oind[j][q];
196                      nind[m][i]=best_index[k];
197                      ndist[m]=err;
198                      break;
199                   }
200                }
201             }
202          }
203          if (i==0)
204            break;
205       }
206
207       /*opdate old-new data*/
208       for (j=0;j<N;j++)
209          for (m=0;m<nsf;m++)
210             ot[j][m]=nt[j][m];
211       for (j=0;j<N;j++)
212          for (m=0;m<nb_subvect;m++)
213             oind[j][m]=nind[j][m];
214       for (j=0;j<N;j++)
215          odist[j]=ndist[j];
216    }
217
218    /*save indices*/
219    for (i=0;i<nb_subvect;i++)
220    {
221       ind[i]=nind[0][i];
222       speex_bits_pack(bits,ind[i],params->shape_bits);
223    }
224    
225    /* Put everything back together */
226    for (i=0;i<nb_subvect;i++)
227       for (j=0;j<subvect_size;j++)
228          e[subvect_size*i+j]=shape_cb[ind[i]*subvect_size+j];
229
230    /* Update excitation */
231    for (j=0;j<nsf;j++)
232       exc[j]+=e[j];
233    
234    /* Update target */
235    residue_zero(e, awk1, r, nsf, p);
236    syn_filt_zero(r, ak, r, nsf, p);
237    syn_filt_zero(r, awk2, r, nsf,p);
238    for (j=0;j<nsf;j++)
239       target[j]-=r[j];
240
241    POP(stack);
242    POP(stack);
243    POP(stack);
244    POP(stack);
245    POP(stack);
246    POP(stack);
247    POP(stack);
248    POP(stack);
249    POP(stack);
250    POP(stack);
251    POP(stack);
252    POP(stack);
253 }
254
255
256
257
258
259 void split_cb_search_shape_sign(
260 float target[],                 /* target vector */
261 float ak[],                     /* LPCs for this subframe */
262 float awk1[],                   /* Weighted LPCs for this subframe */
263 float awk2[],                   /* Weighted LPCs for this subframe */
264 void *par,                      /* Codebook/search parameters*/
265 int   p,                        /* number of LPC coeffs */
266 int   nsf,                      /* number of samples in subframe */
267 float *exc,
268 SpeexBits *bits,
269 float *stack,
270 int   complexity
271 )
272 {
273    int i,j;
274    float *resp;
275    float *t, *r, *e, *E;
276    int *ind, *signs;
277    float *shape_cb;
278    int shape_cb_size, subvect_size, nb_subvect;
279    split_cb_params *params;
280
281    params = (split_cb_params *) par;
282    subvect_size = params->subvect_size;
283    nb_subvect = params->nb_subvect;
284    shape_cb_size = 1<<params->shape_bits;
285    shape_cb = params->shape_cb;
286    resp = PUSH(stack, shape_cb_size*subvect_size);
287    t = PUSH(stack, nsf);
288    r = PUSH(stack, nsf);
289    e = PUSH(stack, nsf);
290    E = PUSH(stack, shape_cb_size);
291    ind = (int*)PUSH(stack, nb_subvect);
292    signs = (int*)PUSH(stack, nb_subvect);
293
294    for (i=0;i<nsf;i++)
295       t[i]=target[i];
296
297    e[0]=1;
298    for (i=1;i<nsf;i++)
299       e[i]=0;
300    residue_zero(e, awk1, r, nsf, p);
301    syn_filt_zero(r, ak, r, nsf, p);
302    syn_filt_zero(r, awk2, r, nsf,p);
303    
304    /* Pre-compute codewords response and energy */
305    for (i=0;i<shape_cb_size;i++)
306    {
307       float *res = resp+i*subvect_size;
308
309       /* Compute codeword response */
310       int k;
311       for(j=0;j<subvect_size;j++)
312          res[j]=0;
313       for(j=0;j<subvect_size;j++)
314       {
315          for (k=j;k<subvect_size;k++)
316             res[k]+=shape_cb[i*subvect_size+j]*r[k-j];
317       }
318       E[i]=0;
319       for(j=0;j<subvect_size;j++)
320          E[i]+=res[j]*res[j];
321    }
322
323    for (i=0;i<nb_subvect;i++)
324    {
325       int best_index[2]={0,0}, k, m;
326       float g, dist, best_dist[2]={-1,-1}, best_sign[2]={0,0};
327       float *a, *x;
328       float energy=0;
329       x=t+subvect_size*i;
330
331       for (k=0;k<subvect_size;k++)
332          energy+=x[k]*x[k];
333       /* Find best codeword for current sub-vector */
334       for (j=0;j<shape_cb_size;j++)
335       {
336          int sign;
337          dist=0;
338          a=resp+j*subvect_size;
339          dist=0;
340          for (k=0;k<subvect_size;k++)
341             dist -= 2*a[k]*x[k];
342          if (dist > 0)
343          {
344             sign=1;
345             dist =- dist;
346          } else
347             sign=0;
348          dist += energy+E[j];
349          if (dist<best_dist[0] || best_dist[0]<0)
350          {
351             best_dist[1]=best_dist[0];
352             best_index[1]=best_index[0];
353             best_sign[1]=best_sign[0];
354             best_dist[0]=dist;
355             best_index[0]=j;
356             best_sign[0]=sign;
357          } else if (dist<best_dist[1] || best_dist[1]<0)
358          {
359             best_dist[1]=dist;
360             best_index[1]=j;
361             best_sign[1]=sign;
362          }
363       }
364       if (i<nb_subvect-1)
365       {
366          int nbest;
367          float *tt, err[2];
368          float best_score[2];
369          tt=PUSH(stack,nsf);
370          for (nbest=0;nbest<2;nbest++)
371          {
372             float s=1;
373             if (best_sign[nbest])
374                s=-1;
375             for (j=0;j<nsf;j++)
376                tt[j]=t[j];
377             for (j=0;j<subvect_size;j++)
378             {
379                g=s*shape_cb[best_index[nbest]*subvect_size+j];
380                for (k=subvect_size*i+j,m=0;k<nsf;k++,m++)
381                   tt[k] -= g*r[m];
382             }
383             
384             {
385                int best_index2=0, best_sign2=0, sign2;
386                float  best_dist2=0;
387                x=t+subvect_size*(i+1);
388                for (j=0;j<shape_cb_size;j++)
389                {
390                   a=resp+j*subvect_size;
391                   dist = 0;
392                   for (k=0;k<subvect_size;k++)
393                      dist -= 2*a[k]*x[k];
394                   if (dist > 0)
395                   {
396                      sign2=1;
397                      dist =- dist;
398                   } else
399                      sign2=0;
400                   dist += energy+E[j];
401                   if (dist<best_dist2 || j==0)
402                   {
403                      best_dist2=dist;
404                      best_index2=j;
405                      best_sign2=sign2;
406                   }
407                }
408                s=1;
409                if (best_sign2)
410                   s=-1;
411                /*int i2=vq_index(&tt[subvect_size*(i+1)], resp, subvect_size, shape_cb_size);*/
412                
413                for (j=0;j<subvect_size;j++)
414                {
415                   g=s*shape_cb[best_index2*subvect_size+j];
416                   for (k=subvect_size*(i+1)+j,m=0;k<nsf;k++,m++)
417                      tt[k] -= g*r[m];
418                }
419             }
420
421             err[nbest]=0;
422             for (j=subvect_size*i;j<subvect_size*(i+2);j++)
423                err[nbest]-=tt[j]*tt[j];
424             
425             best_score[nbest]=err[nbest];
426          }
427
428          if (best_score[1]>best_score[0])
429          {
430             best_sign[0]=best_sign[1];
431             best_index[0]=best_index[1];
432             best_score[0]=best_score[1];
433          }
434          POP(stack);
435
436       }
437
438       ind[i]=best_index[0];
439       signs[i] = best_sign[0];
440
441       /*printf ("best index: %d/%d\n", best_index, shape_cb_size);*/
442       speex_bits_pack(bits,signs[i],1);
443       speex_bits_pack(bits,ind[i],params->shape_bits);
444
445       /* Update target for next subvector */
446       for (j=0;j<subvect_size;j++)
447       {
448          g=shape_cb[ind[i]*subvect_size+j];
449          if (signs[i])
450             g=-g;
451          for (k=subvect_size*i+j,m=0;k<nsf;k++,m++)
452             t[k] -= g*r[m];
453       }
454    }
455    
456    /* Put everything back together */
457    for (i=0;i<nb_subvect;i++)
458    {
459       float s=1;
460       if (signs[i])
461          s=-1;
462       for (j=0;j<subvect_size;j++)
463          e[subvect_size*i+j]=s*shape_cb[ind[i]*subvect_size+j];
464    }
465    /* Update excitation */
466    for (j=0;j<nsf;j++)
467       exc[j]+=e[j];
468    
469    /* Update target */
470    residue_zero(e, awk1, r, nsf, p);
471    syn_filt_zero(r, ak, r, nsf, p);
472    syn_filt_zero(r, awk2, r, nsf,p);
473    for (j=0;j<nsf;j++)
474       target[j]-=r[j];
475
476    
477    POP(stack);
478    POP(stack);
479    POP(stack);
480    POP(stack);
481    POP(stack);
482    POP(stack);
483    POP(stack);
484 }
485
486
487 void split_cb_nogain_unquant(
488 float *exc,
489 void *par,                      /* non-overlapping codebook */
490 int   nsf,                      /* number of samples in subframe */
491 SpeexBits *bits,
492 float *stack
493 )
494 {
495    int i,j;
496    int *ind;
497    float *shape_cb;
498    int shape_cb_size, subvect_size, nb_subvect;
499    split_cb_params *params;
500
501    params = (split_cb_params *) par;
502    subvect_size = params->subvect_size;
503    nb_subvect = params->nb_subvect;
504    shape_cb_size = 1<<params->shape_bits;
505    shape_cb = params->shape_cb;
506    
507    ind = (int*)PUSH(stack, nb_subvect);
508
509    /* Decode codewords and gains */
510    for (i=0;i<nb_subvect;i++)
511    {
512       ind[i] = speex_bits_unpack_unsigned(bits, params->shape_bits);
513       ind[i] = rand()%(1<<params->shape_bits);
514    }
515    /* Compute decoded excitation */
516    for (i=0;i<nb_subvect;i++)
517       for (j=0;j<subvect_size;j++)
518          exc[subvect_size*i+j]+=shape_cb[ind[i]*subvect_size+j];
519
520    POP(stack);
521 }
522
523 void split_cb_shape_sign_unquant(
524 float *exc,
525 void *par,                      /* non-overlapping codebook */
526 int   nsf,                      /* number of samples in subframe */
527 SpeexBits *bits,
528 float *stack
529 )
530 {
531    int i,j;
532    int *ind, *signs;
533    float *shape_cb;
534    int shape_cb_size, subvect_size, nb_subvect;
535    split_cb_params *params;
536
537    params = (split_cb_params *) par;
538    subvect_size = params->subvect_size;
539    nb_subvect = params->nb_subvect;
540    shape_cb_size = 1<<params->shape_bits;
541    shape_cb = params->shape_cb;
542    
543    ind = (int*)PUSH(stack, nb_subvect);
544    signs = (int*)PUSH(stack, nb_subvect);
545
546    /* Decode codewords and gains */
547    for (i=0;i<nb_subvect;i++)
548    {
549       signs[i] = speex_bits_unpack_unsigned(bits, 1);
550       ind[i] = speex_bits_unpack_unsigned(bits, params->shape_bits);
551    }
552    /* Compute decoded excitation */
553    for (i=0;i<nb_subvect;i++)
554    {
555       float s=1;
556       if (signs[i])
557          s=-1;
558       for (j=0;j<subvect_size;j++)
559          exc[subvect_size*i+j]+=s*shape_cb[ind[i]*subvect_size+j];
560    }
561    POP(stack);
562    POP(stack);
563 }
564
565 void noise_codebook_quant(
566 float target[],                 /* target vector */
567 float ak[],                     /* LPCs for this subframe */
568 float awk1[],                   /* Weighted LPCs for this subframe */
569 float awk2[],                   /* Weighted LPCs for this subframe */
570 void *par,                      /* Codebook/search parameters*/
571 int   p,                        /* number of LPC coeffs */
572 int   nsf,                      /* number of samples in subframe */
573 float *exc,
574 SpeexBits *bits,
575 float *stack,
576 int   complexity
577 )
578 {
579    int i;
580    float *tmp=PUSH(stack, nsf);
581    syn_filt_zero(target, awk1, tmp, nsf, p);
582    residue_zero(tmp, ak, tmp, nsf, p);
583    residue_zero(tmp, awk2, tmp, nsf, p);
584
585    for (i=0;i<nsf;i++)
586       exc[i]+=tmp[i];
587    for (i=0;i<nsf;i++)
588       target[i]=0;
589
590    POP(stack);
591 }
592
593
594 void noise_codebook_unquant(
595 float *exc,
596 void *par,                      /* non-overlapping codebook */
597 int   nsf,                      /* number of samples in subframe */
598 SpeexBits *bits,
599 float *stack
600 )
601 {
602    int i;
603
604    for (i=0;i<nsf;i++)
605       exc[i]+=2.5*((((float)rand())/RAND_MAX)-.5);
606 }