Defining celt_inner_prod() and using it instead of explicit loops.
[opus.git] / celt / vq.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Written by Jean-Marc Valin */
4 /*
5    Redistribution and use in source and binary forms, with or without
6    modification, are permitted provided that the following conditions
7    are met:
8
9    - Redistributions of source code must retain the above copyright
10    notice, this list of conditions and the following disclaimer.
11
12    - Redistributions in binary form must reproduce the above copyright
13    notice, this list of conditions and the following disclaimer in the
14    documentation and/or other materials provided with the distribution.
15
16    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19    A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20    OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #ifdef HAVE_CONFIG_H
30 #include "config.h"
31 #endif
32
33 #include "mathops.h"
34 #include "cwrs.h"
35 #include "vq.h"
36 #include "arch.h"
37 #include "os_support.h"
38 #include "bands.h"
39 #include "rate.h"
40 #include "pitch.h"
41
42 static void exp_rotation1(celt_norm *X, int len, int stride, opus_val16 c, opus_val16 s)
43 {
44    int i;
45    celt_norm *Xptr;
46    Xptr = X;
47    for (i=0;i<len-stride;i++)
48    {
49       celt_norm x1, x2;
50       x1 = Xptr[0];
51       x2 = Xptr[stride];
52       Xptr[stride] = EXTRACT16(SHR32(MULT16_16(c,x2) + MULT16_16(s,x1), 15));
53       *Xptr++      = EXTRACT16(SHR32(MULT16_16(c,x1) - MULT16_16(s,x2), 15));
54    }
55    Xptr = &X[len-2*stride-1];
56    for (i=len-2*stride-1;i>=0;i--)
57    {
58       celt_norm x1, x2;
59       x1 = Xptr[0];
60       x2 = Xptr[stride];
61       Xptr[stride] = EXTRACT16(SHR32(MULT16_16(c,x2) + MULT16_16(s,x1), 15));
62       *Xptr--      = EXTRACT16(SHR32(MULT16_16(c,x1) - MULT16_16(s,x2), 15));
63    }
64 }
65
66 static void exp_rotation(celt_norm *X, int len, int dir, int stride, int K, int spread)
67 {
68    static const int SPREAD_FACTOR[3]={15,10,5};
69    int i;
70    opus_val16 c, s;
71    opus_val16 gain, theta;
72    int stride2=0;
73    int factor;
74
75    if (2*K>=len || spread==SPREAD_NONE)
76       return;
77    factor = SPREAD_FACTOR[spread-1];
78
79    gain = celt_div((opus_val32)MULT16_16(Q15_ONE,len),(opus_val32)(len+factor*K));
80    theta = HALF16(MULT16_16_Q15(gain,gain));
81
82    c = celt_cos_norm(EXTEND32(theta));
83    s = celt_cos_norm(EXTEND32(SUB16(Q15ONE,theta))); /*  sin(theta) */
84
85    if (len>=8*stride)
86    {
87       stride2 = 1;
88       /* This is just a simple (equivalent) way of computing sqrt(len/stride) with rounding.
89          It's basically incrementing long as (stride2+0.5)^2 < len/stride. */
90       while ((stride2*stride2+stride2)*stride + (stride>>2) < len)
91          stride2++;
92    }
93    /*NOTE: As a minor optimization, we could be passing around log2(B), not B, for both this and for
94       extract_collapse_mask().*/
95    len /= stride;
96    for (i=0;i<stride;i++)
97    {
98       if (dir < 0)
99       {
100          if (stride2)
101             exp_rotation1(X+i*len, len, stride2, s, c);
102          exp_rotation1(X+i*len, len, 1, c, s);
103       } else {
104          exp_rotation1(X+i*len, len, 1, c, -s);
105          if (stride2)
106             exp_rotation1(X+i*len, len, stride2, s, -c);
107       }
108    }
109 }
110
111 /** Takes the pitch vector and the decoded residual vector, computes the gain
112     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
113 static void normalise_residual(int * OPUS_RESTRICT iy, celt_norm * OPUS_RESTRICT X,
114       int N, opus_val32 Ryy, opus_val16 gain)
115 {
116    int i;
117 #ifdef FIXED_POINT
118    int k;
119 #endif
120    opus_val32 t;
121    opus_val16 g;
122
123 #ifdef FIXED_POINT
124    k = celt_ilog2(Ryy)>>1;
125 #endif
126    t = VSHR32(Ryy, 2*(k-7));
127    g = MULT16_16_P15(celt_rsqrt_norm(t),gain);
128
129    i=0;
130    do
131       X[i] = EXTRACT16(PSHR32(MULT16_16(g, iy[i]), k+1));
132    while (++i < N);
133 }
134
135 static unsigned extract_collapse_mask(int *iy, int N, int B)
136 {
137    unsigned collapse_mask;
138    int N0;
139    int i;
140    if (B<=1)
141       return 1;
142    /*NOTE: As a minor optimization, we could be passing around log2(B), not B, for both this and for
143       exp_rotation().*/
144    N0 = N/B;
145    collapse_mask = 0;
146    i=0; do {
147       int j;
148       j=0; do {
149          collapse_mask |= (iy[i*N0+j]!=0)<<i;
150       } while (++j<N0);
151    } while (++i<B);
152    return collapse_mask;
153 }
154
155 unsigned alg_quant(celt_norm *X, int N, int K, int spread, int B, ec_enc *enc
156 #ifdef RESYNTH
157    , opus_val16 gain
158 #endif
159    )
160 {
161    VARDECL(celt_norm, y);
162    VARDECL(int, iy);
163    VARDECL(opus_val16, signx);
164    int i, j;
165    opus_val16 s;
166    int pulsesLeft;
167    opus_val32 sum;
168    opus_val32 xy;
169    opus_val16 yy;
170    unsigned collapse_mask;
171    SAVE_STACK;
172
173    celt_assert2(K>0, "alg_quant() needs at least one pulse");
174    celt_assert2(N>1, "alg_quant() needs at least two dimensions");
175
176    ALLOC(y, N, celt_norm);
177    ALLOC(iy, N, int);
178    ALLOC(signx, N, opus_val16);
179
180    exp_rotation(X, N, 1, B, K, spread);
181
182    /* Get rid of the sign */
183    sum = 0;
184    j=0; do {
185       if (X[j]>0)
186          signx[j]=1;
187       else {
188          signx[j]=-1;
189          X[j]=-X[j];
190       }
191       iy[j] = 0;
192       y[j] = 0;
193    } while (++j<N);
194
195    xy = yy = 0;
196
197    pulsesLeft = K;
198
199    /* Do a pre-search by projecting on the pyramid */
200    if (K > (N>>1))
201    {
202       opus_val16 rcp;
203       j=0; do {
204          sum += X[j];
205       }  while (++j<N);
206
207       /* If X is too small, just replace it with a pulse at 0 */
208 #ifdef FIXED_POINT
209       if (sum <= K)
210 #else
211       /* Prevents infinities and NaNs from causing too many pulses
212          to be allocated. 64 is an approximation of infinity here. */
213       if (!(sum > EPSILON && sum < 64))
214 #endif
215       {
216          X[0] = QCONST16(1.f,14);
217          j=1; do
218             X[j]=0;
219          while (++j<N);
220          sum = QCONST16(1.f,14);
221       }
222       rcp = EXTRACT16(MULT16_32_Q16(K-1, celt_rcp(sum)));
223       j=0; do {
224 #ifdef FIXED_POINT
225          /* It's really important to round *towards zero* here */
226          iy[j] = MULT16_16_Q15(X[j],rcp);
227 #else
228          iy[j] = (int)floor(rcp*X[j]);
229 #endif
230          y[j] = (celt_norm)iy[j];
231          yy = MAC16_16(yy, y[j],y[j]);
232          xy = MAC16_16(xy, X[j],y[j]);
233          y[j] *= 2;
234          pulsesLeft -= iy[j];
235       }  while (++j<N);
236    }
237    celt_assert2(pulsesLeft>=1, "Allocated too many pulses in the quick pass");
238
239    /* This should never happen, but just in case it does (e.g. on silence)
240       we fill the first bin with pulses. */
241 #ifdef FIXED_POINT_DEBUG
242    celt_assert2(pulsesLeft<=N+3, "Not enough pulses in the quick pass");
243 #endif
244    if (pulsesLeft > N+3)
245    {
246       opus_val16 tmp = (opus_val16)pulsesLeft;
247       yy = MAC16_16(yy, tmp, tmp);
248       yy = MAC16_16(yy, tmp, y[0]);
249       iy[0] += pulsesLeft;
250       pulsesLeft=0;
251    }
252
253    s = 1;
254    for (i=0;i<pulsesLeft;i++)
255    {
256       int best_id;
257       opus_val32 best_num = -VERY_LARGE16;
258       opus_val16 best_den = 0;
259 #ifdef FIXED_POINT
260       int rshift;
261 #endif
262 #ifdef FIXED_POINT
263       rshift = 1+celt_ilog2(K-pulsesLeft+i+1);
264 #endif
265       best_id = 0;
266       /* The squared magnitude term gets added anyway, so we might as well
267          add it outside the loop */
268       yy = ADD32(yy, 1);
269       j=0;
270       do {
271          opus_val16 Rxy, Ryy;
272          /* Temporary sums of the new pulse(s) */
273          Rxy = EXTRACT16(SHR32(ADD32(xy, EXTEND32(X[j])),rshift));
274          /* We're multiplying y[j] by two so we don't have to do it here */
275          Ryy = ADD16(yy, y[j]);
276
277          /* Approximate score: we maximise Rxy/sqrt(Ryy) (we're guaranteed that
278             Rxy is positive because the sign is pre-computed) */
279          Rxy = MULT16_16_Q15(Rxy,Rxy);
280          /* The idea is to check for num/den >= best_num/best_den, but that way
281             we can do it without any division */
282          /* OPT: Make sure to use conditional moves here */
283          if (MULT16_16(best_den, Rxy) > MULT16_16(Ryy, best_num))
284          {
285             best_den = Ryy;
286             best_num = Rxy;
287             best_id = j;
288          }
289       } while (++j<N);
290
291       /* Updating the sums of the new pulse(s) */
292       xy = ADD32(xy, EXTEND32(X[best_id]));
293       /* We're multiplying y[j] by two so we don't have to do it here */
294       yy = ADD16(yy, y[best_id]);
295
296       /* Only now that we've made the final choice, update y/iy */
297       /* Multiplying y[j] by 2 so we don't have to do it everywhere else */
298       y[best_id] += 2*s;
299       iy[best_id]++;
300    }
301
302    /* Put the original sign back */
303    j=0;
304    do {
305       X[j] = MULT16_16(signx[j],X[j]);
306       if (signx[j] < 0)
307          iy[j] = -iy[j];
308    } while (++j<N);
309    encode_pulses(iy, N, K, enc);
310
311 #ifdef RESYNTH
312    normalise_residual(iy, X, N, yy, gain);
313    exp_rotation(X, N, -1, B, K, spread);
314 #endif
315
316    collapse_mask = extract_collapse_mask(iy, N, B);
317    RESTORE_STACK;
318    return collapse_mask;
319 }
320
321 /** Decode pulse vector and combine the result with the pitch vector to produce
322     the final normalised signal in the current band. */
323 unsigned alg_unquant(celt_norm *X, int N, int K, int spread, int B,
324       ec_dec *dec, opus_val16 gain)
325 {
326    int i;
327    opus_val32 Ryy;
328    unsigned collapse_mask;
329    VARDECL(int, iy);
330    SAVE_STACK;
331
332    celt_assert2(K>0, "alg_unquant() needs at least one pulse");
333    celt_assert2(N>1, "alg_unquant() needs at least two dimensions");
334    ALLOC(iy, N, int);
335    decode_pulses(iy, N, K, dec);
336    Ryy = 0;
337    i=0;
338    do {
339       Ryy = MAC16_16(Ryy, iy[i], iy[i]);
340    } while (++i < N);
341    normalise_residual(iy, X, N, Ryy, gain);
342    exp_rotation(X, N, -1, B, K, spread);
343    collapse_mask = extract_collapse_mask(iy, N, B);
344    RESTORE_STACK;
345    return collapse_mask;
346 }
347
348 void renormalise_vector(celt_norm *X, int N, opus_val16 gain)
349 {
350    int i;
351 #ifdef FIXED_POINT
352    int k;
353 #endif
354    opus_val32 E;
355    opus_val16 g;
356    opus_val32 t;
357    celt_norm *xptr;
358    E = EPSILON + celt_inner_prod(X, X, N);
359 #ifdef FIXED_POINT
360    k = celt_ilog2(E)>>1;
361 #endif
362    t = VSHR32(E, 2*(k-7));
363    g = MULT16_16_P15(celt_rsqrt_norm(t),gain);
364
365    xptr = X;
366    for (i=0;i<N;i++)
367    {
368       *xptr = EXTRACT16(PSHR32(MULT16_16(g, *xptr), k+1));
369       xptr++;
370    }
371    /*return celt_sqrt(E);*/
372 }
373
374 int stereo_itheta(celt_norm *X, celt_norm *Y, int stereo, int N)
375 {
376    int i;
377    int itheta;
378    opus_val16 mid, side;
379    opus_val32 Emid, Eside;
380
381    Emid = Eside = EPSILON;
382    if (stereo)
383    {
384       for (i=0;i<N;i++)
385       {
386          celt_norm m, s;
387          m = ADD16(SHR16(X[i],1),SHR16(Y[i],1));
388          s = SUB16(SHR16(X[i],1),SHR16(Y[i],1));
389          Emid = MAC16_16(Emid, m, m);
390          Eside = MAC16_16(Eside, s, s);
391       }
392    } else {
393       Emid += celt_inner_prod(X, X, N);
394       Eside += celt_inner_prod(Y, Y, N);
395    }
396    mid = celt_sqrt(Emid);
397    side = celt_sqrt(Eside);
398 #ifdef FIXED_POINT
399    /* 0.63662 = 2/pi */
400    itheta = MULT16_16_Q15(QCONST16(0.63662f,15),celt_atan2p(side, mid));
401 #else
402    itheta = (int)floor(.5f+16384*0.63662f*atan2(side,mid));
403 #endif
404
405    return itheta;
406 }