Re-introducing the successive rotations as a way to control low-bitrate
[opus.git] / libcelt / vq.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "mathops.h"
37 #include "cwrs.h"
38 #include "vq.h"
39 #include "arch.h"
40 #include "os_support.h"
41 #include "rate.h"
42
43 static void exp_rotation(celt_norm_t *X, int len, int dir, int stride, int K)
44 {
45    int i, k, iter;
46    celt_word16_t c, s;
47    celt_word16_t gain, theta;
48    celt_norm_t *Xptr;
49    gain = celt_div((celt_word32_t)MULT16_16(Q15_ONE,len),(celt_word32_t)(len+2*K*((K>>1)+1)));
50    /* FIXME: Make that HALF16 instead of HALF32 */
51    theta = SUB16(Q15ONE, HALF32(MULT16_16_Q15(gain,gain)));
52    /*if (len==30)
53    {
54    for (i=0;i<len;i++)
55    X[i] = 0;
56    X[14] = 1;
57 }*/ 
58    c = celt_cos_norm(EXTEND32(theta));
59    s = dir*celt_cos_norm(EXTEND32(SUB16(Q15ONE,theta))); /*  sin(theta) */
60    if (stride == 1)
61       stride = 2;
62    iter = 1;
63    for (k=0;k<iter;k++)
64    {
65       /* We could use MULT16_16_P15 instead of MULT16_16_Q15 for more accuracy, 
66       but at this point, I really don't think it's necessary */
67       Xptr = X;
68       for (i=0;i<len-stride;i++)
69       {
70          celt_norm_t x1, x2;
71          x1 = Xptr[0];
72          x2 = Xptr[stride];
73          Xptr[stride] = MULT16_16_Q15(c,x2) + MULT16_16_Q15(s,x1);
74          *Xptr++      = MULT16_16_Q15(c,x1) - MULT16_16_Q15(s,x2);
75       }
76       Xptr = &X[len-2*stride-1];
77       for (i=len-2*stride-1;i>=0;i--)
78       {
79          celt_norm_t x1, x2;
80          x1 = Xptr[0];
81          x2 = Xptr[stride];
82          Xptr[stride] = MULT16_16_Q15(c,x2) + MULT16_16_Q15(s,x1);
83          *Xptr--      = MULT16_16_Q15(c,x1) - MULT16_16_Q15(s,x2);
84       }
85    }
86    /*if (len==30)
87    {
88    for (i=0;i<len;i++)
89    printf ("%f ", X[i]);
90    printf ("\n");
91    exit(0);
92 }*/
93 }
94
95
96 /** Takes the pitch vector and the decoded residual vector, computes the gain
97     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
98 static void mix_pitch_and_residual(int * restrict iy, celt_norm_t * restrict X, int N, int K, const celt_norm_t * restrict P)
99 {
100    int i;
101    celt_word32_t Ryp, Ryy, Rpp;
102    celt_word16_t ryp, ryy, rpp;
103    celt_word32_t g;
104    VARDECL(celt_norm_t, y);
105 #ifdef FIXED_POINT
106    int yshift;
107 #endif
108    SAVE_STACK;
109 #ifdef FIXED_POINT
110    yshift = 13-celt_ilog2(K);
111 #endif
112    ALLOC(y, N, celt_norm_t);
113
114    Rpp = 0;
115    i=0;
116    do {
117       Rpp = MAC16_16(Rpp,P[i],P[i]);
118       y[i] = SHL16(iy[i],yshift);
119    } while (++i < N);
120
121    Ryp = 0;
122    Ryy = 0;
123    /* If this doesn't generate a dual MAC (on supported archs), fire the compiler guy */
124    i=0;
125    do {
126       Ryp = MAC16_16(Ryp, y[i], P[i]);
127       Ryy = MAC16_16(Ryy, y[i], y[i]);
128    } while (++i < N);
129
130    ryp = ROUND16(Ryp,14);
131    ryy = ROUND16(Ryy,14);
132    rpp = ROUND16(Rpp,14);
133    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
134    g = MULT16_32_Q15(celt_sqrt(MAC16_16(Ryy, ryp,ryp) - MULT16_16(ryy,rpp)) - ryp,
135                      celt_rcp(SHR32(Ryy,9)));
136
137    i=0;
138    do 
139       X[i] = ADD16(P[i], ROUND16(MULT16_16(y[i], g),11));
140    while (++i < N);
141
142    RESTORE_STACK;
143 }
144
145
146 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, int spread, celt_norm_t *P, ec_enc *enc)
147 {
148    VARDECL(celt_norm_t, y);
149    VARDECL(int, iy);
150    VARDECL(celt_word16_t, signx);
151    int j, is;
152    celt_word16_t s;
153    int pulsesLeft;
154    celt_word32_t sum;
155    celt_word32_t xy, yy, yp;
156    celt_word16_t Rpp;
157    int N_1; /* Inverse of N, in Q14 format (even for float) */
158 #ifdef FIXED_POINT
159    int yshift;
160 #endif
161    SAVE_STACK;
162
163    K = get_pulses(K);
164 #ifdef FIXED_POINT
165    yshift = 13-celt_ilog2(K);
166 #endif
167
168    ALLOC(y, N, celt_norm_t);
169    ALLOC(iy, N, int);
170    ALLOC(signx, N, celt_word16_t);
171    N_1 = 512/N;
172    
173    if (spread)
174       exp_rotation(X, N, 1, spread, K);
175
176    sum = 0;
177    j=0; do {
178       X[j] -= P[j];
179       if (X[j]>0)
180          signx[j]=1;
181       else {
182          signx[j]=-1;
183          X[j]=-X[j];
184          P[j]=-P[j];
185       }
186       iy[j] = 0;
187       y[j] = 0;
188       sum = MAC16_16(sum, P[j],P[j]);
189    } while (++j<N);
190    Rpp = ROUND16(sum, NORM_SHIFT);
191
192    celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
193
194    xy = yy = yp = 0;
195
196    pulsesLeft = K;
197
198    /* Do a pre-search by projecting on the pyramid */
199    if (K > (N>>1))
200    {
201       celt_word16_t rcp;
202       sum=0;
203       j=0; do {
204          sum += X[j];
205       }  while (++j<N);
206
207 #ifdef FIXED_POINT
208       if (sum <= K)
209 #else
210       if (sum <= EPSILON)
211 #endif
212       {
213          X[0] = QCONST16(1.f,14);
214          j=1; do
215             X[j]=0;
216          while (++j<N);
217          sum = QCONST16(1.f,14);
218       }
219       /* Do we have sufficient accuracy here? */
220       rcp = EXTRACT16(MULT16_32_Q16(K-1, celt_rcp(sum)));
221       j=0; do {
222 #ifdef FIXED_POINT
223          /* It's really important to round *towards zero* here */
224          iy[j] = MULT16_16_Q15(X[j],rcp);
225 #else
226          iy[j] = floor(rcp*X[j]);
227 #endif
228          y[j] = SHL16(iy[j],yshift);
229          yy = MAC16_16(yy, y[j],y[j]);
230          xy = MAC16_16(xy, X[j],y[j]);
231          yp += P[j]*y[j];
232          y[j] *= 2;
233          pulsesLeft -= iy[j];
234       }  while (++j<N);
235    }
236    celt_assert2(pulsesLeft>=1, "Allocated too many pulses in the quick pass");
237
238    while (pulsesLeft > 1)
239    {
240       int pulsesAtOnce=1;
241       int best_id;
242       celt_word16_t magnitude;
243       celt_word32_t best_num = -VERY_LARGE16;
244       celt_word16_t best_den = 0;
245 #ifdef FIXED_POINT
246       int rshift;
247 #endif
248       /* Decide on how many pulses to find at once */
249       pulsesAtOnce = (pulsesLeft*N_1)>>9; /* pulsesLeft/N */
250       if (pulsesAtOnce<1)
251          pulsesAtOnce = 1;
252 #ifdef FIXED_POINT
253       rshift = yshift+1+celt_ilog2(K-pulsesLeft+pulsesAtOnce);
254 #endif
255       magnitude = SHL16(pulsesAtOnce, yshift);
256
257       best_id = 0;
258       /* The squared magnitude term gets added anyway, so we might as well 
259          add it outside the loop */
260       yy = MAC16_16(yy, magnitude,magnitude);
261       /* Choose between fast and accurate strategy depending on where we are in the search */
262          /* This should ensure that anything we can process will have a better score */
263       j=0;
264       do {
265          celt_word16_t Rxy, Ryy;
266          /* Select sign based on X[j] alone */
267          s = magnitude;
268          /* Temporary sums of the new pulse(s) */
269          Rxy = EXTRACT16(SHR32(MAC16_16(xy, s,X[j]),rshift));
270          /* We're multiplying y[j] by two so we don't have to do it here */
271          Ryy = EXTRACT16(SHR32(MAC16_16(yy, s,y[j]),rshift));
272             
273             /* Approximate score: we maximise Rxy/sqrt(Ryy) (we're guaranteed that 
274          Rxy is positive because the sign is pre-computed) */
275          Rxy = MULT16_16_Q15(Rxy,Rxy);
276             /* The idea is to check for num/den >= best_num/best_den, but that way
277          we can do it without any division */
278          /* OPT: Make sure to use conditional moves here */
279          if (MULT16_16(best_den, Rxy) > MULT16_16(Ryy, best_num))
280          {
281             best_den = Ryy;
282             best_num = Rxy;
283             best_id = j;
284          }
285       } while (++j<N);
286       
287       j = best_id;
288       is = pulsesAtOnce;
289       s = SHL16(is, yshift);
290
291       /* Updating the sums of the new pulse(s) */
292       xy = xy + MULT16_16(s,X[j]);
293       /* We're multiplying y[j] by two so we don't have to do it here */
294       yy = yy + MULT16_16(s,y[j]);
295       yp = yp + MULT16_16(s, P[j]);
296
297       /* Only now that we've made the final choice, update y/iy */
298       /* Multiplying y[j] by 2 so we don't have to do it everywhere else */
299       y[j] += 2*s;
300       iy[j] += is;
301       pulsesLeft -= pulsesAtOnce;
302    }
303    
304    if (pulsesLeft > 0)
305    {
306       celt_word16_t g;
307       celt_word16_t best_num = -VERY_LARGE16;
308       celt_word16_t best_den = 0;
309       int best_id = 0;
310       celt_word16_t magnitude = SHL16(1, yshift);
311
312       /* The squared magnitude term gets added anyway, so we might as well 
313       add it outside the loop */
314       yy = MAC16_16(yy, magnitude,magnitude);
315       j=0;
316       do {
317          celt_word16_t Rxy, Ryy, Ryp;
318          celt_word16_t num;
319          /* Select sign based on X[j] alone */
320          s = magnitude;
321          /* Temporary sums of the new pulse(s) */
322          Rxy = ROUND16(MAC16_16(xy, s,X[j]), 14);
323          /* We're multiplying y[j] by two so we don't have to do it here */
324          Ryy = ROUND16(MAC16_16(yy, s,y[j]), 14);
325          Ryp = ROUND16(MAC16_16(yp, s,P[j]), 14);
326
327             /* Compute the gain such that ||p + g*y|| = 1 
328          ...but instead, we compute g*Ryy to avoid dividing */
329          g = celt_psqrt(MULT16_16(Ryp,Ryp) + MULT16_16(Ryy,QCONST16(1.f,14)-Rpp)) - Ryp;
330             /* Knowing that gain, what's the error: (x-g*y)^2 
331          (result is negated and we discard x^2 because it's constant) */
332          /* score = 2*g*Rxy - g*g*Ryy;*/
333 #ifdef FIXED_POINT
334          /* No need to multiply Rxy by 2 because we did it earlier */
335          num = MULT16_16_Q15(ADD16(SUB16(Rxy,g),Rxy),g);
336 #else
337          num = g*(2*Rxy-g);
338 #endif
339          if (MULT16_16(best_den, num) > MULT16_16(Ryy, best_num))
340          {
341             best_den = Ryy;
342             best_num = num;
343             best_id = j;
344          }
345       } while (++j<N);
346       iy[best_id] += 1;
347    }
348    j=0;
349    do {
350       P[j] = MULT16_16(signx[j],P[j]);
351       X[j] = MULT16_16(signx[j],X[j]);
352       if (signx[j] < 0)
353          iy[j] = -iy[j];
354    } while (++j<N);
355    encode_pulses(iy, N, K, enc);
356    
357    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
358    due to the recursive computation used in quantisation. */
359    mix_pitch_and_residual(iy, X, N, K, P);
360    if (spread)
361       exp_rotation(X, N, -1, spread, K);
362    RESTORE_STACK;
363 }
364
365
366 /** Decode pulse vector and combine the result with the pitch vector to produce
367     the final normalised signal in the current band. */
368 void alg_unquant(celt_norm_t *X, int N, int K, int spread, celt_norm_t *P, ec_dec *dec)
369 {
370    VARDECL(int, iy);
371    SAVE_STACK;
372    K = get_pulses(K);
373    ALLOC(iy, N, int);
374    decode_pulses(iy, N, K, dec);
375    mix_pitch_and_residual(iy, X, N, K, P);
376    if (spread)
377       exp_rotation(X, N, -1, spread, K);
378    RESTORE_STACK;
379 }
380
381 celt_word16_t renormalise_vector(celt_norm_t *X, celt_word16_t value, int N, int stride)
382 {
383    int i;
384    celt_word32_t E = EPSILON;
385    celt_word16_t rE;
386    celt_word16_t g;
387    celt_norm_t *xptr = X;
388    for (i=0;i<N;i++)
389    {
390       E = MAC16_16(E, *xptr, *xptr);
391       xptr += stride;
392    }
393
394    rE = celt_sqrt(E);
395 #ifdef FIXED_POINT
396    if (rE <= 128)
397       g = Q15ONE;
398    else
399 #endif
400       g = MULT16_16_Q15(value,celt_rcp(SHL32(rE,9)));
401    xptr = X;
402    for (i=0;i<N;i++)
403    {
404       *xptr = PSHR32(MULT16_16(g, *xptr),8);
405       xptr += stride;
406    }
407    return rE;
408 }
409
410 static void fold(const CELTMode *m, int N, celt_norm_t *Y, celt_norm_t * restrict P, int N0, int B)
411 {
412    int j;
413    const int C = CHANNELS(m);
414    int id = (N0*C) % (C*B);
415    /* Here, we assume that id will never be greater than N0, i.e. that 
416       no band is wider than N0. In the unlikely case it happens, we set
417       everything to zero */
418    /*{
419            int offset = (N0*C - (id+C*N))/2;
420            if (offset > C*N0/16)
421                    offset = C*N0/16;
422            offset -= offset % (C*B);
423            if (offset < 0)
424                    offset = 0;
425            //printf ("%d\n", offset);
426            id += offset;
427    }*/
428    if (id+C*N>N0*C)
429       for (j=0;j<C*N;j++)
430          P[j] = 0;
431    else
432       for (j=0;j<C*N;j++)
433          P[j] = Y[id++];
434 }
435
436 void intra_fold(const CELTMode *m, celt_norm_t * restrict x, int N, int *pulses, celt_norm_t *Y, celt_norm_t * restrict P, int N0, int B)
437 {
438    int c;
439    const int C = CHANNELS(m);
440
441    fold(m, N, Y, P, N0, B);
442    c=0;
443    do {
444       int K = get_pulses(pulses[c]);
445       renormalise_vector(P+c, K==0 ? Q15ONE : 0, N, C);
446    } while (++c < C);
447 }
448