8ef80e507f262dad55e5eb6dbbf5ae2b143ad107
[opus.git] / celt / vq.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Written by Jean-Marc Valin */
4 /*
5    Redistribution and use in source and binary forms, with or without
6    modification, are permitted provided that the following conditions
7    are met:
8
9    - Redistributions of source code must retain the above copyright
10    notice, this list of conditions and the following disclaimer.
11
12    - Redistributions in binary form must reproduce the above copyright
13    notice, this list of conditions and the following disclaimer in the
14    documentation and/or other materials provided with the distribution.
15
16    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19    A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20    OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #ifdef HAVE_CONFIG_H
30 #include "config.h"
31 #endif
32
33 #include "mathops.h"
34 #include "cwrs.h"
35 #include "vq.h"
36 #include "arch.h"
37 #include "os_support.h"
38 #include "bands.h"
39 #include "rate.h"
40 #include "pitch.h"
41
42 #ifndef OVERRIDE_vq_exp_rotation1
43 static void exp_rotation1(celt_norm *X, int len, int stride, opus_val16 c, opus_val16 s)
44 {
45    int i;
46    opus_val16 ms;
47    celt_norm *Xptr;
48    Xptr = X;
49    ms = NEG16(s);
50    for (i=0;i<len-stride;i++)
51    {
52       celt_norm x1, x2;
53       x1 = Xptr[0];
54       x2 = Xptr[stride];
55       Xptr[stride] = EXTRACT16(PSHR32(MAC16_16(MULT16_16(c, x2),  s, x1), 15));
56       *Xptr++      = EXTRACT16(PSHR32(MAC16_16(MULT16_16(c, x1), ms, x2), 15));
57    }
58    Xptr = &X[len-2*stride-1];
59    for (i=len-2*stride-1;i>=0;i--)
60    {
61       celt_norm x1, x2;
62       x1 = Xptr[0];
63       x2 = Xptr[stride];
64       Xptr[stride] = EXTRACT16(PSHR32(MAC16_16(MULT16_16(c, x2),  s, x1), 15));
65       *Xptr--      = EXTRACT16(PSHR32(MAC16_16(MULT16_16(c, x1), ms, x2), 15));
66    }
67 }
68 #endif /* OVERRIDE_vq_exp_rotation1 */
69
70 void exp_rotation(celt_norm *X, int len, int dir, int stride, int K, int spread)
71 {
72    static const int SPREAD_FACTOR[3]={15,10,5};
73    int i;
74    opus_val16 c, s;
75    opus_val16 gain, theta;
76    int stride2=0;
77    int factor;
78
79    if (2*K>=len || spread==SPREAD_NONE)
80       return;
81    factor = SPREAD_FACTOR[spread-1];
82
83    gain = celt_div((opus_val32)MULT16_16(Q15_ONE,len),(opus_val32)(len+factor*K));
84    theta = HALF16(MULT16_16_Q15(gain,gain));
85
86    c = celt_cos_norm(EXTEND32(theta));
87    s = celt_cos_norm(EXTEND32(SUB16(Q15ONE,theta))); /*  sin(theta) */
88
89    if (len>=8*stride)
90    {
91       stride2 = 1;
92       /* This is just a simple (equivalent) way of computing sqrt(len/stride) with rounding.
93          It's basically incrementing long as (stride2+0.5)^2 < len/stride. */
94       while ((stride2*stride2+stride2)*stride + (stride>>2) < len)
95          stride2++;
96    }
97    /*NOTE: As a minor optimization, we could be passing around log2(B), not B, for both this and for
98       extract_collapse_mask().*/
99    len = celt_udiv(len, stride);
100    for (i=0;i<stride;i++)
101    {
102       if (dir < 0)
103       {
104          if (stride2)
105             exp_rotation1(X+i*len, len, stride2, s, c);
106          exp_rotation1(X+i*len, len, 1, c, s);
107       } else {
108          exp_rotation1(X+i*len, len, 1, c, -s);
109          if (stride2)
110             exp_rotation1(X+i*len, len, stride2, s, -c);
111       }
112    }
113 }
114
115 /** Takes the pitch vector and the decoded residual vector, computes the gain
116     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
117 static void normalise_residual(int * OPUS_RESTRICT iy, celt_norm * OPUS_RESTRICT X,
118       int N, opus_val32 Ryy, opus_val16 gain)
119 {
120    int i;
121 #ifdef FIXED_POINT
122    int k;
123 #endif
124    opus_val32 t;
125    opus_val16 g;
126
127 #ifdef FIXED_POINT
128    k = celt_ilog2(Ryy)>>1;
129 #endif
130    t = VSHR32(Ryy, 2*(k-7));
131    g = MULT16_16_P15(celt_rsqrt_norm(t),gain);
132
133    i=0;
134    do
135       X[i] = EXTRACT16(PSHR32(MULT16_16(g, iy[i]), k+1));
136    while (++i < N);
137 }
138
139 static unsigned extract_collapse_mask(int *iy, int N, int B)
140 {
141    unsigned collapse_mask;
142    int N0;
143    int i;
144    if (B<=1)
145       return 1;
146    /*NOTE: As a minor optimization, we could be passing around log2(B), not B, for both this and for
147       exp_rotation().*/
148    N0 = celt_udiv(N, B);
149    collapse_mask = 0;
150    i=0; do {
151       int j;
152       unsigned tmp=0;
153       j=0; do {
154          tmp |= iy[i*N0+j];
155       } while (++j<N0);
156       collapse_mask |= (tmp!=0)<<i;
157    } while (++i<B);
158    return collapse_mask;
159 }
160
161 opus_val16 op_pvq_search_c(celt_norm *X, int *iy, int K, int N, int arch)
162 {
163    VARDECL(celt_norm, y);
164    VARDECL(int, signx);
165    int i, j;
166    int pulsesLeft;
167    opus_val32 sum;
168    opus_val32 xy;
169    opus_val16 yy;
170    SAVE_STACK;
171
172    (void)arch;
173    ALLOC(y, N, celt_norm);
174    ALLOC(signx, N, int);
175
176    /* Get rid of the sign */
177    sum = 0;
178    j=0; do {
179       signx[j] = X[j]<0;
180       /* OPT: Make sure the compiler doesn't use a branch on ABS16(). */
181       X[j] = ABS16(X[j]);
182       iy[j] = 0;
183       y[j] = 0;
184    } while (++j<N);
185
186    xy = yy = 0;
187
188    pulsesLeft = K;
189
190    /* Do a pre-search by projecting on the pyramid */
191    if (K > (N>>1))
192    {
193       opus_val16 rcp;
194       j=0; do {
195          sum += X[j];
196       }  while (++j<N);
197
198       /* If X is too small, just replace it with a pulse at 0 */
199 #ifdef FIXED_POINT
200       if (sum <= K)
201 #else
202       /* Prevents infinities and NaNs from causing too many pulses
203          to be allocated. 64 is an approximation of infinity here. */
204       if (!(sum > EPSILON && sum < 64))
205 #endif
206       {
207          X[0] = QCONST16(1.f,14);
208          j=1; do
209             X[j]=0;
210          while (++j<N);
211          sum = QCONST16(1.f,14);
212       }
213 #ifdef FIXED_POINT
214       rcp = EXTRACT16(MULT16_32_Q16(K, celt_rcp(sum)));
215 #else
216       /* Using K+e with e < 1 guarantees we cannot get more than K pulses. */
217       rcp = EXTRACT16(MULT16_32_Q16(K+0.8f, celt_rcp(sum)));
218 #endif
219       j=0; do {
220 #ifdef FIXED_POINT
221          /* It's really important to round *towards zero* here */
222          iy[j] = MULT16_16_Q15(X[j],rcp);
223 #else
224          iy[j] = (int)floor(rcp*X[j]);
225 #endif
226          y[j] = (celt_norm)iy[j];
227          yy = MAC16_16(yy, y[j],y[j]);
228          xy = MAC16_16(xy, X[j],y[j]);
229          y[j] *= 2;
230          pulsesLeft -= iy[j];
231       }  while (++j<N);
232    }
233    celt_assert2(pulsesLeft>=0, "Allocated too many pulses in the quick pass");
234
235    /* This should never happen, but just in case it does (e.g. on silence)
236       we fill the first bin with pulses. */
237 #ifdef FIXED_POINT_DEBUG
238    celt_assert2(pulsesLeft<=N+3, "Not enough pulses in the quick pass");
239 #endif
240    if (pulsesLeft > N+3)
241    {
242       opus_val16 tmp = (opus_val16)pulsesLeft;
243       yy = MAC16_16(yy, tmp, tmp);
244       yy = MAC16_16(yy, tmp, y[0]);
245       iy[0] += pulsesLeft;
246       pulsesLeft=0;
247    }
248
249    for (i=0;i<pulsesLeft;i++)
250    {
251       opus_val16 Rxy, Ryy;
252       int best_id;
253       opus_val32 best_num;
254       opus_val16 best_den;
255 #ifdef FIXED_POINT
256       int rshift;
257 #endif
258 #ifdef FIXED_POINT
259       rshift = 1+celt_ilog2(K-pulsesLeft+i+1);
260 #endif
261       best_id = 0;
262       /* The squared magnitude term gets added anyway, so we might as well
263          add it outside the loop */
264       yy = ADD16(yy, 1);
265
266       /* Calculations for position 0 are out of the loop, in part to reduce
267          mispredicted branches (since the if condition is usually false)
268          in the loop. */
269       /* Temporary sums of the new pulse(s) */
270       Rxy = EXTRACT16(SHR32(ADD32(xy, EXTEND32(X[0])),rshift));
271       /* We're multiplying y[j] by two so we don't have to do it here */
272       Ryy = ADD16(yy, y[0]);
273
274       /* Approximate score: we maximise Rxy/sqrt(Ryy) (we're guaranteed that
275          Rxy is positive because the sign is pre-computed) */
276       Rxy = MULT16_16_Q15(Rxy,Rxy);
277       best_den = Ryy;
278       best_num = Rxy;
279       j=1;
280       do {
281          /* Temporary sums of the new pulse(s) */
282          Rxy = EXTRACT16(SHR32(ADD32(xy, EXTEND32(X[j])),rshift));
283          /* We're multiplying y[j] by two so we don't have to do it here */
284          Ryy = ADD16(yy, y[j]);
285
286          /* Approximate score: we maximise Rxy/sqrt(Ryy) (we're guaranteed that
287             Rxy is positive because the sign is pre-computed) */
288          Rxy = MULT16_16_Q15(Rxy,Rxy);
289          /* The idea is to check for num/den >= best_num/best_den, but that way
290             we can do it without any division */
291          /* OPT: It's not clear whether a cmov is faster than a branch here
292             since the condition is more often false than true and using
293             a cmov introduces data dependencies across iterations. The optimal
294             choice may be architecture-dependent. */
295          if (opus_unlikely(MULT16_16(best_den, Rxy) > MULT16_16(Ryy, best_num)))
296          {
297             best_den = Ryy;
298             best_num = Rxy;
299             best_id = j;
300          }
301       } while (++j<N);
302
303       /* Updating the sums of the new pulse(s) */
304       xy = ADD32(xy, EXTEND32(X[best_id]));
305       /* We're multiplying y[j] by two so we don't have to do it here */
306       yy = ADD16(yy, y[best_id]);
307
308       /* Only now that we've made the final choice, update y/iy */
309       /* Multiplying y[j] by 2 so we don't have to do it everywhere else */
310       y[best_id] += 2;
311       iy[best_id]++;
312    }
313
314    /* Put the original sign back */
315    j=0;
316    do {
317       /*iy[j] = signx[j] ? -iy[j] : iy[j];*/
318       /* OPT: The is more likely to be compiled without a branch than the code above
319          but has the same performance otherwise. */
320       iy[j] = (iy[j]^-signx[j]) + signx[j];
321    } while (++j<N);
322    RESTORE_STACK;
323    return yy;
324 }
325
326 unsigned alg_quant(celt_norm *X, int N, int K, int spread, int B, ec_enc *enc,
327       opus_val16 gain, int resynth, int arch)
328 {
329    VARDECL(int, iy);
330    opus_val16 yy;
331    unsigned collapse_mask;
332    SAVE_STACK;
333
334    celt_assert2(K>0, "alg_quant() needs at least one pulse");
335    celt_assert2(N>1, "alg_quant() needs at least two dimensions");
336
337    /* Covers vectorization by up to 4. */
338    ALLOC(iy, N+3, int);
339
340    exp_rotation(X, N, 1, B, K, spread);
341
342    yy = op_pvq_search(X, iy, K, N, arch);
343
344    encode_pulses(iy, N, K, enc);
345
346    if (resynth)
347    {
348       normalise_residual(iy, X, N, yy, gain);
349       exp_rotation(X, N, -1, B, K, spread);
350    }
351
352    collapse_mask = extract_collapse_mask(iy, N, B);
353    RESTORE_STACK;
354    return collapse_mask;
355 }
356
357 /** Decode pulse vector and combine the result with the pitch vector to produce
358     the final normalised signal in the current band. */
359 unsigned alg_unquant(celt_norm *X, int N, int K, int spread, int B,
360       ec_dec *dec, opus_val16 gain)
361 {
362    opus_val32 Ryy;
363    unsigned collapse_mask;
364    VARDECL(int, iy);
365    SAVE_STACK;
366
367    celt_assert2(K>0, "alg_unquant() needs at least one pulse");
368    celt_assert2(N>1, "alg_unquant() needs at least two dimensions");
369    ALLOC(iy, N, int);
370    Ryy = decode_pulses(iy, N, K, dec);
371    normalise_residual(iy, X, N, Ryy, gain);
372    exp_rotation(X, N, -1, B, K, spread);
373    collapse_mask = extract_collapse_mask(iy, N, B);
374    RESTORE_STACK;
375    return collapse_mask;
376 }
377
378 #ifndef OVERRIDE_renormalise_vector
379 void renormalise_vector(celt_norm *X, int N, opus_val16 gain, int arch)
380 {
381    int i;
382 #ifdef FIXED_POINT
383    int k;
384 #endif
385    opus_val32 E;
386    opus_val16 g;
387    opus_val32 t;
388    celt_norm *xptr;
389    E = EPSILON + celt_inner_prod(X, X, N, arch);
390 #ifdef FIXED_POINT
391    k = celt_ilog2(E)>>1;
392 #endif
393    t = VSHR32(E, 2*(k-7));
394    g = MULT16_16_P15(celt_rsqrt_norm(t),gain);
395
396    xptr = X;
397    for (i=0;i<N;i++)
398    {
399       *xptr = EXTRACT16(PSHR32(MULT16_16(g, *xptr), k+1));
400       xptr++;
401    }
402    /*return celt_sqrt(E);*/
403 }
404 #endif /* OVERRIDE_renormalise_vector */
405
406 int stereo_itheta(const celt_norm *X, const celt_norm *Y, int stereo, int N, int arch)
407 {
408    int i;
409    int itheta;
410    opus_val16 mid, side;
411    opus_val32 Emid, Eside;
412
413    Emid = Eside = EPSILON;
414    if (stereo)
415    {
416       for (i=0;i<N;i++)
417       {
418          celt_norm m, s;
419          m = ADD16(SHR16(X[i],1),SHR16(Y[i],1));
420          s = SUB16(SHR16(X[i],1),SHR16(Y[i],1));
421          Emid = MAC16_16(Emid, m, m);
422          Eside = MAC16_16(Eside, s, s);
423       }
424    } else {
425       Emid += celt_inner_prod(X, X, N, arch);
426       Eside += celt_inner_prod(Y, Y, N, arch);
427    }
428    mid = celt_sqrt(Emid);
429    side = celt_sqrt(Eside);
430 #ifdef FIXED_POINT
431    /* 0.63662 = 2/pi */
432    itheta = MULT16_16_Q15(QCONST16(0.63662f,15),celt_atan2p(side, mid));
433 #else
434    itheta = (int)floor(.5f+16384*0.63662f*fast_atan2f(side,mid));
435 #endif
436
437    return itheta;
438 }