Thorvald Natvig: Patch to query the impulse response from the AEC
[speexdsp.git] / libspeex / cb_search.c
1 /* Copyright (C) 2002-2006 Jean-Marc Valin 
2    File: cb_search.c
3
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "cb_search.h"
37 #include "filters.h"
38 #include "stack_alloc.h"
39 #include "vq.h"
40 #include "arch.h"
41 #include "math_approx.h"
42 #include "os_support.h"
43
44 #ifdef _USE_SSE
45 #include "cb_search_sse.h"
46 #elif defined(ARM4_ASM) || defined(ARM5E_ASM)
47 #include "cb_search_arm4.h"
48 #elif defined(BFIN_ASM)
49 #include "cb_search_bfin.h"
50 #endif
51
52 #ifndef OVERRIDE_COMPUTE_WEIGHTED_CODEBOOK
53 static void compute_weighted_codebook(const signed char *shape_cb, const spx_word16_t *r, spx_word16_t *resp, spx_word16_t *resp2, spx_word32_t *E, int shape_cb_size, int subvect_size, char *stack)
54 {
55    int i, j, k;
56    VARDECL(spx_word16_t *shape);
57    ALLOC(shape, subvect_size, spx_word16_t);
58    for (i=0;i<shape_cb_size;i++)
59    {
60       spx_word16_t *res;
61       
62       res = resp+i*subvect_size;
63       for (k=0;k<subvect_size;k++)
64          shape[k] = (spx_word16_t)shape_cb[i*subvect_size+k];
65       E[i]=0;
66
67       /* Compute codeword response using convolution with impulse response */
68       for(j=0;j<subvect_size;j++)
69       {
70          spx_word32_t resj=0;
71          spx_word16_t res16;
72          for (k=0;k<=j;k++)
73             resj = MAC16_16(resj,shape[k],r[j-k]);
74 #ifdef FIXED_POINT
75          res16 = EXTRACT16(SHR32(resj, 13));
76 #else
77          res16 = 0.03125f*resj;
78 #endif
79          /* Compute codeword energy */
80          E[i]=MAC16_16(E[i],res16,res16);
81          res[j] = res16;
82          /*printf ("%d\n", (int)res[j]);*/
83       }
84    }
85
86 }
87 #endif
88
89 #ifndef OVERRIDE_TARGET_UPDATE
90 static inline void target_update(spx_word16_t *t, spx_word16_t g, spx_word16_t *r, int len)
91 {
92    int n;
93    for (n=0;n<len;n++)
94       t[n] = SUB16(t[n],PSHR32(MULT16_16(g,r[n]),13));
95 }
96 #endif
97
98
99
100 static void split_cb_search_shape_sign_N1(
101 spx_word16_t target[],                  /* target vector */
102 spx_coef_t ak[],                        /* LPCs for this subframe */
103 spx_coef_t awk1[],                      /* Weighted LPCs for this subframe */
104 spx_coef_t awk2[],                      /* Weighted LPCs for this subframe */
105 const void *par,                      /* Codebook/search parameters*/
106 int   p,                        /* number of LPC coeffs */
107 int   nsf,                      /* number of samples in subframe */
108 spx_sig_t *exc,
109 spx_word16_t *r,
110 SpeexBits *bits,
111 char *stack,
112 int   update_target
113 )
114 {
115    int i,j,m,q;
116    VARDECL(spx_word16_t *resp);
117 #ifdef _USE_SSE
118    VARDECL(__m128 *resp2);
119    VARDECL(__m128 *E);
120 #else
121    spx_word16_t *resp2;
122    VARDECL(spx_word32_t *E);
123 #endif
124    VARDECL(spx_word16_t *t);
125    VARDECL(spx_sig_t *e);
126    const signed char *shape_cb;
127    int shape_cb_size, subvect_size, nb_subvect;
128    const split_cb_params *params;
129    int best_index;
130    spx_word32_t best_dist;
131    int have_sign;
132    
133    params = (const split_cb_params *) par;
134    subvect_size = params->subvect_size;
135    nb_subvect = params->nb_subvect;
136    shape_cb_size = 1<<params->shape_bits;
137    shape_cb = params->shape_cb;
138    have_sign = params->have_sign;
139    ALLOC(resp, shape_cb_size*subvect_size, spx_word16_t);
140 #ifdef _USE_SSE
141    ALLOC(resp2, (shape_cb_size*subvect_size)>>2, __m128);
142    ALLOC(E, shape_cb_size>>2, __m128);
143 #else
144    resp2 = resp;
145    ALLOC(E, shape_cb_size, spx_word32_t);
146 #endif
147    ALLOC(t, nsf, spx_word16_t);
148    ALLOC(e, nsf, spx_sig_t);
149    
150    /* FIXME: Do we still need to copy the target? */
151    SPEEX_COPY(t, target, nsf);
152
153    compute_weighted_codebook(shape_cb, r, resp, resp2, E, shape_cb_size, subvect_size, stack);
154
155    for (i=0;i<nb_subvect;i++)
156    {
157       spx_word16_t *x=t+subvect_size*i;
158       /*Find new n-best based on previous n-best j*/
159       if (have_sign)
160          vq_nbest_sign(x, resp2, subvect_size, shape_cb_size, E, 1, &best_index, &best_dist, stack);
161       else
162          vq_nbest(x, resp2, subvect_size, shape_cb_size, E, 1, &best_index, &best_dist, stack);
163       
164       speex_bits_pack(bits,best_index,params->shape_bits+have_sign);
165       
166       {
167          int rind;
168          spx_word16_t *res;
169          spx_word16_t sign=1;
170          rind = best_index;
171          if (rind>=shape_cb_size)
172          {
173             sign=-1;
174             rind-=shape_cb_size;
175          }
176          res = resp+rind*subvect_size;
177          if (sign>0)
178             for (m=0;m<subvect_size;m++)
179                t[subvect_size*i+m] = SUB16(t[subvect_size*i+m], res[m]);
180          else
181             for (m=0;m<subvect_size;m++)
182                t[subvect_size*i+m] = ADD16(t[subvect_size*i+m], res[m]);
183
184 #ifdef FIXED_POINT
185          if (sign==1)
186          {
187             for (j=0;j<subvect_size;j++)
188                e[subvect_size*i+j]=SHL32(EXTEND32(shape_cb[rind*subvect_size+j]),SIG_SHIFT-5);
189          } else {
190             for (j=0;j<subvect_size;j++)
191                e[subvect_size*i+j]=NEG32(SHL32(EXTEND32(shape_cb[rind*subvect_size+j]),SIG_SHIFT-5));
192          }
193 #else
194          for (j=0;j<subvect_size;j++)
195             e[subvect_size*i+j]=sign*0.03125*shape_cb[rind*subvect_size+j];
196 #endif
197       
198       }
199             
200       for (m=0;m<subvect_size;m++)
201       {
202          spx_word16_t g;
203          int rind;
204          spx_word16_t sign=1;
205          rind = best_index;
206          if (rind>=shape_cb_size)
207          {
208             sign=-1;
209             rind-=shape_cb_size;
210          }
211          
212          q=subvect_size-m;
213 #ifdef FIXED_POINT
214          g=sign*shape_cb[rind*subvect_size+m];
215 #else
216          g=sign*0.03125*shape_cb[rind*subvect_size+m];
217 #endif
218          target_update(t+subvect_size*(i+1), g, r+q, nsf-subvect_size*(i+1));
219       }
220    }
221
222    /* Update excitation */
223    /* FIXME: We could update the excitation directly above */
224    for (j=0;j<nsf;j++)
225       exc[j]=ADD32(exc[j],e[j]);
226    
227    /* Update target: only update target if necessary */
228    if (update_target)
229    {
230       VARDECL(spx_word16_t *r2);
231       ALLOC(r2, nsf, spx_word16_t);
232       for (j=0;j<nsf;j++)
233          r2[j] = EXTRACT16(PSHR32(e[j] ,6));
234       syn_percep_zero16(r2, ak, awk1, awk2, r2, nsf,p, stack);
235       for (j=0;j<nsf;j++)
236          target[j]=SUB16(target[j],PSHR16(r2[j],2));
237    }
238 }
239
240
241
242 void split_cb_search_shape_sign(
243 spx_word16_t target[],                  /* target vector */
244 spx_coef_t ak[],                        /* LPCs for this subframe */
245 spx_coef_t awk1[],                      /* Weighted LPCs for this subframe */
246 spx_coef_t awk2[],                      /* Weighted LPCs for this subframe */
247 const void *par,                      /* Codebook/search parameters*/
248 int   p,                        /* number of LPC coeffs */
249 int   nsf,                      /* number of samples in subframe */
250 spx_sig_t *exc,
251 spx_word16_t *r,
252 SpeexBits *bits,
253 char *stack,
254 int   complexity,
255 int   update_target
256 )
257 {
258    int i,j,k,m,n,q;
259    VARDECL(spx_word16_t *resp);
260 #ifdef _USE_SSE
261    VARDECL(__m128 *resp2);
262    VARDECL(__m128 *E);
263 #else
264    spx_word16_t *resp2;
265    VARDECL(spx_word32_t *E);
266 #endif
267    VARDECL(spx_word16_t *t);
268    VARDECL(spx_sig_t *e);
269    VARDECL(spx_word16_t *tmp);
270    VARDECL(spx_word32_t *ndist);
271    VARDECL(spx_word32_t *odist);
272    VARDECL(int *itmp);
273    VARDECL(spx_word16_t **ot2);
274    VARDECL(spx_word16_t **nt2);
275    spx_word16_t **ot, **nt;
276    VARDECL(int **nind);
277    VARDECL(int **oind);
278    VARDECL(int *ind);
279    const signed char *shape_cb;
280    int shape_cb_size, subvect_size, nb_subvect;
281    const split_cb_params *params;
282    int N=2;
283    VARDECL(int *best_index);
284    VARDECL(spx_word32_t *best_dist);
285    VARDECL(int *best_nind);
286    VARDECL(int *best_ntarget);
287    int have_sign;
288    N=complexity;
289    if (N>10)
290       N=10;
291    /* Complexity isn't as important for the codebooks as it is for the pitch */
292    N=(2*N)/3;
293    if (N<1)
294       N=1;
295    if (N==1)
296    {
297       split_cb_search_shape_sign_N1(target,ak,awk1,awk2,par,p,nsf,exc,r,bits,stack,update_target);
298       return;
299    }
300    ALLOC(ot2, N, spx_word16_t*);
301    ALLOC(nt2, N, spx_word16_t*);
302    ALLOC(oind, N, int*);
303    ALLOC(nind, N, int*);
304
305    params = (const split_cb_params *) par;
306    subvect_size = params->subvect_size;
307    nb_subvect = params->nb_subvect;
308    shape_cb_size = 1<<params->shape_bits;
309    shape_cb = params->shape_cb;
310    have_sign = params->have_sign;
311    ALLOC(resp, shape_cb_size*subvect_size, spx_word16_t);
312 #ifdef _USE_SSE
313    ALLOC(resp2, (shape_cb_size*subvect_size)>>2, __m128);
314    ALLOC(E, shape_cb_size>>2, __m128);
315 #else
316    resp2 = resp;
317    ALLOC(E, shape_cb_size, spx_word32_t);
318 #endif
319    ALLOC(t, nsf, spx_word16_t);
320    ALLOC(e, nsf, spx_sig_t);
321    ALLOC(ind, nb_subvect, int);
322
323    ALLOC(tmp, 2*N*nsf, spx_word16_t);
324    for (i=0;i<N;i++)
325    {
326       ot2[i]=tmp+2*i*nsf;
327       nt2[i]=tmp+(2*i+1)*nsf;
328    }
329    ot=ot2;
330    nt=nt2;
331    ALLOC(best_index, N, int);
332    ALLOC(best_dist, N, spx_word32_t);
333    ALLOC(best_nind, N, int);
334    ALLOC(best_ntarget, N, int);
335    ALLOC(ndist, N, spx_word32_t);
336    ALLOC(odist, N, spx_word32_t);
337    
338    ALLOC(itmp, 2*N*nb_subvect, int);
339    for (i=0;i<N;i++)
340    {
341       nind[i]=itmp+2*i*nb_subvect;
342       oind[i]=itmp+(2*i+1)*nb_subvect;
343    }
344    
345    SPEEX_COPY(t, target, nsf);
346
347    for (j=0;j<N;j++)
348       SPEEX_COPY(&ot[j][0], t, nsf);
349
350    /* Pre-compute codewords response and energy */
351    compute_weighted_codebook(shape_cb, r, resp, resp2, E, shape_cb_size, subvect_size, stack);
352
353    for (j=0;j<N;j++)
354       odist[j]=0;
355    
356    /*For all subvectors*/
357    for (i=0;i<nb_subvect;i++)
358    {
359       /*"erase" nbest list*/
360       for (j=0;j<N;j++)
361          ndist[j]=VERY_LARGE32;
362       /* This is not strictly necessary, but it provides an additonal safety 
363          to prevent crashes in case something goes wrong in the previous
364          steps (e.g. NaNs) */
365       for (j=0;j<N;j++)
366          best_nind[j] = best_ntarget[j] = 0;
367       /*For all n-bests of previous subvector*/
368       for (j=0;j<N;j++)
369       {
370          spx_word16_t *x=ot[j]+subvect_size*i;
371          spx_word32_t tener = 0;
372          for (m=0;m<subvect_size;m++)
373             tener = MAC16_16(tener, x[m],x[m]);
374 #ifdef FIXED_POINT
375          tener = SHR32(tener,1);
376 #else
377          tener *= .5;
378 #endif
379          /*Find new n-best based on previous n-best j*/
380          if (have_sign)
381             vq_nbest_sign(x, resp2, subvect_size, shape_cb_size, E, N, best_index, best_dist, stack);
382          else
383             vq_nbest(x, resp2, subvect_size, shape_cb_size, E, N, best_index, best_dist, stack);
384
385          /*For all new n-bests*/
386          for (k=0;k<N;k++)
387          {
388             /* Compute total distance (including previous sub-vectors */
389             spx_word32_t err = ADD32(ADD32(odist[j],best_dist[k]),tener);
390             
391             /*update n-best list*/
392             if (err<ndist[N-1])
393             {
394                for (m=0;m<N;m++)
395                {
396                   if (err < ndist[m])
397                   {
398                      for (n=N-1;n>m;n--)
399                      {
400                         ndist[n] = ndist[n-1];
401                         best_nind[n] = best_nind[n-1];
402                         best_ntarget[n] = best_ntarget[n-1];
403                      }
404                      /* n is equal to m here, so they're interchangeable */
405                      ndist[m] = err;
406                      best_nind[n] = best_index[k];
407                      best_ntarget[n] = j;
408                      break;
409                   }
410                }
411             }
412          }
413          if (i==0)
414             break;
415       }
416       for (j=0;j<N;j++)
417       {
418          /*previous target (we don't care what happened before*/
419          for (m=(i+1)*subvect_size;m<nsf;m++)
420             nt[j][m]=ot[best_ntarget[j]][m];
421          
422          /* New code: update the rest of the target only if it's worth it */
423          for (m=0;m<subvect_size;m++)
424          {
425             spx_word16_t g;
426             int rind;
427             spx_word16_t sign=1;
428             rind = best_nind[j];
429             if (rind>=shape_cb_size)
430             {
431                sign=-1;
432                rind-=shape_cb_size;
433             }
434
435             q=subvect_size-m;
436 #ifdef FIXED_POINT
437             g=sign*shape_cb[rind*subvect_size+m];
438 #else
439             g=sign*0.03125*shape_cb[rind*subvect_size+m];
440 #endif
441             target_update(nt[j]+subvect_size*(i+1), g, r+q, nsf-subvect_size*(i+1));
442          }
443
444          for (q=0;q<nb_subvect;q++)
445             nind[j][q]=oind[best_ntarget[j]][q];
446          nind[j][i]=best_nind[j];
447       }
448
449       /*update old-new data*/
450       /* just swap pointers instead of a long copy */
451       {
452          spx_word16_t **tmp2;
453          tmp2=ot;
454          ot=nt;
455          nt=tmp2;
456       }
457       for (j=0;j<N;j++)
458          for (m=0;m<nb_subvect;m++)
459             oind[j][m]=nind[j][m];
460       for (j=0;j<N;j++)
461          odist[j]=ndist[j];
462    }
463
464    /*save indices*/
465    for (i=0;i<nb_subvect;i++)
466    {
467       ind[i]=nind[0][i];
468       speex_bits_pack(bits,ind[i],params->shape_bits+have_sign);
469    }
470    
471    /* Put everything back together */
472    for (i=0;i<nb_subvect;i++)
473    {
474       int rind;
475       spx_word16_t sign=1;
476       rind = ind[i];
477       if (rind>=shape_cb_size)
478       {
479          sign=-1;
480          rind-=shape_cb_size;
481       }
482 #ifdef FIXED_POINT
483       if (sign==1)
484       {
485          for (j=0;j<subvect_size;j++)
486             e[subvect_size*i+j]=SHL32(EXTEND32(shape_cb[rind*subvect_size+j]),SIG_SHIFT-5);
487       } else {
488          for (j=0;j<subvect_size;j++)
489             e[subvect_size*i+j]=NEG32(SHL32(EXTEND32(shape_cb[rind*subvect_size+j]),SIG_SHIFT-5));
490       }
491 #else
492       for (j=0;j<subvect_size;j++)
493          e[subvect_size*i+j]=sign*0.03125*shape_cb[rind*subvect_size+j];
494 #endif
495    }   
496    /* Update excitation */
497    for (j=0;j<nsf;j++)
498       exc[j]=ADD32(exc[j],e[j]);
499    
500    /* Update target: only update target if necessary */
501    if (update_target)
502    {
503       VARDECL(spx_word16_t *r2);
504       ALLOC(r2, nsf, spx_word16_t);
505       for (j=0;j<nsf;j++)
506          r2[j] = EXTRACT16(PSHR32(e[j] ,6));
507       syn_percep_zero16(r2, ak, awk1, awk2, r2, nsf,p, stack);
508       for (j=0;j<nsf;j++)
509          target[j]=SUB16(target[j],PSHR16(r2[j],2));
510    }
511 }
512
513
514 void split_cb_shape_sign_unquant(
515 spx_sig_t *exc,
516 const void *par,                      /* non-overlapping codebook */
517 int   nsf,                      /* number of samples in subframe */
518 SpeexBits *bits,
519 char *stack,
520 spx_int32_t *seed
521 )
522 {
523    int i,j;
524    VARDECL(int *ind);
525    VARDECL(int *signs);
526    const signed char *shape_cb;
527    int shape_cb_size, subvect_size, nb_subvect;
528    const split_cb_params *params;
529    int have_sign;
530
531    params = (const split_cb_params *) par;
532    subvect_size = params->subvect_size;
533    nb_subvect = params->nb_subvect;
534    shape_cb_size = 1<<params->shape_bits;
535    shape_cb = params->shape_cb;
536    have_sign = params->have_sign;
537
538    ALLOC(ind, nb_subvect, int);
539    ALLOC(signs, nb_subvect, int);
540
541    /* Decode codewords and gains */
542    for (i=0;i<nb_subvect;i++)
543    {
544       if (have_sign)
545          signs[i] = speex_bits_unpack_unsigned(bits, 1);
546       else
547          signs[i] = 0;
548       ind[i] = speex_bits_unpack_unsigned(bits, params->shape_bits);
549    }
550    /* Compute decoded excitation */
551    for (i=0;i<nb_subvect;i++)
552    {
553       spx_word16_t s=1;
554       if (signs[i])
555          s=-1;
556 #ifdef FIXED_POINT
557       if (s==1)
558       {
559          for (j=0;j<subvect_size;j++)
560             exc[subvect_size*i+j]=SHL32(EXTEND32(shape_cb[ind[i]*subvect_size+j]),SIG_SHIFT-5);
561       } else {
562          for (j=0;j<subvect_size;j++)
563             exc[subvect_size*i+j]=NEG32(SHL32(EXTEND32(shape_cb[ind[i]*subvect_size+j]),SIG_SHIFT-5));
564       }
565 #else
566       for (j=0;j<subvect_size;j++)
567          exc[subvect_size*i+j]+=s*0.03125*shape_cb[ind[i]*subvect_size+j];      
568 #endif
569    }
570 }
571
572 void noise_codebook_quant(
573 spx_word16_t target[],                  /* target vector */
574 spx_coef_t ak[],                        /* LPCs for this subframe */
575 spx_coef_t awk1[],                      /* Weighted LPCs for this subframe */
576 spx_coef_t awk2[],                      /* Weighted LPCs for this subframe */
577 const void *par,                      /* Codebook/search parameters*/
578 int   p,                        /* number of LPC coeffs */
579 int   nsf,                      /* number of samples in subframe */
580 spx_sig_t *exc,
581 spx_word16_t *r,
582 SpeexBits *bits,
583 char *stack,
584 int   complexity,
585 int   update_target
586 )
587 {
588    int i;
589    VARDECL(spx_word16_t *tmp);
590    ALLOC(tmp, nsf, spx_word16_t);
591    residue_percep_zero16(target, ak, awk1, awk2, tmp, nsf, p, stack);
592
593    for (i=0;i<nsf;i++)
594       exc[i]+=SHL32(EXTEND32(tmp[i]),8);
595    SPEEX_MEMSET(target, 0, nsf);
596 }
597
598
599 void noise_codebook_unquant(
600 spx_sig_t *exc,
601 const void *par,                      /* non-overlapping codebook */
602 int   nsf,                      /* number of samples in subframe */
603 SpeexBits *bits,
604 char *stack,
605 spx_int32_t *seed
606 )
607 {
608    int i;
609    /* FIXME: This is bad, but I don't think the function ever gets called anyway */
610    for (i=0;i<nsf;i++)
611       exc[i]=SHL32(EXTEND32(speex_rand(1, seed)),SIG_SHIFT);
612 }