62a8784816a031ccd684c3c8a4cda799d0751068
[opus.git] / libcelt / vq.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "mathops.h"
37 #include "cwrs.h"
38 #include "vq.h"
39 #include "arch.h"
40 #include "os_support.h"
41 #include "rate.h"
42
43 #ifndef M_PI
44 #define M_PI 3.141592653
45 #endif
46
47 static void exp_rotation(celt_norm_t *X, int len, int dir, int stride, int K)
48 {
49    int i, k, iter;
50    celt_word16_t c, s;
51    celt_word16_t gain, theta;
52    celt_norm_t *Xptr;
53    gain = celt_div((celt_word32_t)MULT16_16(Q15_ONE,len),(celt_word32_t)(3+len+6*K));
54    /* FIXME: Make that HALF16 instead of HALF32 */
55    theta = SUB16(Q15ONE, HALF32(MULT16_16_Q15(gain,gain)));
56    /*if (len==30)
57    {
58    for (i=0;i<len;i++)
59    X[i] = 0;
60    X[14] = 1;
61 }*/ 
62    c = celt_cos_norm(EXTEND32(theta));
63    s = dir*celt_cos_norm(EXTEND32(SUB16(Q15ONE,theta))); /*  sin(theta) */
64    if (len > 8*stride)
65       stride *= len/(8*stride);
66    iter = 1;
67    for (k=0;k<iter;k++)
68    {
69       /* We could use MULT16_16_P15 instead of MULT16_16_Q15 for more accuracy, 
70       but at this point, I really don't think it's necessary */
71       Xptr = X;
72       for (i=0;i<len-stride;i++)
73       {
74          celt_norm_t x1, x2;
75          x1 = Xptr[0];
76          x2 = Xptr[stride];
77          Xptr[stride] = MULT16_16_Q15(c,x2) + MULT16_16_Q15(s,x1);
78          *Xptr++      = MULT16_16_Q15(c,x1) - MULT16_16_Q15(s,x2);
79       }
80       Xptr = &X[len-2*stride-1];
81       for (i=len-2*stride-1;i>=0;i--)
82       {
83          celt_norm_t x1, x2;
84          x1 = Xptr[0];
85          x2 = Xptr[stride];
86          Xptr[stride] = MULT16_16_Q15(c,x2) + MULT16_16_Q15(s,x1);
87          *Xptr--      = MULT16_16_Q15(c,x1) - MULT16_16_Q15(s,x2);
88       }
89    }
90    /*if (len==30)
91    {
92    for (i=0;i<len;i++)
93    printf ("%f ", X[i]);
94    printf ("\n");
95    exit(0);
96 }*/
97 }
98
99
100 /** Takes the pitch vector and the decoded residual vector, computes the gain
101     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
102 static void mix_pitch_and_residual(int * restrict iy, celt_norm_t * restrict X, int N, int K, celt_word32_t Ryy)
103 {
104    int i;
105    celt_word32_t g;
106
107    g = celt_rsqrt(Ryy);
108
109    i=0;
110    do
111       X[i] = MULT16_32_P15(g, SHL32(EXTEND32(iy[i]),14));
112    while (++i < N);
113 }
114
115 void alg_quant(celt_norm_t *X, int N, int K, int spread, ec_enc *enc)
116 {
117    VARDECL(celt_norm_t, y);
118    VARDECL(int, iy);
119    VARDECL(celt_word16_t, signx);
120    int j, is;
121    celt_word16_t s;
122    int pulsesLeft;
123    celt_word32_t sum;
124    celt_word32_t xy, yy;
125    int N_1; /* Inverse of N, in Q14 format (even for float) */
126 #ifdef FIXED_POINT
127    int yshift;
128 #endif
129    SAVE_STACK;
130
131    K = get_pulses(K);
132 #ifdef FIXED_POINT
133    yshift = 13-celt_ilog2(K);
134 #endif
135
136    ALLOC(y, N, celt_norm_t);
137    ALLOC(iy, N, int);
138    ALLOC(signx, N, celt_word16_t);
139    N_1 = 512/N;
140    
141    if (spread)
142       exp_rotation(X, N, 1, spread, K);
143
144    sum = 0;
145    j=0; do {
146       if (X[j]>0)
147          signx[j]=1;
148       else {
149          signx[j]=-1;
150          X[j]=-X[j];
151       }
152       iy[j] = 0;
153       y[j] = 0;
154    } while (++j<N);
155
156    xy = yy = 0;
157
158    pulsesLeft = K;
159
160    /* Do a pre-search by projecting on the pyramid */
161    if (K > (N>>1))
162    {
163       celt_word16_t rcp;
164       sum=0;
165       j=0; do {
166          sum += X[j];
167       }  while (++j<N);
168
169 #ifdef FIXED_POINT
170       if (sum <= K)
171 #else
172       if (sum <= EPSILON)
173 #endif
174       {
175          X[0] = QCONST16(1.f,14);
176          j=1; do
177             X[j]=0;
178          while (++j<N);
179          sum = QCONST16(1.f,14);
180       }
181       /* Do we have sufficient accuracy here? */
182       rcp = EXTRACT16(MULT16_32_Q16(K-1, celt_rcp(sum)));
183       j=0; do {
184 #ifdef FIXED_POINT
185          /* It's really important to round *towards zero* here */
186          iy[j] = MULT16_16_Q15(X[j],rcp);
187 #else
188          iy[j] = floor(rcp*X[j]);
189 #endif
190          y[j] = SHL16(iy[j],yshift);
191          yy = MAC16_16(yy, y[j],y[j]);
192          xy = MAC16_16(xy, X[j],y[j]);
193          y[j] *= 2;
194          pulsesLeft -= iy[j];
195       }  while (++j<N);
196    }
197    celt_assert2(pulsesLeft>=1, "Allocated too many pulses in the quick pass");
198
199    while (pulsesLeft > 0)
200    {
201       int pulsesAtOnce=1;
202       int best_id;
203       celt_word16_t magnitude;
204       celt_word32_t best_num = -VERY_LARGE16;
205       celt_word16_t best_den = 0;
206 #ifdef FIXED_POINT
207       int rshift;
208 #endif
209       /* Decide on how many pulses to find at once */
210       pulsesAtOnce = (pulsesLeft*N_1)>>9; /* pulsesLeft/N */
211       if (pulsesAtOnce<1)
212          pulsesAtOnce = 1;
213 #ifdef FIXED_POINT
214       rshift = yshift+1+celt_ilog2(K-pulsesLeft+pulsesAtOnce);
215 #endif
216       magnitude = SHL16(pulsesAtOnce, yshift);
217
218       best_id = 0;
219       /* The squared magnitude term gets added anyway, so we might as well 
220          add it outside the loop */
221       yy = MAC16_16(yy, magnitude,magnitude);
222       /* Choose between fast and accurate strategy depending on where we are in the search */
223          /* This should ensure that anything we can process will have a better score */
224       j=0;
225       do {
226          celt_word16_t Rxy, Ryy;
227          /* Select sign based on X[j] alone */
228          s = magnitude;
229          /* Temporary sums of the new pulse(s) */
230          Rxy = EXTRACT16(SHR32(MAC16_16(xy, s,X[j]),rshift));
231          /* We're multiplying y[j] by two so we don't have to do it here */
232          Ryy = EXTRACT16(SHR32(MAC16_16(yy, s,y[j]),rshift));
233             
234             /* Approximate score: we maximise Rxy/sqrt(Ryy) (we're guaranteed that 
235          Rxy is positive because the sign is pre-computed) */
236          Rxy = MULT16_16_Q15(Rxy,Rxy);
237             /* The idea is to check for num/den >= best_num/best_den, but that way
238          we can do it without any division */
239          /* OPT: Make sure to use conditional moves here */
240          if (MULT16_16(best_den, Rxy) > MULT16_16(Ryy, best_num))
241          {
242             best_den = Ryy;
243             best_num = Rxy;
244             best_id = j;
245          }
246       } while (++j<N);
247       
248       j = best_id;
249       is = pulsesAtOnce;
250       s = SHL16(is, yshift);
251
252       /* Updating the sums of the new pulse(s) */
253       xy = xy + MULT16_16(s,X[j]);
254       /* We're multiplying y[j] by two so we don't have to do it here */
255       yy = yy + MULT16_16(s,y[j]);
256
257       /* Only now that we've made the final choice, update y/iy */
258       /* Multiplying y[j] by 2 so we don't have to do it everywhere else */
259       y[j] += 2*s;
260       iy[j] += is;
261       pulsesLeft -= pulsesAtOnce;
262    }
263    j=0;
264    do {
265       X[j] = MULT16_16(signx[j],X[j]);
266       if (signx[j] < 0)
267          iy[j] = -iy[j];
268    } while (++j<N);
269    encode_pulses(iy, N, K, enc);
270    
271    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
272    due to the recursive computation used in quantisation. */
273    mix_pitch_and_residual(iy, X, N, K, EXTRACT16(SHR32(yy,2*yshift)));
274    if (spread)
275       exp_rotation(X, N, -1, spread, K);
276    RESTORE_STACK;
277 }
278
279
280 /** Decode pulse vector and combine the result with the pitch vector to produce
281     the final normalised signal in the current band. */
282 void alg_unquant(celt_norm_t *X, int N, int K, int spread, ec_dec *dec)
283 {
284    int i;
285    celt_word32_t Ryy;
286    VARDECL(int, iy);
287    SAVE_STACK;
288    K = get_pulses(K);
289    ALLOC(iy, N, int);
290    decode_pulses(iy, N, K, dec);
291    Ryy = 0;
292    i=0;
293    do {
294       Ryy = MAC16_16(Ryy, iy[i], iy[i]);
295    } while (++i < N);
296    mix_pitch_and_residual(iy, X, N, K, Ryy);
297    if (spread)
298       exp_rotation(X, N, -1, spread, K);
299    RESTORE_STACK;
300 }
301
302 celt_word16_t renormalise_vector(celt_norm_t *X, celt_word16_t value, int N, int stride)
303 {
304    int i;
305    celt_word32_t E = EPSILON;
306    celt_word16_t rE;
307    celt_word16_t g;
308    celt_norm_t *xptr = X;
309    for (i=0;i<N;i++)
310    {
311       E = MAC16_16(E, *xptr, *xptr);
312       xptr += stride;
313    }
314
315    rE = celt_sqrt(E);
316 #ifdef FIXED_POINT
317    if (rE <= 128)
318       g = Q15ONE;
319    else
320 #endif
321       g = MULT16_16_Q15(value,celt_rcp(SHL32(rE,9)));
322    xptr = X;
323    for (i=0;i<N;i++)
324    {
325       *xptr = PSHR32(MULT16_16(g, *xptr),8);
326       xptr += stride;
327    }
328    return rE;
329 }
330
331 static void fold(const CELTMode *m, int N, const celt_norm_t * restrict Y, celt_norm_t * restrict P, int N0, int B)
332 {
333    int j;
334    int id = N0 % B;
335    /* Here, we assume that id will never be greater than N0, i.e. that 
336       no band is wider than N0. In the unlikely case it happens, we set
337       everything to zero */
338    /*{
339            int offset = (N0*C - (id+C*N))/2;
340            if (offset > C*N0/16)
341                    offset = C*N0/16;
342            offset -= offset % (C*B);
343            if (offset < 0)
344                    offset = 0;
345            //printf ("%d\n", offset);
346            id += offset;
347    }*/
348    if (id+N>N0)
349       for (j=0;j<N;j++)
350          P[j] = 0;
351    else
352       for (j=0;j<N;j++)
353          P[j] = Y[id++];
354 }
355
356 void intra_fold(const CELTMode *m, int N, const celt_norm_t * restrict Y, celt_norm_t * restrict P, int N0, int B)
357 {
358    fold(m, N, Y, P, N0, B);
359    renormalise_vector(P, Q15ONE, N, 1);
360 }
361