More cleanup in codebook search...
[speexdsp.git] / libspeex / cb_search.c
1 /*-----------------------------------------------------------------------*\
2
3     FILE........: GAINSHAPE.C
4     TYPE........: C Module
5     AUTHOR......: David Rowe
6     COMPANY.....: Voicetronix
7     DATE CREATED: 19/2/02
8
9     General gain-shape codebook search.
10
11 \*-----------------------------------------------------------------------*/
12
13 /* Modified by Jean-Marc Valin 2002
14
15    This library is free software; you can redistribute it and/or
16    modify it under the terms of the GNU Lesser General Public
17    License as published by the Free Software Foundation; either
18    version 2.1 of the License, or (at your option) any later version.
19    
20    This library is distributed in the hope that it will be useful,
21    but WITHOUT ANY WARRANTY; without even the implied warranty of
22    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
23    Lesser General Public License for more details.
24    
25    You should have received a copy of the GNU Lesser General Public
26    License along with this library; if not, write to the Free Software
27    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
28 */
29
30
31
32 #include <stdlib.h>
33 #include <cb_search.h>
34 #include "filters.h"
35 #include <math.h>
36 #include <stdio.h>
37 #include "stack_alloc.h"
38 #include "vq.h"
39
40 #define EXC_CB_SIZE 128
41 #define min(a,b) ((a) < (b) ? (a) : (b))
42 extern float exc_gains_table[];
43 extern float exc_table[];
44
45 /*---------------------------------------------------------------------------*\
46                                                                              
47  void overlap_cb_search()                                                             
48                                                                               
49  Searches a gain/shape codebook consisting of overlapping entries for the    
50  closest vector to the target.  Gives identical results to search() above   
51  buts uses fast end correction algorithm for the synthesis of response       
52  vectors.                                                                     
53                                                                              
54 \*---------------------------------------------------------------------------*/
55
56 float overlap_cb_search(
57 float target[],                 /* target vector */
58 float ak[],                     /* LPCs for this subframe */
59 float awk1[],                   /* Weighted LPCs for this subframe */
60 float awk2[],                   /* Weighted LPCs for this subframe */
61 float codebook[],               /* overlapping codebook */
62 int   entries,                  /* number of overlapping entries to search */
63 float *gain,                    /* gain of optimum entry */
64 int   *index,                   /* index of optimum entry */
65 int   p,                        /* number of LPC coeffs */
66 int   nsf                       /* number of samples in subframe */
67 )
68 {
69   float *resp;                  /* zero state response to current entry */
70   float *h;                     /* impulse response of synthesis filter */
71   float *impulse;               /* excitation vector containing one impulse */
72   float d,e,g,score;            /* codebook searching variables */
73   float bscore;                 /* score of "best" vector so far */
74   int i,k;                      /* loop variables */
75
76   /* Initialise */
77   
78   resp = (float*)malloc(sizeof(float)*nsf);
79   h = (float*)malloc(sizeof(float)*nsf);
80   impulse = (float*)malloc(sizeof(float)*nsf);
81
82   for(i=0; i<nsf; i++)
83     impulse[i] = 0.0;
84    
85   *gain = 0.0;
86   *index = 0;
87   bscore = 0.0;
88   impulse[0] = 1.0;
89
90   /* Calculate impulse response of  A(z/g2) / ( A(z)*(z/g1) ) */
91   residue_zero(impulse, awk1, h, nsf, p);
92   syn_filt_zero(h, ak, h, nsf, p);
93   syn_filt_zero(h, awk2, h, nsf,p);
94   
95   /* Calculate codebook zero-response */
96   residue_zero(&codebook[entries-1],awk1,resp,nsf,p);
97   syn_filt_zero(resp,ak,resp,nsf,p);
98   syn_filt_zero(resp,awk2,resp,nsf,p);
99     
100   /* Search codebook backwards using end correction for synthesis */
101   
102   for(k=entries-1; k>=0; k--) {
103
104     d = 0.0; e = 0.0;
105     for(i=0; i<nsf; i++) {
106       d += target[i]*resp[i];
107       e += resp[i]*resp[i];
108     }
109     g = d/(e+.1);
110     score = g*d;
111     /*printf ("score: %f %f %f %f\n", target[0],d,e,score);*/
112     if (score >= bscore) {
113       bscore = score;
114       *gain = g;
115       *index = k;
116     }
117     
118     /* Synthesise next entry */
119     
120     if (k) {
121       for(i=nsf-1; i>=1; i--)
122         resp[i] = resp[i-1] + codebook[k-1]*h[i];
123       resp[0] = codebook[k-1]*h[0];
124     }
125   }
126
127   free(resp);
128   free(h);
129   free(impulse);
130   return bscore;
131 }
132
133
134 split_cb_params split_cb_nb = {
135    8,               /*subvect_size*/
136    5,               /*nb_subvect*/
137    exc_table,       /*shape_cb*/
138    7,               /*shape_bits*/
139    exc_gains_table, /*gain_cb*/
140    8                /*gain_bits*/
141 };
142
143
144 void split_cb_search(
145 float target[],                 /* target vector */
146 float ak[],                     /* LPCs for this subframe */
147 float awk1[],                   /* Weighted LPCs for this subframe */
148 float awk2[],                   /* Weighted LPCs for this subframe */
149 void *par,                      /* Codebook/search parameters*/
150 int   p,                        /* number of LPC coeffs */
151 int   nsf,                      /* number of samples in subframe */
152 float *exc,
153 FrameBits *bits,
154 float *stack
155 )
156 {
157    int i,j;
158    float *resp, *E, *Ee;
159    float *t, *r, *e;
160    float *gains;
161    int *ind;
162    float *shape_cb, *gain_cb;
163    int shape_cb_size, gain_cb_size, subvect_size, nb_subvect;
164    split_cb_params *params;
165
166    params = (split_cb_params *) par;
167    subvect_size = params->subvect_size;
168    nb_subvect = params->nb_subvect;
169    shape_cb_size = 1<<params->shape_bits;
170    shape_cb = params->shape_cb;
171    gain_cb_size = 1<<params->gain_bits;
172    gain_cb = params->gain_cb;
173    resp = PUSH(stack, shape_cb_size*8);
174    E = PUSH(stack, shape_cb_size);
175    Ee = PUSH(stack, shape_cb_size);
176    t = PUSH(stack, nsf);
177    r = PUSH(stack, nsf);
178    e = PUSH(stack, nsf);
179    gains = PUSH(stack, nb_subvect);
180    ind = (int*)PUSH(stack, nb_subvect);
181    
182
183    for (i=0;i<nsf;i++)
184       t[i]=target[i];
185    for (i=0;i<shape_cb_size;i++)
186    {
187       float *res = resp+i*subvect_size;
188       residue_zero(shape_cb+i*subvect_size, awk1, res, 8, p);
189       syn_filt_zero(res, ak, res, 8, p);
190       syn_filt_zero(res, awk2, res, 8,p);
191       E[i]=0;
192       for(j=0;j<8;j++)
193          E[i]+=res[j]*res[j];
194       Ee[i]=0;
195       for(j=0;j<8;j++)
196          Ee[i]+=shape_cb[i*subvect_size+j]*shape_cb[i*subvect_size+j];
197       
198    }
199    for (i=0;i<5;i++)
200    {
201       int best_index=0;
202       float g, corr, best_gain=0, score, best_score=-1;
203       for (j=0;j<shape_cb_size;j++)
204       {
205          corr=xcorr(resp+j*subvect_size,t+8*i,8);
206          score=corr*corr/(.001+E[j]);
207          g = corr/(.001+E[j]);
208          if (score>best_score)
209          {
210             best_index=j;
211             best_score=score;
212             best_gain=corr/(.001+E[j]);
213          }
214       }
215       frame_bits_pack(bits,best_index,params->shape_bits);
216       if (best_gain>0)
217          frame_bits_pack(bits,0,1);
218       else
219           frame_bits_pack(bits,1,1);        
220       ind[i]=best_index;
221       gains[i]=best_gain*Ee[ind[i]];
222
223       for (j=0;j<nsf;j++)
224          e[j]=0;
225       for (j=0;j<8;j++)
226          e[8*i+j]=best_gain*shape_cb[best_index*subvect_size+j];
227       residue_zero(e, awk1, r, nsf, p);
228       syn_filt_zero(r, ak, r, nsf, p);
229       syn_filt_zero(r, awk2, r, nsf,p);
230       for (j=0;j<nsf;j++)
231          t[j]-=r[j];
232    }
233
234    {
235       int best_vq_index=0, max_index;
236       float max_gain=0, log_max, min_dist=0, sign[5];
237
238       for (i=0;i<5;i++)
239       {
240          if (gains[i]<0)
241          {
242             gains[i]=-gains[i];
243             sign[i]=-1;
244          } else {
245             sign[i]=1;
246          }
247       }
248       for (i=0;i<5;i++)
249          if (gains[i]>max_gain)
250             max_gain=gains[i];
251       log_max=log(max_gain+1);
252       max_index = (int)(floor(.5+log_max-3));
253       if (max_index>7)
254          max_index=7;
255       if (max_index<0)
256          max_index=0;
257       max_gain=1/exp(max_index+3.0);
258       for (i=0;i<5;i++)
259         gains[i]*=max_gain;
260       frame_bits_pack(bits,max_index,3);
261
262       /*Vector quantize gains[i]*/
263       best_vq_index = vq_index(gains, gain_cb, nb_subvect, gain_cb_size);
264       frame_bits_pack(bits,best_vq_index,params->gain_bits);
265
266       printf ("best_gains_vq_index %d %f %d\n", best_vq_index, min_dist, max_index);
267
268 #if 1 /* If 0, the gains are not quantized */
269       for (i=0;i<5;i++)
270          gains[i]= sign[i]*gain_cb[best_vq_index*nb_subvect+i]/max_gain/(Ee[ind[i]]+.001);
271 #else 
272       for (i=0;i<5;i++)
273          gains[i]= sign[i]*gains[i]/max_gain/(Ee[ind[i]]+.001);
274 #endif  
275     
276       for (i=0;i<5;i++)
277          for (j=0;j<8;j++)
278             exc[8*i+j]+=gains[i]*shape_cb[ind[i]*subvect_size+j];
279    }
280
281    /*TODO: Perform joint optimization of gains*/
282    
283    for (i=0;i<nsf;i++)
284       target[i]=t[i];
285
286    POP(stack);
287    POP(stack);
288    POP(stack);
289    POP(stack);
290    POP(stack);
291    POP(stack);
292    POP(stack);
293    POP(stack);
294 }
295
296 void split_cb_unquant(
297 float *exc,
298 float codebook[][8],            /* non-overlapping codebook */
299 int   nsf,                      /* number of samples in subframe */
300 FrameBits *bits
301 )
302 {
303    int i,j;
304    int ind[5];
305    float gains[5];
306    float sign[5];
307    int max_gain_ind, vq_gain_ind;
308    float max_gain, Ee[5];
309    for (i=0;i<5;i++)
310    {
311       ind[i] = frame_bits_unpack_unsigned(bits, 7);
312       if (frame_bits_unpack_unsigned(bits, 1))
313          sign[i]=-1;
314       else
315          sign[i]=1;
316       Ee[i]=.001;
317       for (j=0;j<8;j++)
318          Ee[i]+=codebook[ind[i]][j]*codebook[ind[i]][j];
319    }
320    max_gain_ind = frame_bits_unpack_unsigned(bits, 3);
321    vq_gain_ind = frame_bits_unpack_unsigned(bits, 8);
322    printf ("unquant gains ind: %d %d\n", max_gain_ind, vq_gain_ind);
323
324    max_gain=exp(max_gain_ind+3.0);
325    for (i=0;i<5;i++)
326       gains[i] = sign[i]*exc_gains_table[vq_gain_ind*5+i]*max_gain/Ee[i];
327    
328    printf ("unquant gains: ");
329    for (i=0;i<5;i++)
330       printf ("%f ", gains[i]);
331    printf ("\n");
332
333    for (i=0;i<5;i++)
334       for (j=0;j<8;j++)
335          exc[8*i+j]+=gains[i]*codebook[ind[i]][j];
336    
337 }