Oops. Thanks to Jim Crichton for pointing out that the complexity could end up
[speexdsp.git] / libspeex / cb_search.c
1 /* Copyright (C) 2002-2006 Jean-Marc Valin 
2    File: cb_search.c
3
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "cb_search.h"
37 #include "filters.h"
38 #include "stack_alloc.h"
39 #include "vq.h"
40 #include "misc.h"
41
42 #ifdef _USE_SSE
43 #include "cb_search_sse.h"
44 #elif defined(ARM4_ASM) || defined(ARM5E_ASM)
45 #include "cb_search_arm4.h"
46 #elif defined(BFIN_ASM)
47 #include "cb_search_bfin.h"
48 #endif
49
50 #ifndef OVERRIDE_COMPUTE_WEIGHTED_CODEBOOK
51 static void compute_weighted_codebook(const signed char *shape_cb, const spx_word16_t *r, spx_word16_t *resp, spx_word16_t *resp2, spx_word32_t *E, int shape_cb_size, int subvect_size, char *stack)
52 {
53    int i, j, k;
54    VARDECL(spx_word16_t *shape);
55    ALLOC(shape, subvect_size, spx_word16_t);
56    for (i=0;i<shape_cb_size;i++)
57    {
58       spx_word16_t *res;
59       
60       res = resp+i*subvect_size;
61       for (k=0;k<subvect_size;k++)
62          shape[k] = (spx_word16_t)shape_cb[i*subvect_size+k];
63       E[i]=0;
64
65       /* Compute codeword response using convolution with impulse response */
66       for(j=0;j<subvect_size;j++)
67       {
68          spx_word32_t resj=0;
69          spx_word16_t res16;
70          for (k=0;k<=j;k++)
71             resj = MAC16_16(resj,shape[k],r[j-k]);
72 #ifdef FIXED_POINT
73          res16 = EXTRACT16(SHR32(resj, 13));
74 #else
75          res16 = 0.03125f*resj;
76 #endif
77          /* Compute codeword energy */
78          E[i]=MAC16_16(E[i],res16,res16);
79          res[j] = res16;
80          /*printf ("%d\n", (int)res[j]);*/
81       }
82    }
83
84 }
85 #endif
86
87 #ifndef OVERRIDE_TARGET_UPDATE
88 static inline void target_update(spx_word16_t *t, spx_word16_t g, spx_word16_t *r, int len)
89 {
90    int n;
91    for (n=0;n<len;n++)
92       t[n] = SUB16(t[n],PSHR32(MULT16_16(g,r[n]),13));
93 }
94 #endif
95
96
97
98 static void split_cb_search_shape_sign_N1(
99 spx_word16_t target[],                  /* target vector */
100 spx_coef_t ak[],                        /* LPCs for this subframe */
101 spx_coef_t awk1[],                      /* Weighted LPCs for this subframe */
102 spx_coef_t awk2[],                      /* Weighted LPCs for this subframe */
103 const void *par,                      /* Codebook/search parameters*/
104 int   p,                        /* number of LPC coeffs */
105 int   nsf,                      /* number of samples in subframe */
106 spx_sig_t *exc,
107 spx_word16_t *r,
108 SpeexBits *bits,
109 char *stack,
110 int   update_target
111 )
112 {
113    int i,j,m,q;
114    VARDECL(spx_word16_t *resp);
115 #ifdef _USE_SSE
116    VARDECL(__m128 *resp2);
117    VARDECL(__m128 *E);
118 #else
119    spx_word16_t *resp2;
120    VARDECL(spx_word32_t *E);
121 #endif
122    VARDECL(spx_word16_t *t);
123    VARDECL(spx_sig_t *e);
124    const signed char *shape_cb;
125    int shape_cb_size, subvect_size, nb_subvect;
126    const split_cb_params *params;
127    int best_index;
128    spx_word32_t best_dist;
129    int have_sign;
130    
131    params = (const split_cb_params *) par;
132    subvect_size = params->subvect_size;
133    nb_subvect = params->nb_subvect;
134    shape_cb_size = 1<<params->shape_bits;
135    shape_cb = params->shape_cb;
136    have_sign = params->have_sign;
137    ALLOC(resp, shape_cb_size*subvect_size, spx_word16_t);
138 #ifdef _USE_SSE
139    ALLOC(resp2, (shape_cb_size*subvect_size)>>2, __m128);
140    ALLOC(E, shape_cb_size>>2, __m128);
141 #else
142    resp2 = resp;
143    ALLOC(E, shape_cb_size, spx_word32_t);
144 #endif
145    ALLOC(t, nsf, spx_word16_t);
146    ALLOC(e, nsf, spx_sig_t);
147    
148    /* FIXME: Do we still need to copy the target? */
149    for (i=0;i<nsf;i++)
150       t[i]=target[i];
151
152    compute_weighted_codebook(shape_cb, r, resp, resp2, E, shape_cb_size, subvect_size, stack);
153
154    for (i=0;i<nb_subvect;i++)
155    {
156       spx_word16_t *x=t+subvect_size*i;
157       /*Find new n-best based on previous n-best j*/
158       if (have_sign)
159          vq_nbest_sign(x, resp2, subvect_size, shape_cb_size, E, 1, &best_index, &best_dist, stack);
160       else
161          vq_nbest(x, resp2, subvect_size, shape_cb_size, E, 1, &best_index, &best_dist, stack);
162       
163       speex_bits_pack(bits,best_index,params->shape_bits+have_sign);
164       
165       {
166          int rind;
167          spx_word16_t *res;
168          spx_word16_t sign=1;
169          rind = best_index;
170          if (rind>=shape_cb_size)
171          {
172             sign=-1;
173             rind-=shape_cb_size;
174          }
175          res = resp+rind*subvect_size;
176          if (sign>0)
177             for (m=0;m<subvect_size;m++)
178                t[subvect_size*i+m] = SUB16(t[subvect_size*i+m], res[m]);
179          else
180             for (m=0;m<subvect_size;m++)
181                t[subvect_size*i+m] = ADD16(t[subvect_size*i+m], res[m]);
182
183 #ifdef FIXED_POINT
184          if (sign)
185          {
186             for (j=0;j<subvect_size;j++)
187                e[subvect_size*i+j]=SHL32(EXTEND32(shape_cb[rind*subvect_size+j]),SIG_SHIFT-5);
188          } else {
189             for (j=0;j<subvect_size;j++)
190                e[subvect_size*i+j]=NEG32(SHL32(EXTEND32(shape_cb[rind*subvect_size+j]),SIG_SHIFT-5));
191          }
192 #else
193          for (j=0;j<subvect_size;j++)
194             e[subvect_size*i+j]=sign*0.03125*shape_cb[rind*subvect_size+j];
195 #endif
196       
197       }
198             
199       for (m=0;m<subvect_size;m++)
200       {
201          spx_word16_t g;
202          int rind;
203          spx_word16_t sign=1;
204          rind = best_index;
205          if (rind>=shape_cb_size)
206          {
207             sign=-1;
208             rind-=shape_cb_size;
209          }
210          
211          q=subvect_size-m;
212 #ifdef FIXED_POINT
213          g=sign*shape_cb[rind*subvect_size+m];
214 #else
215          g=sign*0.03125*shape_cb[rind*subvect_size+m];
216 #endif
217          target_update(t+subvect_size*(i+1), g, r+q, nsf-subvect_size*(i+1));
218       }
219    }
220
221    /* Update excitation */
222    /* FIXME: We could update the excitation directly above */
223    for (j=0;j<nsf;j++)
224       exc[j]=ADD32(exc[j],e[j]);
225    
226    /* Update target: only update target if necessary */
227    if (update_target)
228    {
229       VARDECL(spx_sig_t *r2);
230       ALLOC(r2, nsf, spx_sig_t);
231       syn_percep_zero(e, ak, awk1, awk2, r2, nsf,p, stack);
232       for (j=0;j<nsf;j++)
233          target[j]=SUB16(target[j],EXTRACT16(PSHR32(r2[j],8)));
234    }
235 }
236
237
238
239 void split_cb_search_shape_sign(
240 spx_word16_t target[],                  /* target vector */
241 spx_coef_t ak[],                        /* LPCs for this subframe */
242 spx_coef_t awk1[],                      /* Weighted LPCs for this subframe */
243 spx_coef_t awk2[],                      /* Weighted LPCs for this subframe */
244 const void *par,                      /* Codebook/search parameters*/
245 int   p,                        /* number of LPC coeffs */
246 int   nsf,                      /* number of samples in subframe */
247 spx_sig_t *exc,
248 spx_word16_t *r,
249 SpeexBits *bits,
250 char *stack,
251 int   complexity,
252 int   update_target
253 )
254 {
255    int i,j,k,m,n,q;
256    VARDECL(spx_word16_t *resp);
257 #ifdef _USE_SSE
258    VARDECL(__m128 *resp2);
259    VARDECL(__m128 *E);
260 #else
261    spx_word16_t *resp2;
262    VARDECL(spx_word32_t *E);
263 #endif
264    VARDECL(spx_word16_t *t);
265    VARDECL(spx_sig_t *e);
266    VARDECL(spx_sig_t *r2);
267    VARDECL(spx_word16_t *tmp);
268    VARDECL(spx_word32_t *ndist);
269    VARDECL(spx_word32_t *odist);
270    VARDECL(int *itmp);
271    VARDECL(spx_word16_t **ot2);
272    VARDECL(spx_word16_t **nt2);
273    spx_word16_t **ot, **nt;
274    VARDECL(int **nind);
275    VARDECL(int **oind);
276    VARDECL(int *ind);
277    const signed char *shape_cb;
278    int shape_cb_size, subvect_size, nb_subvect;
279    const split_cb_params *params;
280    int N=2;
281    VARDECL(int *best_index);
282    VARDECL(spx_word32_t *best_dist);
283    VARDECL(int *best_nind);
284    VARDECL(int *best_ntarget);
285    int have_sign;
286    N=complexity;
287    if (N>10)
288       N=10;
289    /* Complexity isn't as important for the codebooks as it is for the pitch */
290    N=(2*N)/3;
291    if (N<1)
292       N=1;
293    if (N==1)
294    {
295       split_cb_search_shape_sign_N1(target,ak,awk1,awk2,par,p,nsf,exc,r,bits,stack,update_target);
296       return;
297    }
298    ALLOC(ot2, N, spx_word16_t*);
299    ALLOC(nt2, N, spx_word16_t*);
300    ALLOC(oind, N, int*);
301    ALLOC(nind, N, int*);
302
303    params = (const split_cb_params *) par;
304    subvect_size = params->subvect_size;
305    nb_subvect = params->nb_subvect;
306    shape_cb_size = 1<<params->shape_bits;
307    shape_cb = params->shape_cb;
308    have_sign = params->have_sign;
309    ALLOC(resp, shape_cb_size*subvect_size, spx_word16_t);
310 #ifdef _USE_SSE
311    ALLOC(resp2, (shape_cb_size*subvect_size)>>2, __m128);
312    ALLOC(E, shape_cb_size>>2, __m128);
313 #else
314    resp2 = resp;
315    ALLOC(E, shape_cb_size, spx_word32_t);
316 #endif
317    ALLOC(t, nsf, spx_word16_t);
318    ALLOC(e, nsf, spx_sig_t);
319    ALLOC(r2, nsf, spx_sig_t);
320    ALLOC(ind, nb_subvect, int);
321
322    ALLOC(tmp, 2*N*nsf, spx_word16_t);
323    for (i=0;i<N;i++)
324    {
325       ot2[i]=tmp+2*i*nsf;
326       nt2[i]=tmp+(2*i+1)*nsf;
327    }
328    ot=ot2;
329    nt=nt2;
330    ALLOC(best_index, N, int);
331    ALLOC(best_dist, N, spx_word32_t);
332    ALLOC(best_nind, N, int);
333    ALLOC(best_ntarget, N, int);
334    ALLOC(ndist, N, spx_word32_t);
335    ALLOC(odist, N, spx_word32_t);
336    
337    ALLOC(itmp, 2*N*nb_subvect, int);
338    for (i=0;i<N;i++)
339    {
340       nind[i]=itmp+2*i*nb_subvect;
341       oind[i]=itmp+(2*i+1)*nb_subvect;
342    }
343    
344    for (i=0;i<nsf;i++)
345       t[i]=target[i];
346
347    for (j=0;j<N;j++)
348       speex_move(&ot[j][0], t, nsf*sizeof(spx_word16_t));
349
350    /* Pre-compute codewords response and energy */
351    compute_weighted_codebook(shape_cb, r, resp, resp2, E, shape_cb_size, subvect_size, stack);
352
353    for (j=0;j<N;j++)
354       odist[j]=0;
355    
356    /*For all subvectors*/
357    for (i=0;i<nb_subvect;i++)
358    {
359       /*"erase" nbest list*/
360       for (j=0;j<N;j++)
361          ndist[j]=VERY_LARGE32;
362
363       /*For all n-bests of previous subvector*/
364       for (j=0;j<N;j++)
365       {
366          spx_word16_t *x=ot[j]+subvect_size*i;
367          spx_word32_t tener = 0;
368          for (m=0;m<subvect_size;m++)
369             tener = MAC16_16(tener, x[m],x[m]);
370 #ifdef FIXED_POINT
371          tener = SHR32(tener,1);
372 #else
373          tener *= .5;
374 #endif
375          /*Find new n-best based on previous n-best j*/
376          if (have_sign)
377             vq_nbest_sign(x, resp2, subvect_size, shape_cb_size, E, N, best_index, best_dist, stack);
378          else
379             vq_nbest(x, resp2, subvect_size, shape_cb_size, E, N, best_index, best_dist, stack);
380
381          /*For all new n-bests*/
382          for (k=0;k<N;k++)
383          {
384             /* Compute total distance (including previous sub-vectors */
385             spx_word32_t err = ADD32(ADD32(odist[j],best_dist[k]),tener);
386             
387             /*update n-best list*/
388             if (err<ndist[N-1])
389             {
390                for (m=0;m<N;m++)
391                {
392                   if (err < ndist[m])
393                   {
394                      for (n=N-1;n>m;n--)
395                      {
396                         ndist[n] = ndist[n-1];
397                         best_nind[n] = best_nind[n-1];
398                         best_ntarget[n] = best_ntarget[n-1];
399                      }
400                      ndist[m] = err;
401                      best_nind[n] = best_index[k];
402                      best_ntarget[n] = j;
403                      break;
404                   }
405                }
406             }
407          }
408          if (i==0)
409             break;
410       }
411       for (j=0;j<N;j++)
412       {
413          /*previous target (we don't care what happened before*/
414          for (m=(i+1)*subvect_size;m<nsf;m++)
415             nt[j][m]=ot[best_ntarget[j]][m];
416          
417          /* New code: update the rest of the target only if it's worth it */
418          for (m=0;m<subvect_size;m++)
419          {
420             spx_word16_t g;
421             int rind;
422             spx_word16_t sign=1;
423             rind = best_nind[j];
424             if (rind>=shape_cb_size)
425             {
426                sign=-1;
427                rind-=shape_cb_size;
428             }
429
430             q=subvect_size-m;
431 #ifdef FIXED_POINT
432             g=sign*shape_cb[rind*subvect_size+m];
433 #else
434             g=sign*0.03125*shape_cb[rind*subvect_size+m];
435 #endif
436             target_update(nt[j]+subvect_size*(i+1), g, r+q, nsf-subvect_size*(i+1));
437          }
438
439          for (q=0;q<nb_subvect;q++)
440             nind[j][q]=oind[best_ntarget[j]][q];
441          nind[j][i]=best_nind[j];
442       }
443
444       /*update old-new data*/
445       /* just swap pointers instead of a long copy */
446       {
447          spx_word16_t **tmp2;
448          tmp2=ot;
449          ot=nt;
450          nt=tmp2;
451       }
452       for (j=0;j<N;j++)
453          for (m=0;m<nb_subvect;m++)
454             oind[j][m]=nind[j][m];
455       for (j=0;j<N;j++)
456          odist[j]=ndist[j];
457    }
458
459    /*save indices*/
460    for (i=0;i<nb_subvect;i++)
461    {
462       ind[i]=nind[0][i];
463       speex_bits_pack(bits,ind[i],params->shape_bits+have_sign);
464    }
465    
466    /* Put everything back together */
467    for (i=0;i<nb_subvect;i++)
468    {
469       int rind;
470       spx_word16_t sign=1;
471       rind = ind[i];
472       if (rind>=shape_cb_size)
473       {
474          sign=-1;
475          rind-=shape_cb_size;
476       }
477 #ifdef FIXED_POINT
478       if (sign==1)
479       {
480          for (j=0;j<subvect_size;j++)
481             e[subvect_size*i+j]=SHL32(EXTEND32(shape_cb[rind*subvect_size+j]),SIG_SHIFT-5);
482       } else {
483          for (j=0;j<subvect_size;j++)
484             e[subvect_size*i+j]=NEG32(SHL32(EXTEND32(shape_cb[rind*subvect_size+j]),SIG_SHIFT-5));
485       }
486 #else
487       for (j=0;j<subvect_size;j++)
488          e[subvect_size*i+j]=sign*0.03125*shape_cb[rind*subvect_size+j];
489 #endif
490    }   
491    /* Update excitation */
492    for (j=0;j<nsf;j++)
493       exc[j]=ADD32(exc[j],e[j]);
494    
495    /* Update target: only update target if necessary */
496    if (update_target)
497    {
498       syn_percep_zero(e, ak, awk1, awk2, r2, nsf,p, stack);
499       for (j=0;j<nsf;j++)
500          target[j]=SUB16(target[j],EXTRACT16(PSHR32(r2[j],8)));
501    }
502 }
503
504
505 void split_cb_shape_sign_unquant(
506 spx_sig_t *exc,
507 const void *par,                      /* non-overlapping codebook */
508 int   nsf,                      /* number of samples in subframe */
509 SpeexBits *bits,
510 char *stack,
511 spx_int32_t *seed
512 )
513 {
514    int i,j;
515    VARDECL(int *ind);
516    VARDECL(int *signs);
517    const signed char *shape_cb;
518    int shape_cb_size, subvect_size, nb_subvect;
519    const split_cb_params *params;
520    int have_sign;
521
522    params = (const split_cb_params *) par;
523    subvect_size = params->subvect_size;
524    nb_subvect = params->nb_subvect;
525    shape_cb_size = 1<<params->shape_bits;
526    shape_cb = params->shape_cb;
527    have_sign = params->have_sign;
528
529    ALLOC(ind, nb_subvect, int);
530    ALLOC(signs, nb_subvect, int);
531
532    /* Decode codewords and gains */
533    for (i=0;i<nb_subvect;i++)
534    {
535       if (have_sign)
536          signs[i] = speex_bits_unpack_unsigned(bits, 1);
537       else
538          signs[i] = 0;
539       ind[i] = speex_bits_unpack_unsigned(bits, params->shape_bits);
540    }
541    /* Compute decoded excitation */
542    for (i=0;i<nb_subvect;i++)
543    {
544       spx_word16_t s=1;
545       if (signs[i])
546          s=-1;
547 #ifdef FIXED_POINT
548       if (s==1)
549       {
550          for (j=0;j<subvect_size;j++)
551             exc[subvect_size*i+j]=SHL32(EXTEND32(shape_cb[ind[i]*subvect_size+j]),SIG_SHIFT-5);
552       } else {
553          for (j=0;j<subvect_size;j++)
554             exc[subvect_size*i+j]=NEG32(SHL32(EXTEND32(shape_cb[ind[i]*subvect_size+j]),SIG_SHIFT-5));
555       }
556 #else
557       for (j=0;j<subvect_size;j++)
558          exc[subvect_size*i+j]+=s*0.03125*shape_cb[ind[i]*subvect_size+j];      
559 #endif
560    }
561 }
562
563 void noise_codebook_quant(
564 spx_word16_t target[],                  /* target vector */
565 spx_coef_t ak[],                        /* LPCs for this subframe */
566 spx_coef_t awk1[],                      /* Weighted LPCs for this subframe */
567 spx_coef_t awk2[],                      /* Weighted LPCs for this subframe */
568 const void *par,                      /* Codebook/search parameters*/
569 int   p,                        /* number of LPC coeffs */
570 int   nsf,                      /* number of samples in subframe */
571 spx_sig_t *exc,
572 spx_word16_t *r,
573 SpeexBits *bits,
574 char *stack,
575 int   complexity,
576 int   update_target
577 )
578 {
579    int i;
580    VARDECL(spx_sig_t *tmp);
581    ALLOC(tmp, nsf, spx_sig_t);
582    for (i=0;i<nsf;i++)
583       tmp[i]=PSHR32(EXTEND32(target[i]),SIG_SHIFT);
584    residue_percep_zero(tmp, ak, awk1, awk2, tmp, nsf, p, stack);
585
586    for (i=0;i<nsf;i++)
587       exc[i]+=tmp[i];
588    for (i=0;i<nsf;i++)
589       target[i]=0;
590 }
591
592
593 void noise_codebook_unquant(
594 spx_sig_t *exc,
595 const void *par,                      /* non-overlapping codebook */
596 int   nsf,                      /* number of samples in subframe */
597 SpeexBits *bits,
598 char *stack,
599 spx_int32_t *seed
600 )
601 {
602    int i;
603    /* FIXME: This is bad, but I don't think the function ever gets called anyway */
604    for (i=0;i<nsf;i++)
605       exc[i]=SHL32(EXTEND32(speex_rand(1, seed)),SIG_SHIFT);
606 }