Doing the spreading with a "pseudo-fractional-Hadamard" transform
[opus.git] / libcelt / vq.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Written by Jean-Marc Valin */
4 /*
5    Redistribution and use in source and binary forms, with or without
6    modification, are permitted provided that the following conditions
7    are met:
8    
9    - Redistributions of source code must retain the above copyright
10    notice, this list of conditions and the following disclaimer.
11    
12    - Redistributions in binary form must reproduce the above copyright
13    notice, this list of conditions and the following disclaimer in the
14    documentation and/or other materials provided with the distribution.
15    
16    - Neither the name of the Xiph.org Foundation nor the names of its
17    contributors may be used to endorse or promote products derived from
18    this software without specific prior written permission.
19    
20    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
24    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
25    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
26    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
27    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
28    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
29    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32
33 #ifdef HAVE_CONFIG_H
34 #include "config.h"
35 #endif
36
37 #include "mathops.h"
38 #include "cwrs.h"
39 #include "vq.h"
40 #include "arch.h"
41 #include "os_support.h"
42 #include "rate.h"
43
44 #ifndef M_PI
45 #define M_PI 3.141592653
46 #endif
47
48 static void frac_hadamard1(celt_norm *X, int len, int stride, celt_word16 c, celt_word16 s)
49 {
50    int j;
51    celt_norm *x, *y;
52    celt_norm * end;
53
54    j = 0;
55    x = X;
56    y = X+stride;
57    end = X+len;
58    do
59    {
60       celt_norm x1, x2;
61       x1 = *x;
62       x2 = *y;
63       *x++ = EXTRACT16(SHR32(MULT16_16(c,x1) + MULT16_16(s,x2),15));
64       *y++ = EXTRACT16(SHR32(MULT16_16(s,x1) - MULT16_16(c,x2),15));
65       j++;
66       if (j>=stride)
67       {
68          j=0;
69          x+=stride;
70          y+=stride;
71       }
72    } while (y<end);
73
74    /* Reverse samples so that the next level starts from the other end */
75    for (j=0;j<len>>1;j++)
76    {
77       celt_norm tmp = X[j];
78       X[j] = X[len-j-1];
79       X[len-j-1] = tmp;
80    }
81 }
82
83 #define MAX_LEVELS 8
84 static void exp_rotation(celt_norm *X, int len, int dir, int stride, int K)
85 {
86    int i, N=0;
87    celt_word16 gain, theta;
88    int istride[MAX_LEVELS];
89    celt_word16 c, s;
90
91    do {
92       istride[N] = stride;
93       stride *= 2;
94       N++;
95    } while (N<MAX_LEVELS && stride < len);
96
97    gain = celt_div((celt_word32)MULT16_16(Q15_ONE,len),(celt_word32)(3+len+4*K));
98    /* FIXME: Make that HALF16 instead of HALF32 */
99    theta = HALF32(MULT16_16_Q15(gain,gain));
100    c = celt_cos_norm(EXTEND32(theta));
101    s = celt_cos_norm(EXTEND32(SUB16(Q15ONE,theta))); /*  sin(theta) */
102
103    if (dir > 0)
104    {
105       for (i=0;i<N;i++)
106          frac_hadamard1(X, len, istride[i], c, s);
107    } else {
108       for (i=N-1;i>=0;i--)
109          frac_hadamard1(X, len, istride[i], c, s);
110    }
111
112    /* Undo last reversal */
113    for (i=0;i<len>>1;i++)
114    {
115       celt_norm tmp = X[i];
116       X[i] = X[len-i-1];
117       X[len-i-1] = tmp;
118    }
119 }
120
121 /** Takes the pitch vector and the decoded residual vector, computes the gain
122     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
123 static void normalise_residual(int * restrict iy, celt_norm * restrict X, int N, int K, celt_word32 Ryy)
124 {
125    int i;
126 #ifdef FIXED_POINT
127    int k;
128 #endif
129    celt_word32 t;
130    celt_word16 g;
131
132 #ifdef FIXED_POINT
133    k = celt_ilog2(Ryy)>>1;
134 #endif
135    t = VSHR32(Ryy, (k-7)<<1);
136    g = celt_rsqrt_norm(t);
137
138    i=0;
139    do
140       X[i] = EXTRACT16(PSHR32(MULT16_16(g, iy[i]), k+1));
141    while (++i < N);
142 }
143
144 void alg_quant(celt_norm *X, int N, int K, int spread, ec_enc *enc)
145 {
146    VARDECL(celt_norm, y);
147    VARDECL(int, iy);
148    VARDECL(celt_word16, signx);
149    int j, is;
150    celt_word16 s;
151    int pulsesLeft;
152    celt_word32 sum;
153    celt_word32 xy, yy;
154    int N_1; /* Inverse of N, in Q14 format (even for float) */
155 #ifdef FIXED_POINT
156    int yshift;
157 #endif
158    SAVE_STACK;
159
160    K = get_pulses(K);
161 #ifdef FIXED_POINT
162    yshift = 13-celt_ilog2(K);
163 #endif
164
165    ALLOC(y, N, celt_norm);
166    ALLOC(iy, N, int);
167    ALLOC(signx, N, celt_word16);
168    N_1 = 512/N;
169    
170    if (spread)
171       exp_rotation(X, N, 1, spread, K);
172
173    sum = 0;
174    j=0; do {
175       if (X[j]>0)
176          signx[j]=1;
177       else {
178          signx[j]=-1;
179          X[j]=-X[j];
180       }
181       iy[j] = 0;
182       y[j] = 0;
183    } while (++j<N);
184
185    xy = yy = 0;
186
187    pulsesLeft = K;
188
189    /* Do a pre-search by projecting on the pyramid */
190    if (K > (N>>1))
191    {
192       celt_word16 rcp;
193       j=0; do {
194          sum += X[j];
195       }  while (++j<N);
196
197 #ifdef FIXED_POINT
198       if (sum <= K)
199 #else
200       if (sum <= EPSILON)
201 #endif
202       {
203          X[0] = QCONST16(1.f,14);
204          j=1; do
205             X[j]=0;
206          while (++j<N);
207          sum = QCONST16(1.f,14);
208       }
209       /* Do we have sufficient accuracy here? */
210       rcp = EXTRACT16(MULT16_32_Q16(K-1, celt_rcp(sum)));
211       j=0; do {
212 #ifdef FIXED_POINT
213          /* It's really important to round *towards zero* here */
214          iy[j] = MULT16_16_Q15(X[j],rcp);
215 #else
216          iy[j] = floor(rcp*X[j]);
217 #endif
218          y[j] = SHL16(iy[j],yshift);
219          yy = MAC16_16(yy, y[j],y[j]);
220          xy = MAC16_16(xy, X[j],y[j]);
221          y[j] *= 2;
222          pulsesLeft -= iy[j];
223       }  while (++j<N);
224    }
225    celt_assert2(pulsesLeft>=1, "Allocated too many pulses in the quick pass");
226
227    while (pulsesLeft > 0)
228    {
229       int pulsesAtOnce=1;
230       int best_id;
231       celt_word16 magnitude;
232       celt_word32 best_num = -VERY_LARGE16;
233       celt_word16 best_den = 0;
234 #ifdef FIXED_POINT
235       int rshift;
236 #endif
237       /* Decide on how many pulses to find at once */
238       pulsesAtOnce = (pulsesLeft*N_1)>>9; /* pulsesLeft/N */
239       if (pulsesAtOnce<1)
240          pulsesAtOnce = 1;
241 #ifdef FIXED_POINT
242       rshift = yshift+1+celt_ilog2(K-pulsesLeft+pulsesAtOnce);
243 #endif
244       magnitude = SHL16(pulsesAtOnce, yshift);
245
246       best_id = 0;
247       /* The squared magnitude term gets added anyway, so we might as well 
248          add it outside the loop */
249       yy = MAC16_16(yy, magnitude,magnitude);
250       /* Choose between fast and accurate strategy depending on where we are in the search */
251          /* This should ensure that anything we can process will have a better score */
252       j=0;
253       do {
254          celt_word16 Rxy, Ryy;
255          /* Select sign based on X[j] alone */
256          s = magnitude;
257          /* Temporary sums of the new pulse(s) */
258          Rxy = EXTRACT16(SHR32(MAC16_16(xy, s,X[j]),rshift));
259          /* We're multiplying y[j] by two so we don't have to do it here */
260          Ryy = EXTRACT16(SHR32(MAC16_16(yy, s,y[j]),rshift));
261             
262             /* Approximate score: we maximise Rxy/sqrt(Ryy) (we're guaranteed that 
263          Rxy is positive because the sign is pre-computed) */
264          Rxy = MULT16_16_Q15(Rxy,Rxy);
265             /* The idea is to check for num/den >= best_num/best_den, but that way
266          we can do it without any division */
267          /* OPT: Make sure to use conditional moves here */
268          if (MULT16_16(best_den, Rxy) > MULT16_16(Ryy, best_num))
269          {
270             best_den = Ryy;
271             best_num = Rxy;
272             best_id = j;
273          }
274       } while (++j<N);
275       
276       j = best_id;
277       is = pulsesAtOnce;
278       s = SHL16(is, yshift);
279
280       /* Updating the sums of the new pulse(s) */
281       xy = xy + MULT16_16(s,X[j]);
282       /* We're multiplying y[j] by two so we don't have to do it here */
283       yy = yy + MULT16_16(s,y[j]);
284
285       /* Only now that we've made the final choice, update y/iy */
286       /* Multiplying y[j] by 2 so we don't have to do it everywhere else */
287       y[j] += 2*s;
288       iy[j] += is;
289       pulsesLeft -= pulsesAtOnce;
290    }
291    j=0;
292    do {
293       X[j] = MULT16_16(signx[j],X[j]);
294       if (signx[j] < 0)
295          iy[j] = -iy[j];
296    } while (++j<N);
297    encode_pulses(iy, N, K, enc);
298    
299    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
300    due to the recursive computation used in quantisation. */
301    normalise_residual(iy, X, N, K, EXTRACT16(SHR32(yy,2*yshift)));
302    if (spread)
303       exp_rotation(X, N, -1, spread, K);
304    RESTORE_STACK;
305 }
306
307
308 /** Decode pulse vector and combine the result with the pitch vector to produce
309     the final normalised signal in the current band. */
310 void alg_unquant(celt_norm *X, int N, int K, int spread, ec_dec *dec)
311 {
312    int i;
313    celt_word32 Ryy;
314    VARDECL(int, iy);
315    SAVE_STACK;
316    K = get_pulses(K);
317    ALLOC(iy, N, int);
318    decode_pulses(iy, N, K, dec);
319    Ryy = 0;
320    i=0;
321    do {
322       Ryy = MAC16_16(Ryy, iy[i], iy[i]);
323    } while (++i < N);
324    normalise_residual(iy, X, N, K, Ryy);
325    if (spread)
326       exp_rotation(X, N, -1, spread, K);
327    RESTORE_STACK;
328 }
329
330 celt_word16 renormalise_vector(celt_norm *X, celt_word16 value, int N, int stride)
331 {
332    int i;
333    celt_word32 E = EPSILON;
334    celt_word16 g;
335    celt_norm *xptr = X;
336    for (i=0;i<N;i++)
337    {
338       E = MAC16_16(E, *xptr, *xptr);
339       xptr += stride;
340    }
341 #ifdef FIXED_POINT
342    int k = celt_ilog2(E)>>1;
343 #endif
344    celt_word32 t = VSHR32(E, (k-7)<<1);
345    g = MULT16_16_Q15(value, celt_rsqrt_norm(t));
346
347    xptr = X;
348    for (i=0;i<N;i++)
349    {
350       *xptr = EXTRACT16(PSHR32(MULT16_16(g, *xptr), k+1));
351       xptr += stride;
352    }
353    return celt_sqrt(E);
354 }
355
356 static void fold(const CELTMode *m, int start, int N, const celt_norm * restrict Y, celt_norm * restrict P, int N0, int B)
357 {
358    int j;
359    int id = N0 % B;
360    while (id < m->eBands[start])
361       id += B;
362    /* Here, we assume that id will never be greater than N0, i.e. that 
363       no band is wider than N0. In the unlikely case it happens, we set
364       everything to zero */
365    /*{
366            int offset = (N0*C - (id+C*N))/2;
367            if (offset > C*N0/16)
368                    offset = C*N0/16;
369            offset -= offset % (C*B);
370            if (offset < 0)
371                    offset = 0;
372            //printf ("%d\n", offset);
373            id += offset;
374    }*/
375    if (id+N>N0)
376       for (j=0;j<N;j++)
377          P[j] = 0;
378    else
379       for (j=0;j<N;j++)
380          P[j] = Y[id++];
381 }
382
383 void intra_fold(const CELTMode *m, int start, int N, const celt_norm * restrict Y, celt_norm * restrict P, int N0, int B)
384 {
385    fold(m, start, N, Y, P, N0, B);
386    renormalise_vector(P, Q15ONE, N, 1);
387 }
388