ed2c46a95ac57f8dd3698c45f81ba7aa5c4dcfab
[opus.git] / libcelt / vq.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Written by Jean-Marc Valin */
4 /*
5    Redistribution and use in source and binary forms, with or without
6    modification, are permitted provided that the following conditions
7    are met:
8    
9    - Redistributions of source code must retain the above copyright
10    notice, this list of conditions and the following disclaimer.
11    
12    - Redistributions in binary form must reproduce the above copyright
13    notice, this list of conditions and the following disclaimer in the
14    documentation and/or other materials provided with the distribution.
15    
16    - Neither the name of the Xiph.org Foundation nor the names of its
17    contributors may be used to endorse or promote products derived from
18    this software without specific prior written permission.
19    
20    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
24    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
25    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
26    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
27    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
28    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
29    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32
33 #ifdef HAVE_CONFIG_H
34 #include "config.h"
35 #endif
36
37 #include "mathops.h"
38 #include "cwrs.h"
39 #include "vq.h"
40 #include "arch.h"
41 #include "os_support.h"
42 #include "bands.h"
43 #include "rate.h"
44
45 #ifndef M_PI
46 #define M_PI 3.141592653
47 #endif
48
49 static void exp_rotation1(celt_norm *X, int len, int stride, celt_word16 c, celt_word16 s)
50 {
51    int i;
52    celt_norm *Xptr;
53    Xptr = X;
54    for (i=0;i<len-stride;i++)
55    {
56       celt_norm x1, x2;
57       x1 = Xptr[0];
58       x2 = Xptr[stride];
59       Xptr[stride] = EXTRACT16(SHR32(MULT16_16(c,x2) + MULT16_16(s,x1), 15));
60       *Xptr++      = EXTRACT16(SHR32(MULT16_16(c,x1) - MULT16_16(s,x2), 15));
61    }
62    Xptr = &X[len-2*stride-1];
63    for (i=len-2*stride-1;i>=0;i--)
64    {
65       celt_norm x1, x2;
66       x1 = Xptr[0];
67       x2 = Xptr[stride];
68       Xptr[stride] = EXTRACT16(SHR32(MULT16_16(c,x2) + MULT16_16(s,x1), 15));
69       *Xptr--      = EXTRACT16(SHR32(MULT16_16(c,x1) - MULT16_16(s,x2), 15));
70    }
71 }
72
73 static void exp_rotation(celt_norm *X, int len, int dir, int stride, int K, int spread)
74 {
75    static const int SPREAD_FACTOR[3]={5,10,15};
76    int i;
77    celt_word16 c, s;
78    celt_word16 gain, theta;
79    int stride2=0;
80    int factor;
81    /*int i;
82    if (len>=30)
83    {
84       for (i=0;i<len;i++)
85          X[i] = 0;
86       X[14] = 1;
87       K=5;
88    }*/
89    if (2*K>=len || spread==SPREAD_NONE)
90       return;
91    factor = SPREAD_FACTOR[spread-1];
92
93    gain = celt_div((celt_word32)MULT16_16(Q15_ONE,len),(celt_word32)(len+factor*K));
94    /* FIXME: Make that HALF16 instead of HALF32 */
95    theta = HALF32(MULT16_16_Q15(gain,gain));
96
97    c = celt_cos_norm(EXTEND32(theta));
98    s = celt_cos_norm(EXTEND32(SUB16(Q15ONE,theta))); /*  sin(theta) */
99
100    if (len>=8*stride)
101    {
102       stride2 = 1;
103       /* This is just a simple way of computing sqrt(len/stride) with rounding.
104          It's basically incrementing long as (stride2+0.5)^2 < len/stride.
105          I _think_ it is bit-exact */
106       while ((stride2*stride2+stride2)*stride + (stride>>2) < len)
107          stride2++;
108    }
109    len /= stride;
110    for (i=0;i<stride;i++)
111    {
112       if (dir < 0)
113       {
114          if (stride2)
115             exp_rotation1(X+i*len, len, stride2, s, c);
116          exp_rotation1(X+i*len, len, 1, c, s);
117       } else {
118          exp_rotation1(X+i*len, len, 1, c, -s);
119          if (stride2)
120             exp_rotation1(X+i*len, len, stride2, s, -c);
121       }
122    }
123    /*if (len>=30)
124    {
125       for (i=0;i<len;i++)
126          printf ("%f ", X[i]);
127       printf ("\n");
128       exit(0);
129    }*/
130 }
131
132 /** Takes the pitch vector and the decoded residual vector, computes the gain
133     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
134 static void normalise_residual(int * restrict iy, celt_norm * restrict X,
135       int N, int K, celt_word32 Ryy, celt_word16 gain)
136 {
137    int i;
138 #ifdef FIXED_POINT
139    int k;
140 #endif
141    celt_word32 t;
142    celt_word16 g;
143
144 #ifdef FIXED_POINT
145    k = celt_ilog2(Ryy)>>1;
146 #endif
147    t = VSHR32(Ryy, (k-7)<<1);
148    g = MULT16_16_P15(celt_rsqrt_norm(t),gain);
149
150    i=0;
151    do
152       X[i] = EXTRACT16(PSHR32(MULT16_16(g, iy[i]), k+1));
153    while (++i < N);
154 }
155
156 void alg_quant(celt_norm *X, int N, int K, int spread, int B, celt_norm *lowband,
157       int resynth, ec_enc *enc, celt_int32 *seed, celt_word16 gain)
158 {
159    VARDECL(celt_norm, y);
160    VARDECL(int, iy);
161    VARDECL(celt_word16, signx);
162    int i, j;
163    celt_word16 s;
164    int pulsesLeft;
165    celt_word32 sum;
166    celt_word32 xy;
167    celt_word16 yy;
168    SAVE_STACK;
169
170    celt_assert2(K!=0, "alg_quant() needs at least one pulse");
171
172    ALLOC(y, N, celt_norm);
173    ALLOC(iy, N, int);
174    ALLOC(signx, N, celt_word16);
175    
176    exp_rotation(X, N, 1, B, K, spread);
177
178    /* Get rid of the sign */
179    sum = 0;
180    j=0; do {
181       if (X[j]>0)
182          signx[j]=1;
183       else {
184          signx[j]=-1;
185          X[j]=-X[j];
186       }
187       iy[j] = 0;
188       y[j] = 0;
189    } while (++j<N);
190
191    xy = yy = 0;
192
193    pulsesLeft = K;
194
195    /* Do a pre-search by projecting on the pyramid */
196    if (K > (N>>1))
197    {
198       celt_word16 rcp;
199       j=0; do {
200          sum += X[j];
201       }  while (++j<N);
202
203       /* If X is too small, just replace it with a pulse at 0 */
204 #ifdef FIXED_POINT
205       if (sum <= K)
206 #else
207       if (sum <= EPSILON)
208 #endif
209       {
210          X[0] = QCONST16(1.f,14);
211          j=1; do
212             X[j]=0;
213          while (++j<N);
214          sum = QCONST16(1.f,14);
215       }
216       /* Do we have sufficient accuracy here? */
217       rcp = EXTRACT16(MULT16_32_Q16(K-1, celt_rcp(sum)));
218       j=0; do {
219 #ifdef FIXED_POINT
220          /* It's really important to round *towards zero* here */
221          iy[j] = MULT16_16_Q15(X[j],rcp);
222 #else
223          iy[j] = (int)floor(rcp*X[j]);
224 #endif
225          y[j] = iy[j];
226          yy = MAC16_16(yy, y[j],y[j]);
227          xy = MAC16_16(xy, X[j],y[j]);
228          y[j] *= 2;
229          pulsesLeft -= iy[j];
230       }  while (++j<N);
231    }
232    celt_assert2(pulsesLeft>=1, "Allocated too many pulses in the quick pass");
233
234    /* This should never happen, but just in case it does (e.g. on silence)
235       we fill the first bin with pulses. */
236 #ifdef FIXED_POINT_DEBUG
237    celt_assert2(pulsesLeft<=N+3, "Not enough pulses in the quick pass");
238 #endif
239    if (pulsesLeft > N+3)
240    {
241       celt_word16 tmp = pulsesLeft;
242       yy = MAC16_16(yy, tmp, tmp);
243       yy = MAC16_16(yy, tmp, y[0]);
244       iy[0] += pulsesLeft;
245       pulsesLeft=0;
246    }
247
248    s = 1;
249    for (i=0;i<pulsesLeft;i++)
250    {
251       int best_id;
252       celt_word32 best_num = -VERY_LARGE16;
253       celt_word16 best_den = 0;
254 #ifdef FIXED_POINT
255       int rshift;
256 #endif
257 #ifdef FIXED_POINT
258       rshift = 1+celt_ilog2(K-pulsesLeft+i+1);
259 #endif
260       best_id = 0;
261       /* The squared magnitude term gets added anyway, so we might as well 
262          add it outside the loop */
263       yy = ADD32(yy, 1);
264       j=0;
265       do {
266          celt_word16 Rxy, Ryy;
267          /* Temporary sums of the new pulse(s) */
268          Rxy = EXTRACT16(SHR32(ADD32(xy, EXTEND32(X[j])),rshift));
269          /* We're multiplying y[j] by two so we don't have to do it here */
270          Ryy = ADD16(yy, y[j]);
271
272          /* Approximate score: we maximise Rxy/sqrt(Ryy) (we're guaranteed that
273             Rxy is positive because the sign is pre-computed) */
274          Rxy = MULT16_16_Q15(Rxy,Rxy);
275          /* The idea is to check for num/den >= best_num/best_den, but that way
276             we can do it without any division */
277          /* OPT: Make sure to use conditional moves here */
278          if (MULT16_16(best_den, Rxy) > MULT16_16(Ryy, best_num))
279          {
280             best_den = Ryy;
281             best_num = Rxy;
282             best_id = j;
283          }
284       } while (++j<N);
285       
286       /* Updating the sums of the new pulse(s) */
287       xy = ADD32(xy, EXTEND32(X[best_id]));
288       /* We're multiplying y[j] by two so we don't have to do it here */
289       yy = ADD16(yy, y[best_id]);
290
291       /* Only now that we've made the final choice, update y/iy */
292       /* Multiplying y[j] by 2 so we don't have to do it everywhere else */
293       y[best_id] += 2*s;
294       iy[best_id]++;
295    }
296
297    /* Put the original sign back */
298    j=0;
299    do {
300       X[j] = MULT16_16(signx[j],X[j]);
301       if (signx[j] < 0)
302          iy[j] = -iy[j];
303    } while (++j<N);
304    encode_pulses(iy, N, K, enc);
305    
306    if (resynth)
307    {
308       normalise_residual(iy, X, N, K, yy, gain);
309       exp_rotation(X, N, -1, B, K, spread);
310    }
311    RESTORE_STACK;
312 }
313
314
315 /** Decode pulse vector and combine the result with the pitch vector to produce
316     the final normalised signal in the current band. */
317 void alg_unquant(celt_norm *X, int N, int K, int spread, int B,
318       celt_norm *lowband, ec_dec *dec, celt_int32 *seed, celt_word16 gain)
319 {
320    int i;
321    celt_word32 Ryy;
322    VARDECL(int, iy);
323    SAVE_STACK;
324
325    celt_assert2(K!=0, "alg_unquant() needs at least one pulse");
326    ALLOC(iy, N, int);
327    decode_pulses(iy, N, K, dec);
328    Ryy = 0;
329    i=0;
330    do {
331       Ryy = MAC16_16(Ryy, iy[i], iy[i]);
332    } while (++i < N);
333    normalise_residual(iy, X, N, K, Ryy, gain);
334    exp_rotation(X, N, -1, B, K, spread);
335    RESTORE_STACK;
336 }
337
338 void renormalise_vector(celt_norm *X, int N, celt_word16 gain)
339 {
340    int i;
341 #ifdef FIXED_POINT
342    int k;
343 #endif
344    celt_word32 E = EPSILON;
345    celt_word16 g;
346    celt_word32 t;
347    celt_norm *xptr = X;
348    for (i=0;i<N;i++)
349    {
350       E = MAC16_16(E, *xptr, *xptr);
351       xptr++;
352    }
353 #ifdef FIXED_POINT
354    k = celt_ilog2(E)>>1;
355 #endif
356    t = VSHR32(E, (k-7)<<1);
357    g = MULT16_16_P15(celt_rsqrt_norm(t),gain);
358
359    xptr = X;
360    for (i=0;i<N;i++)
361    {
362       *xptr = EXTRACT16(PSHR32(MULT16_16(g, *xptr), k+1));
363       xptr++;
364    }
365    /*return celt_sqrt(E);*/
366 }
367
368 int stereo_itheta(celt_norm *X, celt_norm *Y, int stereo, int N)
369 {
370    int i;
371    int itheta;
372    celt_word16 mid, side;
373    celt_word32 Emid, Eside;
374
375    Emid = Eside = EPSILON;
376    if (stereo)
377    {
378       for (i=0;i<N;i++)
379       {
380          celt_norm m, s;
381          m = ADD16(SHR16(X[i],1),SHR16(Y[i],1));
382          s = SUB16(SHR16(X[i],1),SHR16(Y[i],1));
383          Emid = MAC16_16(Emid, m, m);
384          Eside = MAC16_16(Eside, s, s);
385       }
386    } else {
387       for (i=0;i<N;i++)
388       {
389          celt_norm m, s;
390          m = X[i];
391          s = Y[i];
392          Emid = MAC16_16(Emid, m, m);
393          Eside = MAC16_16(Eside, s, s);
394       }
395    }
396    mid = celt_sqrt(Emid);
397    side = celt_sqrt(Eside);
398 #ifdef FIXED_POINT
399    /* 0.63662 = 2/pi */
400    itheta = MULT16_16_Q15(QCONST16(0.63662f,15),celt_atan2p(side, mid));
401 #else
402    itheta = (int)floor(.5f+16384*0.63662f*atan2(side,mid));
403 #endif
404
405    return itheta;
406 }