Merge branch 'cwrs_speedup' (derf's cwrs changes)
[opus.git] / libcelt / vq.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "mathops.h"
37 #include "cwrs.h"
38 #include "vq.h"
39 #include "arch.h"
40 #include "os_support.h"
41
42 /** Takes the pitch vector and the decoded residual vector, computes the gain
43     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
44 static void mix_pitch_and_residual(int * restrict iy, celt_norm_t * restrict X, int N, int K, const celt_norm_t * restrict P)
45 {
46    int i;
47    celt_word32_t Ryp, Ryy, Rpp;
48    celt_word16_t ryp, ryy, rpp;
49    celt_word32_t g;
50    VARDECL(celt_norm_t, y);
51 #ifdef FIXED_POINT
52    int yshift;
53 #endif
54    SAVE_STACK;
55 #ifdef FIXED_POINT
56    yshift = 13-celt_ilog2(K);
57 #endif
58    ALLOC(y, N, celt_norm_t);
59
60    /*for (i=0;i<N;i++)
61    printf ("%d ", iy[i]);*/
62    Rpp = 0;
63    i=0;
64    do {
65       Rpp = MAC16_16(Rpp,P[i],P[i]);
66       y[i] = SHL16(iy[i],yshift);
67    } while (++i < N);
68
69    Ryp = 0;
70    Ryy = 0;
71    /* If this doesn't generate a dual MAC (on supported archs), fire the compiler guy */
72    i=0;
73    do {
74       Ryp = MAC16_16(Ryp, y[i], P[i]);
75       Ryy = MAC16_16(Ryy, y[i], y[i]);
76    } while (++i < N);
77
78    ryp = ROUND16(Ryp,14);
79    ryy = ROUND16(Ryy,14);
80    rpp = ROUND16(Rpp,14);
81    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
82    g = MULT16_32_Q15(celt_sqrt(MAC16_16(Ryy, ryp,ryp) - MULT16_16(ryy,rpp)) - ryp,
83                      celt_rcp(SHR32(Ryy,9)));
84
85    i=0;
86    do 
87       X[i] = ADD16(P[i], ROUND16(MULT16_16(y[i], g),11));
88    while (++i < N);
89
90    RESTORE_STACK;
91 }
92
93
94 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, ec_enc *enc)
95 {
96    VARDECL(celt_norm_t, y);
97    VARDECL(int, iy);
98    VARDECL(celt_word16_t, signx);
99    int j, is;
100    celt_word16_t s;
101    int pulsesLeft;
102    celt_word32_t sum;
103    celt_word32_t xy, yy, yp;
104    celt_word16_t Rpp;
105    int N_1; /* Inverse of N, in Q14 format (even for float) */
106 #ifdef FIXED_POINT
107    int yshift;
108 #endif
109    SAVE_STACK;
110
111 #ifdef FIXED_POINT
112    yshift = 13-celt_ilog2(K);
113 #endif
114
115    ALLOC(y, N, celt_norm_t);
116    ALLOC(iy, N, int);
117    ALLOC(signx, N, celt_word16_t);
118    N_1 = 512/N;
119
120    sum = 0;
121    j=0; do {
122       X[j] -= P[j];
123       if (X[j]>0)
124          signx[j]=1;
125       else
126          signx[j]=-1;
127       iy[j] = 0;
128       y[j] = 0;
129       sum = MAC16_16(sum, P[j],P[j]);
130    } while (++j<N);
131    Rpp = ROUND16(sum, NORM_SHIFT);
132
133    celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
134
135    xy = yy = yp = 0;
136
137    pulsesLeft = K;
138    while (pulsesLeft > 0)
139    {
140       int pulsesAtOnce=1;
141       int best_id;
142       celt_word16_t magnitude;
143 #ifdef FIXED_POINT
144       int rshift;
145 #endif
146       /* Decide on how many pulses to find at once */
147       pulsesAtOnce = (pulsesLeft*N_1)>>9; /* pulsesLeft/N */
148       if (pulsesAtOnce<1)
149          pulsesAtOnce = 1;
150 #ifdef FIXED_POINT
151       rshift = yshift+1+celt_ilog2(K-pulsesLeft+pulsesAtOnce);
152 #endif
153       magnitude = SHL16(pulsesAtOnce, yshift);
154
155       best_id = 0;
156       /* The squared magnitude term gets added anyway, so we might as well 
157          add it outside the loop */
158       yy = MAC16_16(yy, magnitude,magnitude);
159       /* Choose between fast and accurate strategy depending on where we are in the search */
160       if (pulsesLeft>1)
161       {
162          /* This should ensure that anything we can process will have a better score */
163          celt_word32_t best_num = -VERY_LARGE16;
164          celt_word16_t best_den = 0;
165          j=0;
166          do {
167             celt_word16_t Rxy, Ryy;
168             /* Select sign based on X[j] alone */
169             s = MULT16_16(signx[j],magnitude);
170             /* Temporary sums of the new pulse(s) */
171             Rxy = EXTRACT16(SHR32(MAC16_16(xy, s,X[j]),rshift));
172             /* We're multiplying y[j] by two so we don't have to do it here */
173             Ryy = EXTRACT16(SHR32(MAC16_16(yy, s,y[j]),rshift));
174             
175             /* Approximate score: we maximise Rxy/sqrt(Ryy) (we're guaranteed that 
176                Rxy is positive because the sign is pre-computed) */
177             Rxy = MULT16_16_Q15(Rxy,Rxy);
178             /* The idea is to check for num/den >= best_num/best_den, but that way
179                we can do it without any division */
180             /* OPT: Make sure to use conditional moves here */
181             if (MULT16_16(best_den, Rxy) > MULT16_16(Ryy, best_num))
182             {
183                best_den = Ryy;
184                best_num = Rxy;
185                best_id = j;
186             }
187          } while (++j<N);
188       } else {
189          celt_word16_t g;
190          celt_word16_t best_num = -VERY_LARGE16;
191          celt_word16_t best_den = 0;
192          j=0;
193          do {
194             celt_word16_t Rxy, Ryy, Ryp;
195             celt_word16_t num;
196             /* Select sign based on X[j] alone */
197             s = MULT16_16(signx[j],magnitude);
198             /* Temporary sums of the new pulse(s) */
199             Rxy = ROUND16(MAC16_16(xy, s,X[j]), 14);
200             /* We're multiplying y[j] by two so we don't have to do it here */
201             Ryy = ROUND16(MAC16_16(yy, s,y[j]), 14);
202             Ryp = ROUND16(MAC16_16(yp, s,P[j]), 14);
203
204             /* Compute the gain such that ||p + g*y|| = 1 
205                ...but instead, we compute g*Ryy to avoid dividing */
206             g = celt_psqrt(MULT16_16(Ryp,Ryp) + MULT16_16(Ryy,QCONST16(1.f,14)-Rpp)) - Ryp;
207             /* Knowing that gain, what's the error: (x-g*y)^2 
208                (result is negated and we discard x^2 because it's constant) */
209             /* score = 2*g*Rxy - g*g*Ryy;*/
210 #ifdef FIXED_POINT
211             /* No need to multiply Rxy by 2 because we did it earlier */
212             num = MULT16_16_Q15(ADD16(SUB16(Rxy,g),Rxy),g);
213 #else
214             num = g*(2*Rxy-g);
215 #endif
216             if (MULT16_16(best_den, num) > MULT16_16(Ryy, best_num))
217             {
218                best_den = Ryy;
219                best_num = num;
220                best_id = j;
221             }
222          } while (++j<N);
223       }
224       
225       j = best_id;
226       is = MULT16_16(signx[j],pulsesAtOnce);
227       s = SHL16(is, yshift);
228
229       /* Updating the sums of the new pulse(s) */
230       xy = xy + MULT16_16(s,X[j]);
231       /* We're multiplying y[j] by two so we don't have to do it here */
232       yy = yy + MULT16_16(s,y[j]);
233       yp = yp + MULT16_16(s, P[j]);
234
235       /* Only now that we've made the final choice, update y/iy */
236       /* Multiplying y[j] by 2 so we don't have to do it everywhere else */
237       y[j] += 2*s;
238       iy[j] += is;
239       pulsesLeft -= pulsesAtOnce;
240    }
241    
242    encode_pulses(iy, N, K, enc);
243    
244    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
245    due to the recursive computation used in quantisation. */
246    mix_pitch_and_residual(iy, X, N, K, P);
247    RESTORE_STACK;
248 }
249
250
251 /** Decode pulse vector and combine the result with the pitch vector to produce
252     the final normalised signal in the current band. */
253 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, ec_dec *dec)
254 {
255    VARDECL(int, iy);
256    SAVE_STACK;
257    ALLOC(iy, N, int);
258    decode_pulses(iy, N, K, dec);
259    mix_pitch_and_residual(iy, X, N, K, P);
260    RESTORE_STACK;
261 }
262
263 void renormalise_vector(celt_norm_t *X, celt_word16_t value, int N, int stride)
264 {
265    int i;
266    celt_word32_t E = EPSILON;
267    celt_word16_t g;
268    celt_norm_t *xptr = X;
269    for (i=0;i<N;i++)
270    {
271       E = MAC16_16(E, *xptr, *xptr);
272       xptr += stride;
273    }
274
275    g = MULT16_16_Q15(value,celt_rcp(SHL32(celt_sqrt(E),9)));
276    xptr = X;
277    for (i=0;i<N;i++)
278    {
279       *xptr = PSHR32(MULT16_16(g, *xptr),8);
280       xptr += stride;
281    }
282 }
283
284 static void fold(const CELTMode *m, int N, celt_norm_t *Y, celt_norm_t * restrict P, int N0, int B)
285 {
286    int j;
287    const int C = CHANNELS(m);
288    int id = N0 % (C*B);
289    /* Here, we assume that id will never be greater than N0, i.e. that 
290       no band is wider than N0. In the unlikely case it happens, we set
291       everything to zero */
292    if (id+C*N>N0)
293       for (j=0;j<C*N;j++)
294          P[j] = 0;
295    else
296       for (j=0;j<C*N;j++)
297          P[j] = Y[id++];
298 }
299
300 #define KGAIN 6
301
302 void intra_prediction(const CELTMode *m, celt_norm_t * restrict x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t * restrict P, int N0, int B, ec_enc *enc)
303 {
304    int j;
305    celt_word16_t s = 1;
306    int sign;
307    celt_word16_t pred_gain;
308    celt_word32_t xy=0;
309    const int C = CHANNELS(m);
310
311    pred_gain = celt_div((celt_word32_t)MULT16_16(Q15_ONE,N),(celt_word32_t)(N+KGAIN*K));
312
313    fold(m, N, Y, P, N0, B);
314
315    for (j=0;j<C*N;j++)
316       xy = MAC16_16(xy, P[j], x[j]);
317    if (xy<0)
318    {
319       s = -1;
320       sign = 1;
321    } else {
322       s = 1;
323       sign = 0;
324    }
325    ec_enc_bits(enc,sign,1);
326
327    renormalise_vector(P, s*pred_gain, C*N, 1);
328 }
329
330 void intra_unquant(const CELTMode *m, celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t * restrict P, int N0, int B, ec_dec *dec)
331 {
332    celt_word16_t s;
333    celt_word16_t pred_gain;
334    const int C = CHANNELS(m);
335       
336    if (ec_dec_bits(dec, 1) == 0)
337       s = 1;
338    else
339       s = -1;
340    
341    pred_gain = celt_div((celt_word32_t)MULT16_16(Q15_ONE,N),(celt_word32_t)(N+KGAIN*K));
342    
343    fold(m, N, Y, P, N0, B);
344    
345    renormalise_vector(P, s*pred_gain, C*N, 1);
346 }
347
348 void intra_fold(const CELTMode *m, celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t * restrict P, int N0, int B)
349 {
350    const int C = CHANNELS(m);
351
352    fold(m, N, Y, P, N0, B);
353    
354    renormalise_vector(P, Q15ONE, C*N, 1);
355 }
356