85d00a7940498c6613d8dbb2f33f630a8bd0086e
[speexdsp.git] / libspeex / cb_search.c
1 /* Copyright (C) 2002 Jean-Marc Valin 
2    File: cb_search.c
3
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32
33
34 #include <stdlib.h>
35 #include "cb_search.h"
36 #include "filters.h"
37 #include <math.h>
38 #ifdef DEBUG
39 #include <stdio.h>
40 #endif
41 #include "stack_alloc.h"
42 #include "vq.h"
43
44 #ifndef min
45 # define min(a,b) ((a) < (b) ? (a) : (b))
46 #endif
47 #ifndef max
48 # define max(a,b) ((a) > (b) ? (a) : (b))
49 #endif
50
51
52
53
54 void split_cb_search_shape_sign(
55 float target[],                 /* target vector */
56 float ak[],                     /* LPCs for this subframe */
57 float awk1[],                   /* Weighted LPCs for this subframe */
58 float awk2[],                   /* Weighted LPCs for this subframe */
59 void *par,                      /* Codebook/search parameters*/
60 int   p,                        /* number of LPC coeffs */
61 int   nsf,                      /* number of samples in subframe */
62 float *exc,
63 SpeexBits *bits,
64 float *stack,
65 int   complexity
66 )
67 {
68    int i,j,k,m,n,q;
69    float *resp;
70    float *t, *r, *e, *E;
71    /*FIXME: Should make this dynamic*/
72    float *tmp, *_ot[20], *_nt[20];
73    float *ndist, *odist;
74    int *itmp, *_nind[20], *_oind[20];
75    float **ot, **nt;
76    int **nind, **oind;
77    int *ind;
78    float *shape_cb;
79    int shape_cb_size, subvect_size, nb_subvect;
80    split_cb_params *params;
81    int N=2;
82    int *best_index;
83    float *best_dist;
84    int have_sign;
85
86    ot=_ot;
87    nt=_nt;
88    oind=_oind;
89    nind=_nind;
90    N=complexity;
91    if (N<1)
92       N=1;
93    if (N>10)
94       N=10;
95
96    params = (split_cb_params *) par;
97    subvect_size = params->subvect_size;
98    nb_subvect = params->nb_subvect;
99    shape_cb_size = 1<<params->shape_bits;
100    shape_cb = params->shape_cb;
101    have_sign = params->have_sign;
102    resp = PUSH(stack, shape_cb_size*subvect_size);
103    t = PUSH(stack, nsf);
104    r = PUSH(stack, nsf);
105    e = PUSH(stack, nsf);
106    E = PUSH(stack, shape_cb_size);
107    /*FIXME: This breaks if sizeof(int) != sizeof(float) */
108    ind = (int*)PUSH(stack, nb_subvect);
109
110    tmp = PUSH(stack, 2*N*nsf);
111    for (i=0;i<N;i++)
112    {
113       ot[i]=tmp;
114       tmp += nsf;
115       nt[i]=tmp;
116       tmp += nsf;
117    }
118
119    /*FIXME: This breaks if sizeof(int) != sizeof(float) */
120    best_index = (int*)PUSH(stack, N);
121    best_dist = PUSH(stack, N);
122    ndist = PUSH(stack, N);
123    odist = PUSH(stack, N);
124    
125    /*FIXME: This breaks if sizeof(int) != sizeof(float) */
126    itmp = (int*)PUSH(stack, 2*N*nb_subvect);
127    for (i=0;i<N;i++)
128    {
129       nind[i]=itmp;
130       itmp+=nb_subvect;
131       oind[i]=itmp;
132       itmp+=nb_subvect;
133       for (j=0;j<nb_subvect;j++)
134          nind[i][j]=oind[i][j]=-1;
135    }
136
137    for (j=0;j<N;j++)
138       for (i=0;i<nsf;i++)
139          ot[j][i]=target[i];
140
141    for (i=0;i<nsf;i++)
142       t[i]=target[i];
143
144    e[0]=1;
145    for (i=1;i<nsf;i++)
146       e[i]=0;
147    syn_percep_zero(e, ak, awk1, awk2, r, nsf,p, stack);
148
149    /* Pre-compute codewords response and energy */
150    for (i=0;i<shape_cb_size;i++)
151    {
152       float *res;
153       float *shape;
154
155       res = resp+i*subvect_size;
156       shape = shape_cb+i*subvect_size;
157       /* Compute codeword response */
158
159       for(j=0;j<subvect_size;j++)
160       {
161          res[j]=0;
162          for (k=0;k<=j;k++)
163             res[j] += shape[k]*r[j-k];
164       }
165       E[i]=0;
166       for(j=0;j<subvect_size;j++)
167          E[i]+=res[j]*res[j];
168    }
169
170    for (j=0;j<N;j++)
171       odist[j]=0;
172    /*For all subvectors*/
173    for (i=0;i<nb_subvect;i++)
174    {
175       /*"erase" nbest list*/
176       for (j=0;j<N;j++)
177          ndist[j]=-1;
178
179       /*For all n-bests of previous subvector*/
180       for (j=0;j<N;j++)
181       {
182          float *x=ot[j]+subvect_size*i;
183          /*Find new n-best based on previous n-best j*/
184          if (have_sign)
185             vq_nbest_sign(x, resp, subvect_size, shape_cb_size, E, N, best_index, best_dist);
186          else
187             vq_nbest(x, resp, subvect_size, shape_cb_size, E, N, best_index, best_dist);
188
189          /*For all new n-bests*/
190          for (k=0;k<N;k++)
191          {
192             float *ct;
193             float err=0;
194             ct = ot[j];
195             /*update target*/
196
197             /*previous target*/
198             for (m=i*subvect_size;m<(i+1)*subvect_size;m++)
199                t[m]=ct[m];
200
201             /* New code: update only enough of the target to calculate error*/
202             {
203                int rind;
204                float *res;
205                float sign=1;
206                rind = best_index[k];
207                if (rind>shape_cb_size)
208                {
209                   sign=-1;
210                   rind-=shape_cb_size;
211                }
212                res = resp+rind*subvect_size;
213                if (sign>0)
214                   for (m=0;m<subvect_size;m++)
215                      t[subvect_size*i+m] -= res[m];
216                else
217                   for (m=0;m<subvect_size;m++)
218                      t[subvect_size*i+m] += res[m];
219             }
220             
221             /*compute error (distance)*/
222             err=odist[j];
223             for (m=i*subvect_size;m<(i+1)*subvect_size;m++)
224                err += t[m]*t[m];
225             /*update n-best list*/
226             if (err<ndist[N-1] || ndist[N-1]<-.5)
227             {
228
229                /*previous target (we don't care what happened before*/
230                for (m=(i+1)*subvect_size;m<nsf;m++)
231                   t[m]=ct[m];
232                /* New code: update the rest of the target only if it's worth it */
233                for (m=0;m<subvect_size;m++)
234                {
235                   float g;
236                   int rind;
237                   float sign=1;
238                   rind = best_index[k];
239                   if (rind>shape_cb_size)
240                   {
241                      sign=-1;
242                      rind-=shape_cb_size;
243                   }
244
245                   g=sign*shape_cb[rind*subvect_size+m];
246                   q=subvect_size-m;
247                   for (n=subvect_size*(i+1);n<nsf;n++,q++)
248                      t[n] -= g*r[q];
249                }
250
251
252                for (m=0;m<N;m++)
253                {
254                   if (err < ndist[m] || ndist[m]<-.5)
255                   {
256                      for (n=N-1;n>m;n--)
257                      {
258                         for (q=0;q<nsf;q++)
259                            nt[n][q]=nt[n-1][q];
260                         for (q=0;q<nb_subvect;q++)
261                            nind[n][q]=nind[n-1][q];
262                         ndist[n]=ndist[n-1];
263                      }
264                      for (q=0;q<nsf;q++)
265                         nt[m][q]=t[q];
266                      for (q=0;q<nb_subvect;q++)
267                         nind[m][q]=oind[j][q];
268                      nind[m][i]=best_index[k];
269                      ndist[m]=err;
270                      break;
271                   }
272                }
273             }
274          }
275          if (i==0)
276            break;
277       }
278
279       /*update old-new data*/
280       /* just swap pointers instead of a long copy */
281       {
282          float **tmp;
283          tmp=ot;
284          ot=nt;
285          nt=tmp;
286       }
287       for (j=0;j<N;j++)
288          for (m=0;m<nb_subvect;m++)
289             oind[j][m]=nind[j][m];
290       for (j=0;j<N;j++)
291          odist[j]=ndist[j];
292    }
293
294    /*save indices*/
295    for (i=0;i<nb_subvect;i++)
296    {
297       ind[i]=nind[0][i];
298       speex_bits_pack(bits,ind[i],params->shape_bits+have_sign);
299    }
300    
301    /* Put everything back together */
302    for (i=0;i<nb_subvect;i++)
303    {
304       int rind;
305       float sign=1;
306       rind = ind[i];
307       if (rind>shape_cb_size)
308       {
309          sign=-1;
310          rind-=shape_cb_size;
311       }
312
313       for (j=0;j<subvect_size;j++)
314          e[subvect_size*i+j]=sign*shape_cb[rind*subvect_size+j];
315    }   
316    /* Update excitation */
317    for (j=0;j<nsf;j++)
318       exc[j]+=e[j];
319    
320    /* Update target */
321    syn_percep_zero(e, ak, awk1, awk2, r, nsf,p, stack);
322    for (j=0;j<nsf;j++)
323       target[j]-=r[j];
324
325 }
326
327
328 void split_cb_shape_sign_unquant(
329 float *exc,
330 void *par,                      /* non-overlapping codebook */
331 int   nsf,                      /* number of samples in subframe */
332 SpeexBits *bits,
333 float *stack
334 )
335 {
336    int i,j;
337    int *ind, *signs;
338    float *shape_cb;
339    int shape_cb_size, subvect_size, nb_subvect;
340    split_cb_params *params;
341    int have_sign;
342
343    params = (split_cb_params *) par;
344    subvect_size = params->subvect_size;
345    nb_subvect = params->nb_subvect;
346    shape_cb_size = 1<<params->shape_bits;
347    shape_cb = params->shape_cb;
348    have_sign = params->have_sign;
349
350    /*FIXME: This breaks if sizeof(int) != sizeof(float) */
351    ind = (int*)PUSH(stack, nb_subvect);
352    signs = (int*)PUSH(stack, nb_subvect);
353
354    /* Decode codewords and gains */
355    for (i=0;i<nb_subvect;i++)
356    {
357       if (have_sign)
358          signs[i] = speex_bits_unpack_unsigned(bits, 1);
359       else
360          signs[i] = 0;
361       ind[i] = speex_bits_unpack_unsigned(bits, params->shape_bits);
362    }
363    /* Compute decoded excitation */
364    for (i=0;i<nb_subvect;i++)
365    {
366       float s=1;
367       if (signs[i])
368          s=-1;
369       for (j=0;j<subvect_size;j++)
370          exc[subvect_size*i+j]+=s*shape_cb[ind[i]*subvect_size+j];
371    }
372 }
373
374 void noise_codebook_quant(
375 float target[],                 /* target vector */
376 float ak[],                     /* LPCs for this subframe */
377 float awk1[],                   /* Weighted LPCs for this subframe */
378 float awk2[],                   /* Weighted LPCs for this subframe */
379 void *par,                      /* Codebook/search parameters*/
380 int   p,                        /* number of LPC coeffs */
381 int   nsf,                      /* number of samples in subframe */
382 float *exc,
383 SpeexBits *bits,
384 float *stack,
385 int   complexity
386 )
387 {
388    int i;
389    float *tmp=PUSH(stack, nsf);
390    residue_percep_zero(target, ak, awk1, awk2, tmp, nsf, p, stack);
391
392    for (i=0;i<nsf;i++)
393       exc[i]+=tmp[i];
394    for (i=0;i<nsf;i++)
395       target[i]=0;
396
397 }
398
399
400 void noise_codebook_unquant(
401 float *exc,
402 void *par,                      /* non-overlapping codebook */
403 int   nsf,                      /* number of samples in subframe */
404 SpeexBits *bits,
405 float *stack
406 )
407 {
408    int i;
409
410    for (i=0;i<nsf;i++)
411       exc[i]+=3*((((float)rand())/RAND_MAX)-.5);
412 }