Making the band definition the same at all frame sizes.
[opus.git] / libcelt / vq.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Written by Jean-Marc Valin */
4 /*
5    Redistribution and use in source and binary forms, with or without
6    modification, are permitted provided that the following conditions
7    are met:
8    
9    - Redistributions of source code must retain the above copyright
10    notice, this list of conditions and the following disclaimer.
11    
12    - Redistributions in binary form must reproduce the above copyright
13    notice, this list of conditions and the following disclaimer in the
14    documentation and/or other materials provided with the distribution.
15    
16    - Neither the name of the Xiph.org Foundation nor the names of its
17    contributors may be used to endorse or promote products derived from
18    this software without specific prior written permission.
19    
20    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
24    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
25    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
26    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
27    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
28    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
29    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32
33 #ifdef HAVE_CONFIG_H
34 #include "config.h"
35 #endif
36
37 #include "mathops.h"
38 #include "cwrs.h"
39 #include "vq.h"
40 #include "arch.h"
41 #include "os_support.h"
42 #include "rate.h"
43
44 #ifndef M_PI
45 #define M_PI 3.141592653
46 #endif
47 #if 0
48 static void frac_hadamard1(celt_norm *X, int len, int stride, celt_word16 c, celt_word16 s)
49 {
50    int j;
51    celt_norm *x, *y;
52    celt_norm * end;
53
54    j = 0;
55    x = X;
56    y = X+stride;
57    end = X+len;
58    do
59    {
60       celt_norm x1, x2;
61       x1 = *x;
62       x2 = *y;
63       *x++ = EXTRACT16(SHR32(MULT16_16(c,x1) + MULT16_16(s,x2),15));
64       *y++ = EXTRACT16(SHR32(MULT16_16(s,x1) - MULT16_16(c,x2),15));
65       j++;
66       if (j>=stride)
67       {
68          j=0;
69          x+=stride;
70          y+=stride;
71       }
72    } while (y<end);
73
74    /* Reverse samples so that the next level starts from the other end */
75    for (j=0;j<len>>1;j++)
76    {
77       celt_norm tmp = X[j];
78       X[j] = X[len-j-1];
79       X[len-j-1] = tmp;
80    }
81 }
82
83 #define MAX_LEVELS 8
84 static void pseudo_hadamard(celt_norm *X, int len, int dir, int stride, int K)
85 {
86    int i, N=0;
87    int transient;
88    celt_word16 gain, theta;
89    int istride[MAX_LEVELS];
90    celt_word16 c[MAX_LEVELS], s[MAX_LEVELS];
91
92    if (K >= len)
93       return;
94    transient = stride>1;
95    /*if (len>=30)
96    {
97       for (i=0;i<len;i++)
98          X[i] = 0;
99       X[30] = 1;
100       dir = -1;
101       transient = 1;
102    }*/
103    gain = celt_div((celt_word32)MULT16_16(Q15_ONE,len),(celt_word32)(3+len+4*K));
104    /* FIXME: Make that HALF16 instead of HALF32 */
105    theta = MIN16(QCONST16(.25f,15),HALF32(MULT16_16_Q15(gain,gain)));
106    c[0] = celt_cos_norm(EXTEND32(theta));
107    s[0] = celt_cos_norm(EXTEND32(SUB16(Q15ONE,theta))); /*  sin(theta) */
108
109    do {
110       istride[N] = stride;
111       stride *= 2;
112       c[N] = c[0];
113       s[N] = s[0];
114       N++;
115    } while (N<MAX_LEVELS && stride < len);
116
117    /* This should help a little bit with the transients */
118    if (transient)
119       c[0] = s[0] = QCONST16(.7071068f, 15);
120
121    /* Needs to be < 0 to prevent gaps on the side of the spreading */
122    if (dir < 0)
123    {
124       for (i=0;i<N;i++)
125          frac_hadamard1(X, len, istride[i], c[i], s[i]);
126    } else {
127       for (i=N-1;i>=0;i--)
128          frac_hadamard1(X, len, istride[i], c[i], s[i]);
129    }
130
131    /* Undo last reversal */
132    for (i=0;i<len>>1;i++)
133    {
134       celt_norm tmp = X[i];
135       X[i] = X[len-i-1];
136       X[len-i-1] = tmp;
137    }
138    /*if (len>=30)
139    {
140       for (i=0;i<len;i++)
141          printf ("%f ", X[i]);
142       printf ("\n");
143       exit(0);
144    }*/
145 }
146 #endif
147
148
149 static void exp_rotation1(celt_norm *X, int len, int dir, int stride, celt_word16 c, celt_word16 s)
150 {
151    int i;
152    celt_norm *Xptr;
153    if (dir>0)
154       s = -s;
155    Xptr = X;
156    for (i=0;i<len-stride;i++)
157    {
158       celt_norm x1, x2;
159       x1 = Xptr[0];
160       x2 = Xptr[stride];
161       Xptr[stride] = EXTRACT16(SHR32(MULT16_16(c,x2) + MULT16_16(s,x1), 15));
162       *Xptr++      = EXTRACT16(SHR32(MULT16_16(c,x1) - MULT16_16(s,x2), 15));
163    }
164    Xptr = &X[len-2*stride-1];
165    for (i=len-2*stride-1;i>=0;i--)
166    {
167       celt_norm x1, x2;
168       x1 = Xptr[0];
169       x2 = Xptr[stride];
170       Xptr[stride] = EXTRACT16(SHR32(MULT16_16(c,x2) + MULT16_16(s,x1), 15));
171       *Xptr--      = EXTRACT16(SHR32(MULT16_16(c,x1) - MULT16_16(s,x2), 15));
172    }
173 }
174
175 static void exp_rotation(celt_norm *X, int len, int dir, int stride, int K)
176 {
177    celt_word16 c, s;
178    celt_word16 gain, theta;
179    int stride2=0;
180    /*int i;
181    if (len>=30)
182    {
183       for (i=0;i<len;i++)
184          X[i] = 0;
185       X[14] = 1;
186       K=5;
187    }*/
188    /*if (stride>1)
189    {
190       pseudo_hadamard(X, len, dir, stride, K);
191       return;
192    }*/
193    if (2*K>=len)
194       return;
195    gain = celt_div((celt_word32)MULT16_16(Q15_ONE,len),(celt_word32)(len+10*K));
196    /* FIXME: Make that HALF16 instead of HALF32 */
197    theta = HALF32(MULT16_16_Q15(gain,gain));
198
199    c = celt_cos_norm(EXTEND32(theta));
200    s = celt_cos_norm(EXTEND32(SUB16(Q15ONE,theta))); /*  sin(theta) */
201
202 #if 0
203    if (len>=8*stride)
204       stride2 = stride*floor(.5+sqrt(len*1.f/stride));
205 #else
206    if (len>=8*stride)
207    {
208       stride2 = 1;
209       /* This is just a simple way of computing sqrt(len/stride) with rounding.
210          It's basically incrementing long as (stride2+0.5)^2 < len/stride.
211          I _think_ it is bit-exact */
212       while ((stride2*stride2+stride2)*stride + (stride>>2) < len)
213          stride2++;
214       stride2 *= stride;
215    }
216 #endif
217    if (dir < 0)
218    {
219       if (stride2)
220          exp_rotation1(X, len, dir, stride2, s, c);
221       exp_rotation1(X, len, dir, stride, c, s);
222    } else {
223       exp_rotation1(X, len, dir, stride, c, s);
224       if (stride2)
225          exp_rotation1(X, len, dir, stride2, s, c);
226    }
227
228    /*if (len>=30)
229    {
230       for (i=0;i<len;i++)
231          printf ("%f ", X[i]);
232       printf ("\n");
233       exit(0);
234    }*/
235 }
236
237 /** Takes the pitch vector and the decoded residual vector, computes the gain
238     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
239 static void normalise_residual(int * restrict iy, celt_norm * restrict X, int N, int K, celt_word32 Ryy)
240 {
241    int i;
242 #ifdef FIXED_POINT
243    int k;
244 #endif
245    celt_word32 t;
246    celt_word16 g;
247
248 #ifdef FIXED_POINT
249    k = celt_ilog2(Ryy)>>1;
250 #endif
251    t = VSHR32(Ryy, (k-7)<<1);
252    g = celt_rsqrt_norm(t);
253
254    i=0;
255    do
256       X[i] = EXTRACT16(PSHR32(MULT16_16(g, iy[i]), k+1));
257    while (++i < N);
258 }
259
260 void alg_quant(celt_norm *X, int N, int K, int spread, int resynth, ec_enc *enc)
261 {
262    VARDECL(celt_norm, y);
263    VARDECL(int, iy);
264    VARDECL(celt_word16, signx);
265    int j, is;
266    celt_word16 s;
267    int pulsesLeft;
268    celt_word32 sum;
269    celt_word32 xy, yy;
270    int N_1; /* Inverse of N, in Q14 format (even for float) */
271 #ifdef FIXED_POINT
272    int yshift;
273 #endif
274    SAVE_STACK;
275
276    K = get_pulses(K);
277 #ifdef FIXED_POINT
278    yshift = 13-celt_ilog2(K);
279 #endif
280
281    ALLOC(y, N, celt_norm);
282    ALLOC(iy, N, int);
283    ALLOC(signx, N, celt_word16);
284    N_1 = 512/N;
285    
286    if (spread)
287       exp_rotation(X, N, 1, spread, K);
288
289    sum = 0;
290    j=0; do {
291       if (X[j]>0)
292          signx[j]=1;
293       else {
294          signx[j]=-1;
295          X[j]=-X[j];
296       }
297       iy[j] = 0;
298       y[j] = 0;
299    } while (++j<N);
300
301    xy = yy = 0;
302
303    pulsesLeft = K;
304
305    /* Do a pre-search by projecting on the pyramid */
306    if (K > (N>>1))
307    {
308       celt_word16 rcp;
309       j=0; do {
310          sum += X[j];
311       }  while (++j<N);
312
313 #ifdef FIXED_POINT
314       if (sum <= K)
315 #else
316       if (sum <= EPSILON)
317 #endif
318       {
319          X[0] = QCONST16(1.f,14);
320          j=1; do
321             X[j]=0;
322          while (++j<N);
323          sum = QCONST16(1.f,14);
324       }
325       /* Do we have sufficient accuracy here? */
326       rcp = EXTRACT16(MULT16_32_Q16(K-1, celt_rcp(sum)));
327       j=0; do {
328 #ifdef FIXED_POINT
329          /* It's really important to round *towards zero* here */
330          iy[j] = MULT16_16_Q15(X[j],rcp);
331 #else
332          iy[j] = floor(rcp*X[j]);
333 #endif
334          y[j] = SHL16(iy[j],yshift);
335          yy = MAC16_16(yy, y[j],y[j]);
336          xy = MAC16_16(xy, X[j],y[j]);
337          y[j] *= 2;
338          pulsesLeft -= iy[j];
339       }  while (++j<N);
340    }
341    celt_assert2(pulsesLeft>=1, "Allocated too many pulses in the quick pass");
342
343    while (pulsesLeft > 0)
344    {
345       int pulsesAtOnce=1;
346       int best_id;
347       celt_word16 magnitude;
348       celt_word32 best_num = -VERY_LARGE16;
349       celt_word16 best_den = 0;
350 #ifdef FIXED_POINT
351       int rshift;
352 #endif
353       /* Decide on how many pulses to find at once */
354       pulsesAtOnce = (pulsesLeft*N_1)>>9; /* pulsesLeft/N */
355       if (pulsesAtOnce<1)
356          pulsesAtOnce = 1;
357 #ifdef FIXED_POINT
358       rshift = yshift+1+celt_ilog2(K-pulsesLeft+pulsesAtOnce);
359 #endif
360       magnitude = SHL16(pulsesAtOnce, yshift);
361
362       best_id = 0;
363       /* The squared magnitude term gets added anyway, so we might as well 
364          add it outside the loop */
365       yy = MAC16_16(yy, magnitude,magnitude);
366       /* Choose between fast and accurate strategy depending on where we are in the search */
367          /* This should ensure that anything we can process will have a better score */
368       j=0;
369       do {
370          celt_word16 Rxy, Ryy;
371          /* Select sign based on X[j] alone */
372          s = magnitude;
373          /* Temporary sums of the new pulse(s) */
374          Rxy = EXTRACT16(SHR32(MAC16_16(xy, s,X[j]),rshift));
375          /* We're multiplying y[j] by two so we don't have to do it here */
376          Ryy = EXTRACT16(SHR32(MAC16_16(yy, s,y[j]),rshift));
377             
378             /* Approximate score: we maximise Rxy/sqrt(Ryy) (we're guaranteed that 
379          Rxy is positive because the sign is pre-computed) */
380          Rxy = MULT16_16_Q15(Rxy,Rxy);
381             /* The idea is to check for num/den >= best_num/best_den, but that way
382          we can do it without any division */
383          /* OPT: Make sure to use conditional moves here */
384          if (MULT16_16(best_den, Rxy) > MULT16_16(Ryy, best_num))
385          {
386             best_den = Ryy;
387             best_num = Rxy;
388             best_id = j;
389          }
390       } while (++j<N);
391       
392       j = best_id;
393       is = pulsesAtOnce;
394       s = SHL16(is, yshift);
395
396       /* Updating the sums of the new pulse(s) */
397       xy = xy + MULT16_16(s,X[j]);
398       /* We're multiplying y[j] by two so we don't have to do it here */
399       yy = yy + MULT16_16(s,y[j]);
400
401       /* Only now that we've made the final choice, update y/iy */
402       /* Multiplying y[j] by 2 so we don't have to do it everywhere else */
403       y[j] += 2*s;
404       iy[j] += is;
405       pulsesLeft -= pulsesAtOnce;
406    }
407    j=0;
408    do {
409       X[j] = MULT16_16(signx[j],X[j]);
410       if (signx[j] < 0)
411          iy[j] = -iy[j];
412    } while (++j<N);
413    encode_pulses(iy, N, K, enc);
414    
415    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
416    due to the recursive computation used in quantisation. */
417    if (resynth)
418    {
419       normalise_residual(iy, X, N, K, EXTRACT16(SHR32(yy,2*yshift)));
420       if (spread)
421          exp_rotation(X, N, -1, spread, K);
422    }
423    RESTORE_STACK;
424 }
425
426
427 /** Decode pulse vector and combine the result with the pitch vector to produce
428     the final normalised signal in the current band. */
429 void alg_unquant(celt_norm *X, int N, int K, int spread, ec_dec *dec)
430 {
431    int i;
432    celt_word32 Ryy;
433    VARDECL(int, iy);
434    SAVE_STACK;
435    K = get_pulses(K);
436    ALLOC(iy, N, int);
437    decode_pulses(iy, N, K, dec);
438    Ryy = 0;
439    i=0;
440    do {
441       Ryy = MAC16_16(Ryy, iy[i], iy[i]);
442    } while (++i < N);
443    normalise_residual(iy, X, N, K, Ryy);
444    if (spread)
445       exp_rotation(X, N, -1, spread, K);
446    RESTORE_STACK;
447 }
448
449 celt_word16 renormalise_vector(celt_norm *X, celt_word16 value, int N, int stride)
450 {
451    int i;
452    celt_word32 E = EPSILON;
453    celt_word16 g;
454    celt_word32 t;
455    celt_norm *xptr = X;
456    for (i=0;i<N;i++)
457    {
458       E = MAC16_16(E, *xptr, *xptr);
459       xptr += stride;
460    }
461 #ifdef FIXED_POINT
462    int k = celt_ilog2(E)>>1;
463 #endif
464    t = VSHR32(E, (k-7)<<1);
465    g = MULT16_16_Q15(value, celt_rsqrt_norm(t));
466
467    xptr = X;
468    for (i=0;i<N;i++)
469    {
470       *xptr = EXTRACT16(PSHR32(MULT16_16(g, *xptr), k+1));
471       xptr += stride;
472    }
473    return celt_sqrt(E);
474 }
475
476 static void fold(const CELTMode *m, int start, int N, const celt_norm * restrict Y, celt_norm * restrict P, int N0, int B, int M)
477 {
478    int j;
479    int id = N0 % B;
480    while (id < M*m->eBands[start])
481       id += B;
482    /* Here, we assume that id will never be greater than N0, i.e. that 
483       no band is wider than N0. In the unlikely case it happens, we set
484       everything to zero */
485    /*{
486            int offset = (N0*C - (id+C*N))/2;
487            if (offset > C*N0/16)
488                    offset = C*N0/16;
489            offset -= offset % (C*B);
490            if (offset < 0)
491                    offset = 0;
492            //printf ("%d\n", offset);
493            id += offset;
494    }*/
495    if (id+N>N0)
496       for (j=0;j<N;j++)
497          P[j] = 0;
498    else
499       for (j=0;j<N;j++)
500          P[j] = Y[id++];
501 }
502
503 void intra_fold(const CELTMode *m, int start, int N, const celt_norm * restrict Y, celt_norm * restrict P, int N0, int B, int M)
504 {
505    fold(m, start, N, Y, P, N0, B, M);
506    renormalise_vector(P, Q15ONE, N, 1);
507 }
508