Brought back a version split_cb_search_shape_sign optimized for complexity 1
[speexdsp.git] / libspeex / cb_search.c
1 /* Copyright (C) 2002 Jean-Marc Valin 
2    File: cb_search.c
3
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "cb_search.h"
37 #include "filters.h"
38 #include "stack_alloc.h"
39 #include "vq.h"
40 #include "misc.h"
41 #include <stdio.h>
42
43 #ifdef _USE_SSE
44 #include "cb_search_sse.h"
45 #else
46
47 static void compute_weighted_codebook(const signed char *shape_cb, const spx_sig_t *r, spx_word16_t *resp, spx_word16_t *resp2, spx_word32_t *E, int shape_cb_size, int subvect_size, char *stack)
48 {
49    int i, j, k;
50    for (i=0;i<shape_cb_size;i++)
51    {
52       spx_word16_t *res;
53       const signed char *shape;
54
55       res = resp+i*subvect_size;
56       shape = shape_cb+i*subvect_size;
57
58       /* Compute codeword response using convolution with impulse response */
59       for(j=0;j<subvect_size;j++)
60       {
61          spx_word32_t resj=0;
62          for (k=0;k<=j;k++)
63             resj = MAC16_16_Q11(resj,shape[k],r[j-k]);
64 #ifndef FIXED_POINT
65          resj *= 0.03125;
66 #endif
67          res[j] = resj;
68          /*printf ("%d\n", (int)res[j]);*/
69       }
70       
71       /* Compute codeword energy */
72       E[i]=0;
73       for(j=0;j<subvect_size;j++)
74          E[i]=ADD32(E[i],MULT16_16(res[j],res[j]));
75    }
76
77 }
78
79 #endif
80
81
82
83 #if 0
84
85 void split_cb_search_shape_sign(
86 spx_sig_t target[],                     /* target vector */
87 spx_coef_t ak[],                        /* LPCs for this subframe */
88 spx_coef_t awk1[],                      /* Weighted LPCs for this subframe */
89 spx_coef_t awk2[],                      /* Weighted LPCs for this subframe */
90 const void *par,                      /* Codebook/search parameters*/
91 int   p,                        /* number of LPC coeffs */
92 int   nsf,                      /* number of samples in subframe */
93 spx_sig_t *exc,
94 spx_sig_t *r,
95 SpeexBits *bits,
96 char *stack,
97 int   complexity,
98 int   update_target
99                                )
100 {
101    int i,j,k,m,n,q;
102    spx_word16_t *resp;
103 #ifdef _USE_SSE
104    __m128 *resp2;
105    __m128 *E;
106 #else
107    spx_word16_t *resp2;
108    spx_word32_t *E;
109 #endif
110    spx_word16_t *t;
111    spx_sig_t *e, *r2;
112    const signed char *shape_cb;
113    int shape_cb_size, subvect_size, nb_subvect;
114    split_cb_params *params;
115    int N=2;
116    int best_index;
117    spx_word32_t best_dist;
118    int have_sign;
119    N=complexity;
120    if (N>10)
121       N=10;
122    if (N<1)
123       N=1;
124    
125    params = (split_cb_params *) par;
126    subvect_size = params->subvect_size;
127    nb_subvect = params->nb_subvect;
128    shape_cb_size = 1<<params->shape_bits;
129    shape_cb = params->shape_cb;
130    have_sign = params->have_sign;
131    resp = PUSH(stack, shape_cb_size*subvect_size, spx_word16_t);
132 #ifdef _USE_SSE
133    resp2 = PUSH(stack, (shape_cb_size*subvect_size)>>2, __m128);
134    E = PUSH(stack, shape_cb_size>>2, __m128);
135 #else
136    resp2 = resp;
137    E = PUSH(stack, shape_cb_size, spx_word32_t);
138 #endif
139    t = PUSH(stack, nsf, spx_word16_t);
140    e = PUSH(stack, nsf, spx_sig_t);
141    r2 = PUSH(stack, nsf, spx_sig_t);
142    
143    /* FIXME: make that adaptive? */
144    for (i=0;i<nsf;i++)
145       t[i]=SHR(target[i],6);
146
147    compute_weighted_codebook(shape_cb, r, resp, resp2, E, shape_cb_size, subvect_size, stack);
148
149    for (i=0;i<nb_subvect;i++)
150    {
151       spx_word16_t *x=t+subvect_size*i;
152       /*Find new n-best based on previous n-best j*/
153       if (have_sign)
154          vq_nbest_sign(x, resp2, subvect_size, shape_cb_size, E, 1, &best_index, &best_dist, stack);
155       else
156          vq_nbest(x, resp2, subvect_size, shape_cb_size, E, 1, &best_index, &best_dist, stack);
157       
158       speex_bits_pack(bits,best_index,params->shape_bits+have_sign);
159       
160       /* New code: update only enough of the target to calculate error*/
161       {
162          int rind;
163          spx_word16_t *res;
164          spx_word16_t sign=1;
165          rind = best_index;
166          if (rind>=shape_cb_size)
167          {
168             sign=-1;
169             rind-=shape_cb_size;
170          }
171          res = resp+rind*subvect_size;
172          if (sign>0)
173             for (m=0;m<subvect_size;m++)
174                t[subvect_size*i+m] -= res[m];
175          else
176             for (m=0;m<subvect_size;m++)
177                t[subvect_size*i+m] += res[m];
178
179 #ifdef FIXED_POINT
180          if (sign)
181          {
182             for (j=0;j<subvect_size;j++)
183                e[subvect_size*i+j]=SHL((spx_word32_t)shape_cb[rind*subvect_size+j],SIG_SHIFT-5);
184          } else {
185             for (j=0;j<subvect_size;j++)
186                e[subvect_size*i+j]=-SHL((spx_word32_t)shape_cb[rind*subvect_size+j],SIG_SHIFT-5);
187          }
188 #else
189          for (j=0;j<subvect_size;j++)
190             e[subvect_size*i+j]=sign*0.03125*shape_cb[rind*subvect_size+j];
191 #endif
192       
193       }
194             
195       for (m=0;m<subvect_size;m++)
196       {
197          spx_word16_t g;
198          int rind;
199          spx_word16_t sign=1;
200          rind = best_index;
201          if (rind>=shape_cb_size)
202          {
203             sign=-1;
204             rind-=shape_cb_size;
205          }
206          
207          q=subvect_size-m;
208 #ifdef FIXED_POINT
209          g=sign*shape_cb[rind*subvect_size+m];
210          for (n=subvect_size*(i+1);n<nsf;n++,q++)
211             t[n] = SUB32(t[n],MULT16_16_Q11(g,r[q]));
212 #else
213          g=sign*0.03125*shape_cb[rind*subvect_size+m];
214          for (n=subvect_size*(i+1);n<nsf;n++,q++)
215             t[n] = SUB32(t[n],g*r[q]);
216 #endif
217       }
218
219
220    }
221
222    /* Update excitation */
223    for (j=0;j<nsf;j++)
224       exc[j]+=e[j];
225    
226    /* Update target: only update target if necessary */
227    if (update_target)
228    {
229       syn_percep_zero(e, ak, awk1, awk2, r2, nsf,p, stack);
230       for (j=0;j<nsf;j++)
231          target[j]-=r2[j];
232    }
233 }
234
235 #else
236
237 void split_cb_search_shape_sign(
238 spx_sig_t target[],                     /* target vector */
239 spx_coef_t ak[],                        /* LPCs for this subframe */
240 spx_coef_t awk1[],                      /* Weighted LPCs for this subframe */
241 spx_coef_t awk2[],                      /* Weighted LPCs for this subframe */
242 const void *par,                      /* Codebook/search parameters*/
243 int   p,                        /* number of LPC coeffs */
244 int   nsf,                      /* number of samples in subframe */
245 spx_sig_t *exc,
246 spx_sig_t *r,
247 SpeexBits *bits,
248 char *stack,
249 int   complexity,
250 int   update_target
251 )
252 {
253    int i,j,k,m,n,q;
254    spx_word16_t *resp;
255 #ifdef _USE_SSE
256    __m128 *resp2;
257    __m128 *E;
258 #else
259    spx_word16_t *resp2;
260    spx_word32_t *E;
261 #endif
262    spx_word16_t *t;
263    spx_sig_t *e, *r2;
264    spx_word16_t *tmp;
265    spx_word32_t *ndist, *odist;
266    int *itmp;
267    spx_word16_t **ot, **nt;
268    int **nind, **oind;
269    int *ind;
270    const signed char *shape_cb;
271    int shape_cb_size, subvect_size, nb_subvect;
272    split_cb_params *params;
273    int N=2;
274    int *best_index;
275    spx_word32_t *best_dist;
276    int have_sign;
277    N=complexity;
278    if (N>10)
279       N=10;
280    if (N<1)
281       N=1;
282    
283    ot=PUSH(stack, N, spx_word16_t*);
284    nt=PUSH(stack, N, spx_word16_t*);
285    oind=PUSH(stack, N, int*);
286    nind=PUSH(stack, N, int*);
287
288    params = (split_cb_params *) par;
289    subvect_size = params->subvect_size;
290    nb_subvect = params->nb_subvect;
291    shape_cb_size = 1<<params->shape_bits;
292    shape_cb = params->shape_cb;
293    have_sign = params->have_sign;
294    resp = PUSH(stack, shape_cb_size*subvect_size, spx_word16_t);
295 #ifdef _USE_SSE
296    resp2 = PUSH(stack, (shape_cb_size*subvect_size)>>2, __m128);
297    E = PUSH(stack, shape_cb_size>>2, __m128);
298 #else
299    resp2 = resp;
300    E = PUSH(stack, shape_cb_size, spx_word32_t);
301 #endif
302    t = PUSH(stack, nsf, spx_word16_t);
303    e = PUSH(stack, nsf, spx_sig_t);
304    r2 = PUSH(stack, nsf, spx_sig_t);
305    ind = PUSH(stack, nb_subvect, int);
306
307    tmp = PUSH(stack, 2*N*nsf, spx_word16_t);
308    for (i=0;i<N;i++)
309    {
310       ot[i]=tmp;
311       tmp += nsf;
312       nt[i]=tmp;
313       tmp += nsf;
314    }
315    best_index = PUSH(stack, N, int);
316    best_dist = PUSH(stack, N, spx_word32_t);
317    ndist = PUSH(stack, N, spx_word32_t);
318    odist = PUSH(stack, N, spx_word32_t);
319    
320    itmp = PUSH(stack, 2*N*nb_subvect, int);
321    for (i=0;i<N;i++)
322    {
323       nind[i]=itmp;
324       itmp+=nb_subvect;
325       oind[i]=itmp;
326       itmp+=nb_subvect;
327       for (j=0;j<nb_subvect;j++)
328          nind[i][j]=oind[i][j]=-1;
329    }
330    
331    /* FIXME: make that adaptive? */
332    for (i=0;i<nsf;i++)
333       t[i]=SHR(target[i],6);
334
335    for (j=0;j<N;j++)
336       for (i=0;i<nsf;i++)
337          ot[j][i]=t[i];
338
339    /*for (i=0;i<nsf;i++)
340      printf ("%d\n", (int)t[i]);*/
341
342    /* Pre-compute codewords response and energy */
343    compute_weighted_codebook(shape_cb, r, resp, resp2, E, shape_cb_size, subvect_size, stack);
344
345    for (j=0;j<N;j++)
346       odist[j]=0;
347    /*For all subvectors*/
348    for (i=0;i<nb_subvect;i++)
349    {
350       /*"erase" nbest list*/
351       for (j=0;j<N;j++)
352          ndist[j]=-2;
353
354       /*For all n-bests of previous subvector*/
355       for (j=0;j<N;j++)
356       {
357          spx_word16_t *x=ot[j]+subvect_size*i;
358          /*Find new n-best based on previous n-best j*/
359          if (have_sign)
360             vq_nbest_sign(x, resp2, subvect_size, shape_cb_size, E, N, best_index, best_dist, stack);
361          else
362             vq_nbest(x, resp2, subvect_size, shape_cb_size, E, N, best_index, best_dist, stack);
363
364          /*For all new n-bests*/
365          for (k=0;k<N;k++)
366          {
367             spx_word16_t *ct;
368             spx_word32_t err=0;
369             ct = ot[j];
370             /*update target*/
371
372             /*previous target*/
373             for (m=i*subvect_size;m<(i+1)*subvect_size;m++)
374                t[m]=ct[m];
375
376             /* New code: update only enough of the target to calculate error*/
377             {
378                int rind;
379                spx_word16_t *res;
380                spx_word16_t sign=1;
381                rind = best_index[k];
382                if (rind>=shape_cb_size)
383                {
384                   sign=-1;
385                   rind-=shape_cb_size;
386                }
387                res = resp+rind*subvect_size;
388                if (sign>0)
389                   for (m=0;m<subvect_size;m++)
390                      t[subvect_size*i+m] -= res[m];
391                else
392                   for (m=0;m<subvect_size;m++)
393                      t[subvect_size*i+m] += res[m];
394             }
395             
396             /*compute error (distance)*/
397             err=odist[j];
398             for (m=i*subvect_size;m<(i+1)*subvect_size;m++)
399                err = MAC16_16(err, t[m],t[m]);
400             /*update n-best list*/
401             if (err<ndist[N-1] || ndist[N-1]<-1)
402             {
403
404                /*previous target (we don't care what happened before*/
405                for (m=(i+1)*subvect_size;m<nsf;m++)
406                   t[m]=ct[m];
407                /* New code: update the rest of the target only if it's worth it */
408                for (m=0;m<subvect_size;m++)
409                {
410                   spx_word16_t g;
411                   int rind;
412                   spx_word16_t sign=1;
413                   rind = best_index[k];
414                   if (rind>=shape_cb_size)
415                   {
416                      sign=-1;
417                      rind-=shape_cb_size;
418                   }
419
420                   q=subvect_size-m;
421 #ifdef FIXED_POINT
422                   g=sign*shape_cb[rind*subvect_size+m];
423                   for (n=subvect_size*(i+1);n<nsf;n++,q++)
424                      t[n] = SUB32(t[n],MULT16_16_Q11(g,r[q]));
425 #else
426                   g=sign*0.03125*shape_cb[rind*subvect_size+m];
427                   for (n=subvect_size*(i+1);n<nsf;n++,q++)
428                      t[n] = SUB32(t[n],g*r[q]);
429 #endif
430                }
431
432
433                for (m=0;m<N;m++)
434                {
435                   if (err < ndist[m] || ndist[m]<-1)
436                   {
437                      for (n=N-1;n>m;n--)
438                      {
439                         for (q=(i+1)*subvect_size;q<nsf;q++)
440                            nt[n][q]=nt[n-1][q];
441                         for (q=0;q<nb_subvect;q++)
442                            nind[n][q]=nind[n-1][q];
443                         ndist[n]=ndist[n-1];
444                      }
445                      for (q=(i+1)*subvect_size;q<nsf;q++)
446                         nt[m][q]=t[q];
447                      for (q=0;q<nb_subvect;q++)
448                         nind[m][q]=oind[j][q];
449                      nind[m][i]=best_index[k];
450                      ndist[m]=err;
451                      break;
452                   }
453                }
454             }
455          }
456          if (i==0)
457            break;
458       }
459
460       /*update old-new data*/
461       /* just swap pointers instead of a long copy */
462       {
463          spx_word16_t **tmp2;
464          tmp2=ot;
465          ot=nt;
466          nt=tmp2;
467       }
468       for (j=0;j<N;j++)
469          for (m=0;m<nb_subvect;m++)
470             oind[j][m]=nind[j][m];
471       for (j=0;j<N;j++)
472          odist[j]=ndist[j];
473    }
474
475    /*save indices*/
476    for (i=0;i<nb_subvect;i++)
477    {
478       ind[i]=nind[0][i];
479       speex_bits_pack(bits,ind[i],params->shape_bits+have_sign);
480    }
481    
482    /* Put everything back together */
483    for (i=0;i<nb_subvect;i++)
484    {
485       int rind;
486       spx_word16_t sign=1;
487       rind = ind[i];
488       if (rind>=shape_cb_size)
489       {
490          sign=-1;
491          rind-=shape_cb_size;
492       }
493 #ifdef FIXED_POINT
494       if (sign==1)
495       {
496          for (j=0;j<subvect_size;j++)
497             e[subvect_size*i+j]=SHL((spx_word32_t)shape_cb[rind*subvect_size+j],SIG_SHIFT-5);
498       } else {
499          for (j=0;j<subvect_size;j++)
500             e[subvect_size*i+j]=-SHL((spx_word32_t)shape_cb[rind*subvect_size+j],SIG_SHIFT-5);
501       }
502 #else
503       for (j=0;j<subvect_size;j++)
504          e[subvect_size*i+j]=sign*0.03125*shape_cb[rind*subvect_size+j];
505 #endif
506    }   
507    /* Update excitation */
508    for (j=0;j<nsf;j++)
509       exc[j]+=e[j];
510    
511    /* Update target: only update target if necessary */
512    if (update_target)
513    {
514       syn_percep_zero(e, ak, awk1, awk2, r2, nsf,p, stack);
515       for (j=0;j<nsf;j++)
516          target[j]-=r2[j];
517    }
518 }
519 #endif
520
521 void split_cb_shape_sign_unquant(
522 spx_sig_t *exc,
523 const void *par,                      /* non-overlapping codebook */
524 int   nsf,                      /* number of samples in subframe */
525 SpeexBits *bits,
526 char *stack
527 )
528 {
529    int i,j;
530    int *ind, *signs;
531    const signed char *shape_cb;
532    int shape_cb_size, subvect_size, nb_subvect;
533    split_cb_params *params;
534    int have_sign;
535
536    params = (split_cb_params *) par;
537    subvect_size = params->subvect_size;
538    nb_subvect = params->nb_subvect;
539    shape_cb_size = 1<<params->shape_bits;
540    shape_cb = params->shape_cb;
541    have_sign = params->have_sign;
542
543    ind = PUSH(stack, nb_subvect, int);
544    signs = PUSH(stack, nb_subvect, int);
545
546    /* Decode codewords and gains */
547    for (i=0;i<nb_subvect;i++)
548    {
549       if (have_sign)
550          signs[i] = speex_bits_unpack_unsigned(bits, 1);
551       else
552          signs[i] = 0;
553       ind[i] = speex_bits_unpack_unsigned(bits, params->shape_bits);
554    }
555    /* Compute decoded excitation */
556    for (i=0;i<nb_subvect;i++)
557    {
558       spx_word16_t s=1;
559       if (signs[i])
560          s=-1;
561 #ifdef FIXED_POINT
562       if (s==1)
563       {
564          for (j=0;j<subvect_size;j++)
565             exc[subvect_size*i+j]=SHL((spx_word32_t)shape_cb[ind[i]*subvect_size+j],SIG_SHIFT-5);
566       } else {
567          for (j=0;j<subvect_size;j++)
568             exc[subvect_size*i+j]=-SHL((spx_word32_t)shape_cb[ind[i]*subvect_size+j],SIG_SHIFT-5);
569       }
570 #else
571       for (j=0;j<subvect_size;j++)
572          exc[subvect_size*i+j]+=s*0.03125*shape_cb[ind[i]*subvect_size+j];      
573 #endif
574    }
575 }
576
577 void noise_codebook_quant(
578 spx_sig_t target[],                     /* target vector */
579 spx_coef_t ak[],                        /* LPCs for this subframe */
580 spx_coef_t awk1[],                      /* Weighted LPCs for this subframe */
581 spx_coef_t awk2[],                      /* Weighted LPCs for this subframe */
582 const void *par,                      /* Codebook/search parameters*/
583 int   p,                        /* number of LPC coeffs */
584 int   nsf,                      /* number of samples in subframe */
585 spx_sig_t *exc,
586 spx_sig_t *r,
587 SpeexBits *bits,
588 char *stack,
589 int   complexity,
590 int   update_target
591 )
592 {
593    int i;
594    spx_sig_t *tmp=PUSH(stack, nsf, spx_sig_t);
595    residue_percep_zero(target, ak, awk1, awk2, tmp, nsf, p, stack);
596
597    for (i=0;i<nsf;i++)
598       exc[i]+=tmp[i];
599    for (i=0;i<nsf;i++)
600       target[i]=0;
601
602 }
603
604
605 void noise_codebook_unquant(
606 spx_sig_t *exc,
607 const void *par,                      /* non-overlapping codebook */
608 int   nsf,                      /* number of samples in subframe */
609 SpeexBits *bits,
610 char *stack
611 )
612 {
613    speex_rand_vec(1, exc, nsf);
614 }