removed useless comments
[opus.git] / libcelt / vq.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "mathops.h"
37 #include "cwrs.h"
38 #include "vq.h"
39 #include "arch.h"
40 #include "os_support.h"
41
42 /** Takes the pitch vector and the decoded residual vector, computes the gain
43     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
44 static void mix_pitch_and_residual(int * restrict iy, celt_norm_t * restrict X, int N, int K, const celt_norm_t * restrict P)
45 {
46    int i;
47    celt_word32_t Ryp, Ryy, Rpp;
48    celt_word16_t ryp, ryy, rpp;
49    celt_word32_t g;
50    VARDECL(celt_norm_t, y);
51 #ifdef FIXED_POINT
52    int yshift;
53 #endif
54    SAVE_STACK;
55 #ifdef FIXED_POINT
56    yshift = 13-celt_ilog2(K);
57 #endif
58    ALLOC(y, N, celt_norm_t);
59
60    Rpp = 0;
61    i=0;
62    do {
63       Rpp = MAC16_16(Rpp,P[i],P[i]);
64       y[i] = SHL16(iy[i],yshift);
65    } while (++i < N);
66
67    Ryp = 0;
68    Ryy = 0;
69    /* If this doesn't generate a dual MAC (on supported archs), fire the compiler guy */
70    i=0;
71    do {
72       Ryp = MAC16_16(Ryp, y[i], P[i]);
73       Ryy = MAC16_16(Ryy, y[i], y[i]);
74    } while (++i < N);
75
76    ryp = ROUND16(Ryp,14);
77    ryy = ROUND16(Ryy,14);
78    rpp = ROUND16(Rpp,14);
79    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
80    g = MULT16_32_Q15(celt_sqrt(MAC16_16(Ryy, ryp,ryp) - MULT16_16(ryy,rpp)) - ryp,
81                      celt_rcp(SHR32(Ryy,9)));
82
83    i=0;
84    do 
85       X[i] = ADD16(P[i], ROUND16(MULT16_16(y[i], g),11));
86    while (++i < N);
87
88    RESTORE_STACK;
89 }
90
91
92 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, celt_norm_t *P, ec_enc *enc)
93 {
94    VARDECL(celt_norm_t, y);
95    VARDECL(int, iy);
96    VARDECL(celt_word16_t, signx);
97    int j, is;
98    celt_word16_t s;
99    int pulsesLeft;
100    celt_word32_t sum;
101    celt_word32_t xy, yy, yp;
102    celt_word16_t Rpp;
103    int N_1; /* Inverse of N, in Q14 format (even for float) */
104 #ifdef FIXED_POINT
105    int yshift;
106 #endif
107    SAVE_STACK;
108
109 #ifdef FIXED_POINT
110    yshift = 13-celt_ilog2(K);
111 #endif
112
113    ALLOC(y, N, celt_norm_t);
114    ALLOC(iy, N, int);
115    ALLOC(signx, N, celt_word16_t);
116    N_1 = 512/N;
117
118    sum = 0;
119    j=0; do {
120       X[j] -= P[j];
121       if (X[j]>0)
122          signx[j]=1;
123       else {
124          signx[j]=-1;
125          X[j]=-X[j];
126          P[j]=-P[j];
127       }
128       iy[j] = 0;
129       y[j] = 0;
130       sum = MAC16_16(sum, P[j],P[j]);
131    } while (++j<N);
132    Rpp = ROUND16(sum, NORM_SHIFT);
133
134    celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
135
136    xy = yy = yp = 0;
137
138    pulsesLeft = K;
139
140    /* Do a pre-search by projecting on the pyramid */
141    if (K > (N>>1))
142    {
143       celt_word16_t rcp;
144       sum=0;
145       j=0; do {
146          sum += X[j];
147       }  while (++j<N);
148
149 #ifdef FIXED_POINT
150       if (sum <= K)
151 #else
152       if (sum <= EPSILON)
153 #endif
154       {
155          X[0] = QCONST16(1.f,14);
156          j=1; do
157             X[j]=0;
158          while (++j<N);
159          sum = QCONST16(1.f,14);
160       }
161       /* Do we have sufficient accuracy here? */
162       rcp = EXTRACT16(MULT16_32_Q16(K-1, celt_rcp(sum)));
163       j=0; do {
164 #ifdef FIXED_POINT
165          /* It's really important to round *towards zero* here */
166          iy[j] = MULT16_16_Q15(X[j],rcp);
167 #else
168          iy[j] = floor(rcp*X[j]);
169 #endif
170          y[j] = SHL16(iy[j],yshift);
171          yy = MAC16_16(yy, y[j],y[j]);
172          xy = MAC16_16(xy, X[j],y[j]);
173          yp += P[j]*y[j];
174          y[j] *= 2;
175          pulsesLeft -= iy[j];
176       }  while (++j<N);
177    }
178    celt_assert2(pulsesLeft>=1, "Allocated too many pulses in the quick pass");
179
180    while (pulsesLeft > 1)
181    {
182       int pulsesAtOnce=1;
183       int best_id;
184       celt_word16_t magnitude;
185       celt_word32_t best_num = -VERY_LARGE16;
186       celt_word16_t best_den = 0;
187 #ifdef FIXED_POINT
188       int rshift;
189 #endif
190       /* Decide on how many pulses to find at once */
191       pulsesAtOnce = (pulsesLeft*N_1)>>9; /* pulsesLeft/N */
192       if (pulsesAtOnce<1)
193          pulsesAtOnce = 1;
194 #ifdef FIXED_POINT
195       rshift = yshift+1+celt_ilog2(K-pulsesLeft+pulsesAtOnce);
196 #endif
197       magnitude = SHL16(pulsesAtOnce, yshift);
198
199       best_id = 0;
200       /* The squared magnitude term gets added anyway, so we might as well 
201          add it outside the loop */
202       yy = MAC16_16(yy, magnitude,magnitude);
203       /* Choose between fast and accurate strategy depending on where we are in the search */
204          /* This should ensure that anything we can process will have a better score */
205       j=0;
206       do {
207          celt_word16_t Rxy, Ryy;
208          /* Select sign based on X[j] alone */
209          s = magnitude;
210          /* Temporary sums of the new pulse(s) */
211          Rxy = EXTRACT16(SHR32(MAC16_16(xy, s,X[j]),rshift));
212          /* We're multiplying y[j] by two so we don't have to do it here */
213          Ryy = EXTRACT16(SHR32(MAC16_16(yy, s,y[j]),rshift));
214             
215             /* Approximate score: we maximise Rxy/sqrt(Ryy) (we're guaranteed that 
216          Rxy is positive because the sign is pre-computed) */
217          Rxy = MULT16_16_Q15(Rxy,Rxy);
218             /* The idea is to check for num/den >= best_num/best_den, but that way
219          we can do it without any division */
220          /* OPT: Make sure to use conditional moves here */
221          if (MULT16_16(best_den, Rxy) > MULT16_16(Ryy, best_num))
222          {
223             best_den = Ryy;
224             best_num = Rxy;
225             best_id = j;
226          }
227       } while (++j<N);
228       
229       j = best_id;
230       is = pulsesAtOnce;
231       s = SHL16(is, yshift);
232
233       /* Updating the sums of the new pulse(s) */
234       xy = xy + MULT16_16(s,X[j]);
235       /* We're multiplying y[j] by two so we don't have to do it here */
236       yy = yy + MULT16_16(s,y[j]);
237       yp = yp + MULT16_16(s, P[j]);
238
239       /* Only now that we've made the final choice, update y/iy */
240       /* Multiplying y[j] by 2 so we don't have to do it everywhere else */
241       y[j] += 2*s;
242       iy[j] += is;
243       pulsesLeft -= pulsesAtOnce;
244    }
245    
246    if (pulsesLeft > 0)
247    {
248       celt_word16_t g;
249       celt_word16_t best_num = -VERY_LARGE16;
250       celt_word16_t best_den = 0;
251       int best_id = 0;
252       celt_word16_t magnitude = SHL16(1, yshift);
253
254       /* The squared magnitude term gets added anyway, so we might as well 
255       add it outside the loop */
256       yy = MAC16_16(yy, magnitude,magnitude);
257       j=0;
258       do {
259          celt_word16_t Rxy, Ryy, Ryp;
260          celt_word16_t num;
261          /* Select sign based on X[j] alone */
262          s = magnitude;
263          /* Temporary sums of the new pulse(s) */
264          Rxy = ROUND16(MAC16_16(xy, s,X[j]), 14);
265          /* We're multiplying y[j] by two so we don't have to do it here */
266          Ryy = ROUND16(MAC16_16(yy, s,y[j]), 14);
267          Ryp = ROUND16(MAC16_16(yp, s,P[j]), 14);
268
269             /* Compute the gain such that ||p + g*y|| = 1 
270          ...but instead, we compute g*Ryy to avoid dividing */
271          g = celt_psqrt(MULT16_16(Ryp,Ryp) + MULT16_16(Ryy,QCONST16(1.f,14)-Rpp)) - Ryp;
272             /* Knowing that gain, what's the error: (x-g*y)^2 
273          (result is negated and we discard x^2 because it's constant) */
274          /* score = 2*g*Rxy - g*g*Ryy;*/
275 #ifdef FIXED_POINT
276          /* No need to multiply Rxy by 2 because we did it earlier */
277          num = MULT16_16_Q15(ADD16(SUB16(Rxy,g),Rxy),g);
278 #else
279          num = g*(2*Rxy-g);
280 #endif
281          if (MULT16_16(best_den, num) > MULT16_16(Ryy, best_num))
282          {
283             best_den = Ryy;
284             best_num = num;
285             best_id = j;
286          }
287       } while (++j<N);
288       iy[best_id] += 1;
289    }
290    j=0;
291    do {
292       P[j] = MULT16_16(signx[j],P[j]);
293       X[j] = MULT16_16(signx[j],X[j]);
294       if (signx[j] < 0)
295          iy[j] = -iy[j];
296    } while (++j<N);
297    encode_pulses(iy, N, K, enc);
298    
299    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
300    due to the recursive computation used in quantisation. */
301    mix_pitch_and_residual(iy, X, N, K, P);
302    RESTORE_STACK;
303 }
304
305
306 /** Decode pulse vector and combine the result with the pitch vector to produce
307     the final normalised signal in the current band. */
308 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, ec_dec *dec)
309 {
310    VARDECL(int, iy);
311    SAVE_STACK;
312    ALLOC(iy, N, int);
313    decode_pulses(iy, N, K, dec);
314    mix_pitch_and_residual(iy, X, N, K, P);
315    RESTORE_STACK;
316 }
317
318 celt_word16_t renormalise_vector(celt_norm_t *X, celt_word16_t value, int N, int stride)
319 {
320    int i;
321    celt_word32_t E = EPSILON;
322    celt_word16_t rE;
323    celt_word16_t g;
324    celt_norm_t *xptr = X;
325    for (i=0;i<N;i++)
326    {
327       E = MAC16_16(E, *xptr, *xptr);
328       xptr += stride;
329    }
330
331    rE = celt_sqrt(E);
332 #ifdef FIXED_POINT
333    if (rE <= 128)
334       g = Q15ONE;
335    else
336 #endif
337       g = MULT16_16_Q15(value,celt_rcp(SHL32(rE,9)));
338    xptr = X;
339    for (i=0;i<N;i++)
340    {
341       *xptr = PSHR32(MULT16_16(g, *xptr),8);
342       xptr += stride;
343    }
344    return rE;
345 }
346
347 static void fold(const CELTMode *m, int N, celt_norm_t *Y, celt_norm_t * restrict P, int N0, int B)
348 {
349    int j;
350    const int C = CHANNELS(m);
351    int id = (N0*C) % (C*B);
352    /* Here, we assume that id will never be greater than N0, i.e. that 
353       no band is wider than N0. In the unlikely case it happens, we set
354       everything to zero */
355    /*{
356            int offset = (N0*C - (id+C*N))/2;
357            if (offset > C*N0/16)
358                    offset = C*N0/16;
359            offset -= offset % (C*B);
360            if (offset < 0)
361                    offset = 0;
362            //printf ("%d\n", offset);
363            id += offset;
364    }*/
365    if (id+C*N>N0*C)
366       for (j=0;j<C*N;j++)
367          P[j] = 0;
368    else
369       for (j=0;j<C*N;j++)
370          P[j] = Y[id++];
371 }
372
373 void intra_fold(const CELTMode *m, celt_norm_t * restrict x, int N, int K, celt_norm_t *Y, celt_norm_t * restrict P, int N0, int B)
374 {
375    celt_word16_t pred_gain;
376    const int C = CHANNELS(m);
377
378    if (K==0)
379       pred_gain = Q15ONE;
380    else
381       pred_gain = celt_div((celt_word32_t)MULT16_16(Q15_ONE,N),(celt_word32_t)(N+2*K*(K+1)));
382
383    fold(m, N, Y, P, N0, B);
384    renormalise_vector(P, pred_gain, C*N, 1);
385 }
386