A bit more tuning on the pseudo-frac-Hadamard. Also Trying to improve
[opus.git] / libcelt / vq.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Written by Jean-Marc Valin */
4 /*
5    Redistribution and use in source and binary forms, with or without
6    modification, are permitted provided that the following conditions
7    are met:
8    
9    - Redistributions of source code must retain the above copyright
10    notice, this list of conditions and the following disclaimer.
11    
12    - Redistributions in binary form must reproduce the above copyright
13    notice, this list of conditions and the following disclaimer in the
14    documentation and/or other materials provided with the distribution.
15    
16    - Neither the name of the Xiph.org Foundation nor the names of its
17    contributors may be used to endorse or promote products derived from
18    this software without specific prior written permission.
19    
20    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
24    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
25    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
26    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
27    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
28    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
29    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32
33 #ifdef HAVE_CONFIG_H
34 #include "config.h"
35 #endif
36
37 #include "mathops.h"
38 #include "cwrs.h"
39 #include "vq.h"
40 #include "arch.h"
41 #include "os_support.h"
42 #include "rate.h"
43
44 #ifndef M_PI
45 #define M_PI 3.141592653
46 #endif
47
48 static void frac_hadamard1(celt_norm *X, int len, int stride, celt_word16 c, celt_word16 s)
49 {
50    int j;
51    celt_norm *x, *y;
52    celt_norm * end;
53
54    j = 0;
55    x = X;
56    y = X+stride;
57    end = X+len;
58    do
59    {
60       celt_norm x1, x2;
61       x1 = *x;
62       x2 = *y;
63       *x++ = EXTRACT16(SHR32(MULT16_16(c,x1) + MULT16_16(s,x2),15));
64       *y++ = EXTRACT16(SHR32(MULT16_16(s,x1) - MULT16_16(c,x2),15));
65       j++;
66       if (j>=stride)
67       {
68          j=0;
69          x+=stride;
70          y+=stride;
71       }
72    } while (y<end);
73
74    /* Reverse samples so that the next level starts from the other end */
75    for (j=0;j<len>>1;j++)
76    {
77       celt_norm tmp = X[j];
78       X[j] = X[len-j-1];
79       X[len-j-1] = tmp;
80    }
81 }
82
83 #define MAX_LEVELS 8
84 static void exp_rotation(celt_norm *X, int len, int dir, int stride, int K)
85 {
86    int i, N=0;
87    int transient;
88    celt_word16 gain, theta;
89    int istride[MAX_LEVELS];
90    celt_word16 c[MAX_LEVELS], s[MAX_LEVELS];
91
92    if (K >= len)
93       return;
94    transient = stride>1;
95    /*if (len>=30)
96    {
97       for (i=0;i<len;i++)
98          X[i] = 0;
99       X[30] = 1;
100       dir = -1;
101       transient = 1;
102    }*/
103    gain = celt_div((celt_word32)MULT16_16(Q15_ONE,len),(celt_word32)(3+len+4*K));
104    /* FIXME: Make that HALF16 instead of HALF32 */
105    theta = HALF32(MULT16_16_Q15(gain,gain));
106    c[0] = celt_cos_norm(EXTEND32(theta));
107    s[0] = celt_cos_norm(EXTEND32(SUB16(Q15ONE,theta))); /*  sin(theta) */
108
109    do {
110       istride[N] = stride;
111       stride *= 2;
112       if (K==1)
113          theta = QCONST16(.25f,15);
114       c[N] = c[0];
115       s[N] = s[0];
116       N++;
117    } while (N<MAX_LEVELS && stride < len);
118
119    /* This should help a little bit with the transients */
120    if (transient)
121       c[0] = s[0] = QCONST16(.7071068, 15);
122
123    /* Needs to be < 0 to prevent gaps on the side of the spreading */
124    if (dir < 0)
125    {
126       for (i=0;i<N;i++)
127          frac_hadamard1(X, len, istride[i], c[i], s[i]);
128    } else {
129       for (i=N-1;i>=0;i--)
130          frac_hadamard1(X, len, istride[i], c[i], s[i]);
131    }
132
133    /* Undo last reversal */
134    for (i=0;i<len>>1;i++)
135    {
136       celt_norm tmp = X[i];
137       X[i] = X[len-i-1];
138       X[len-i-1] = tmp;
139    }
140    /*if (len>=30)
141    {
142       for (i=0;i<len;i++)
143          printf ("%f ", X[i]);
144       printf ("\n");
145       exit(0);
146    }*/
147 }
148
149 /** Takes the pitch vector and the decoded residual vector, computes the gain
150     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
151 static void normalise_residual(int * restrict iy, celt_norm * restrict X, int N, int K, celt_word32 Ryy)
152 {
153    int i;
154 #ifdef FIXED_POINT
155    int k;
156 #endif
157    celt_word32 t;
158    celt_word16 g;
159
160 #ifdef FIXED_POINT
161    k = celt_ilog2(Ryy)>>1;
162 #endif
163    t = VSHR32(Ryy, (k-7)<<1);
164    g = celt_rsqrt_norm(t);
165
166    i=0;
167    do
168       X[i] = EXTRACT16(PSHR32(MULT16_16(g, iy[i]), k+1));
169    while (++i < N);
170 }
171
172 void alg_quant(celt_norm *X, int N, int K, int spread, ec_enc *enc)
173 {
174    VARDECL(celt_norm, y);
175    VARDECL(int, iy);
176    VARDECL(celt_word16, signx);
177    int j, is;
178    celt_word16 s;
179    int pulsesLeft;
180    celt_word32 sum;
181    celt_word32 xy, yy;
182    int N_1; /* Inverse of N, in Q14 format (even for float) */
183 #ifdef FIXED_POINT
184    int yshift;
185 #endif
186    SAVE_STACK;
187
188    K = get_pulses(K);
189 #ifdef FIXED_POINT
190    yshift = 13-celt_ilog2(K);
191 #endif
192
193    ALLOC(y, N, celt_norm);
194    ALLOC(iy, N, int);
195    ALLOC(signx, N, celt_word16);
196    N_1 = 512/N;
197    
198    if (spread)
199       exp_rotation(X, N, 1, spread, K);
200
201    sum = 0;
202    j=0; do {
203       if (X[j]>0)
204          signx[j]=1;
205       else {
206          signx[j]=-1;
207          X[j]=-X[j];
208       }
209       iy[j] = 0;
210       y[j] = 0;
211    } while (++j<N);
212
213    xy = yy = 0;
214
215    pulsesLeft = K;
216
217    /* Do a pre-search by projecting on the pyramid */
218    if (K > (N>>1))
219    {
220       celt_word16 rcp;
221       j=0; do {
222          sum += X[j];
223       }  while (++j<N);
224
225 #ifdef FIXED_POINT
226       if (sum <= K)
227 #else
228       if (sum <= EPSILON)
229 #endif
230       {
231          X[0] = QCONST16(1.f,14);
232          j=1; do
233             X[j]=0;
234          while (++j<N);
235          sum = QCONST16(1.f,14);
236       }
237       /* Do we have sufficient accuracy here? */
238       rcp = EXTRACT16(MULT16_32_Q16(K-1, celt_rcp(sum)));
239       j=0; do {
240 #ifdef FIXED_POINT
241          /* It's really important to round *towards zero* here */
242          iy[j] = MULT16_16_Q15(X[j],rcp);
243 #else
244          iy[j] = floor(rcp*X[j]);
245 #endif
246          y[j] = SHL16(iy[j],yshift);
247          yy = MAC16_16(yy, y[j],y[j]);
248          xy = MAC16_16(xy, X[j],y[j]);
249          y[j] *= 2;
250          pulsesLeft -= iy[j];
251       }  while (++j<N);
252    }
253    celt_assert2(pulsesLeft>=1, "Allocated too many pulses in the quick pass");
254
255    while (pulsesLeft > 0)
256    {
257       int pulsesAtOnce=1;
258       int best_id;
259       celt_word16 magnitude;
260       celt_word32 best_num = -VERY_LARGE16;
261       celt_word16 best_den = 0;
262 #ifdef FIXED_POINT
263       int rshift;
264 #endif
265       /* Decide on how many pulses to find at once */
266       pulsesAtOnce = (pulsesLeft*N_1)>>9; /* pulsesLeft/N */
267       if (pulsesAtOnce<1)
268          pulsesAtOnce = 1;
269 #ifdef FIXED_POINT
270       rshift = yshift+1+celt_ilog2(K-pulsesLeft+pulsesAtOnce);
271 #endif
272       magnitude = SHL16(pulsesAtOnce, yshift);
273
274       best_id = 0;
275       /* The squared magnitude term gets added anyway, so we might as well 
276          add it outside the loop */
277       yy = MAC16_16(yy, magnitude,magnitude);
278       /* Choose between fast and accurate strategy depending on where we are in the search */
279          /* This should ensure that anything we can process will have a better score */
280       j=0;
281       do {
282          celt_word16 Rxy, Ryy;
283          /* Select sign based on X[j] alone */
284          s = magnitude;
285          /* Temporary sums of the new pulse(s) */
286          Rxy = EXTRACT16(SHR32(MAC16_16(xy, s,X[j]),rshift));
287          /* We're multiplying y[j] by two so we don't have to do it here */
288          Ryy = EXTRACT16(SHR32(MAC16_16(yy, s,y[j]),rshift));
289             
290             /* Approximate score: we maximise Rxy/sqrt(Ryy) (we're guaranteed that 
291          Rxy is positive because the sign is pre-computed) */
292          Rxy = MULT16_16_Q15(Rxy,Rxy);
293             /* The idea is to check for num/den >= best_num/best_den, but that way
294          we can do it without any division */
295          /* OPT: Make sure to use conditional moves here */
296          if (MULT16_16(best_den, Rxy) > MULT16_16(Ryy, best_num))
297          {
298             best_den = Ryy;
299             best_num = Rxy;
300             best_id = j;
301          }
302       } while (++j<N);
303       
304       j = best_id;
305       is = pulsesAtOnce;
306       s = SHL16(is, yshift);
307
308       /* Updating the sums of the new pulse(s) */
309       xy = xy + MULT16_16(s,X[j]);
310       /* We're multiplying y[j] by two so we don't have to do it here */
311       yy = yy + MULT16_16(s,y[j]);
312
313       /* Only now that we've made the final choice, update y/iy */
314       /* Multiplying y[j] by 2 so we don't have to do it everywhere else */
315       y[j] += 2*s;
316       iy[j] += is;
317       pulsesLeft -= pulsesAtOnce;
318    }
319    j=0;
320    do {
321       X[j] = MULT16_16(signx[j],X[j]);
322       if (signx[j] < 0)
323          iy[j] = -iy[j];
324    } while (++j<N);
325    encode_pulses(iy, N, K, enc);
326    
327    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
328    due to the recursive computation used in quantisation. */
329    normalise_residual(iy, X, N, K, EXTRACT16(SHR32(yy,2*yshift)));
330    if (spread)
331       exp_rotation(X, N, -1, spread, K);
332    RESTORE_STACK;
333 }
334
335
336 /** Decode pulse vector and combine the result with the pitch vector to produce
337     the final normalised signal in the current band. */
338 void alg_unquant(celt_norm *X, int N, int K, int spread, ec_dec *dec)
339 {
340    int i;
341    celt_word32 Ryy;
342    VARDECL(int, iy);
343    SAVE_STACK;
344    K = get_pulses(K);
345    ALLOC(iy, N, int);
346    decode_pulses(iy, N, K, dec);
347    Ryy = 0;
348    i=0;
349    do {
350       Ryy = MAC16_16(Ryy, iy[i], iy[i]);
351    } while (++i < N);
352    normalise_residual(iy, X, N, K, Ryy);
353    if (spread)
354       exp_rotation(X, N, -1, spread, K);
355    RESTORE_STACK;
356 }
357
358 celt_word16 renormalise_vector(celt_norm *X, celt_word16 value, int N, int stride)
359 {
360    int i;
361    celt_word32 E = EPSILON;
362    celt_word16 g;
363    celt_norm *xptr = X;
364    for (i=0;i<N;i++)
365    {
366       E = MAC16_16(E, *xptr, *xptr);
367       xptr += stride;
368    }
369 #ifdef FIXED_POINT
370    int k = celt_ilog2(E)>>1;
371 #endif
372    celt_word32 t = VSHR32(E, (k-7)<<1);
373    g = MULT16_16_Q15(value, celt_rsqrt_norm(t));
374
375    xptr = X;
376    for (i=0;i<N;i++)
377    {
378       *xptr = EXTRACT16(PSHR32(MULT16_16(g, *xptr), k+1));
379       xptr += stride;
380    }
381    return celt_sqrt(E);
382 }
383
384 static void fold(const CELTMode *m, int start, int N, const celt_norm * restrict Y, celt_norm * restrict P, int N0, int B)
385 {
386    int j;
387    int id = N0 % B;
388    while (id < m->eBands[start])
389       id += B;
390    /* Here, we assume that id will never be greater than N0, i.e. that 
391       no band is wider than N0. In the unlikely case it happens, we set
392       everything to zero */
393    /*{
394            int offset = (N0*C - (id+C*N))/2;
395            if (offset > C*N0/16)
396                    offset = C*N0/16;
397            offset -= offset % (C*B);
398            if (offset < 0)
399                    offset = 0;
400            //printf ("%d\n", offset);
401            id += offset;
402    }*/
403    if (id+N>N0)
404       for (j=0;j<N;j++)
405          P[j] = 0;
406    else
407       for (j=0;j<N;j++)
408          P[j] = Y[id++];
409 }
410
411 void intra_fold(const CELTMode *m, int start, int N, const celt_norm * restrict Y, celt_norm * restrict P, int N0, int B)
412 {
413    fold(m, start, N, Y, P, N0, B);
414    renormalise_vector(P, Q15ONE, N, 1);
415 }
416