Fix a case where the new search can leave us with no pulse left
[opus.git] / libcelt / vq.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "mathops.h"
37 #include "cwrs.h"
38 #include "vq.h"
39 #include "arch.h"
40 #include "os_support.h"
41
42 /** Takes the pitch vector and the decoded residual vector, computes the gain
43     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
44 static void mix_pitch_and_residual(int * restrict iy, celt_norm_t * restrict X, int N, int K, const celt_norm_t * restrict P)
45 {
46    int i;
47    celt_word32_t Ryp, Ryy, Rpp;
48    celt_word16_t ryp, ryy, rpp;
49    celt_word32_t g;
50    VARDECL(celt_norm_t, y);
51 #ifdef FIXED_POINT
52    int yshift;
53 #endif
54    SAVE_STACK;
55 #ifdef FIXED_POINT
56    yshift = 13-celt_ilog2(K);
57 #endif
58    ALLOC(y, N, celt_norm_t);
59
60    /*for (i=0;i<N;i++)
61    printf ("%d ", iy[i]);*/
62    Rpp = 0;
63    i=0;
64    do {
65       Rpp = MAC16_16(Rpp,P[i],P[i]);
66       y[i] = SHL16(iy[i],yshift);
67    } while (++i < N);
68
69    Ryp = 0;
70    Ryy = 0;
71    /* If this doesn't generate a dual MAC (on supported archs), fire the compiler guy */
72    i=0;
73    do {
74       Ryp = MAC16_16(Ryp, y[i], P[i]);
75       Ryy = MAC16_16(Ryy, y[i], y[i]);
76    } while (++i < N);
77
78    ryp = ROUND16(Ryp,14);
79    ryy = ROUND16(Ryy,14);
80    rpp = ROUND16(Rpp,14);
81    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
82    g = MULT16_32_Q15(celt_sqrt(MAC16_16(Ryy, ryp,ryp) - MULT16_16(ryy,rpp)) - ryp,
83                      celt_rcp(SHR32(Ryy,9)));
84
85    i=0;
86    do 
87       X[i] = ADD16(P[i], ROUND16(MULT16_16(y[i], g),11));
88    while (++i < N);
89
90    RESTORE_STACK;
91 }
92
93
94 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, celt_norm_t *P, ec_enc *enc)
95 {
96    VARDECL(celt_norm_t, y);
97    VARDECL(int, iy);
98    VARDECL(celt_word16_t, signx);
99    int j, is;
100    celt_word16_t s;
101    int pulsesLeft;
102    celt_word32_t sum;
103    celt_word32_t xy, yy, yp;
104    celt_word16_t Rpp;
105    int N_1; /* Inverse of N, in Q14 format (even for float) */
106 #ifdef FIXED_POINT
107    int yshift;
108 #endif
109    SAVE_STACK;
110
111 #ifdef FIXED_POINT
112    yshift = 13-celt_ilog2(K);
113 #endif
114
115    ALLOC(y, N, celt_norm_t);
116    ALLOC(iy, N, int);
117    ALLOC(signx, N, celt_word16_t);
118    N_1 = 512/N;
119
120    sum = 0;
121    j=0; do {
122       X[j] -= P[j];
123       if (X[j]>0)
124          signx[j]=1;
125       else {
126          signx[j]=-1;
127          X[j]=-X[j];
128          P[j]=-P[j];
129       }
130       iy[j] = 0;
131       y[j] = 0;
132       sum = MAC16_16(sum, P[j],P[j]);
133    } while (++j<N);
134    Rpp = ROUND16(sum, NORM_SHIFT);
135
136    celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
137
138    xy = yy = yp = 0;
139
140    pulsesLeft = K;
141 #if 0
142    if (K > (N>>1))
143    {
144       celt_word32_t sum=0;
145       j=0; do {
146          sum += X[j];
147       }  while (++j<N);
148       sum = DIV32(SHL32(EXTEND32(K-1),15),EPSILON+sum);
149       j=0; do {
150 #ifdef FIXED_POINT
151          /* It's really important to round *towards zero* here */
152          iy[j] = MULT16_32_Q15(X[j],sum);
153 #else
154          iy[j] = floor(sum*X[j]);
155 #endif
156          y[j] = SHL16(iy[j],yshift);
157          yy = MAC16_16(yy, y[j],y[j]);
158          xy = MAC16_16(xy, X[j],y[j]);
159          yp += P[j]*y[j];
160          y[j] *= 2;
161          pulsesLeft -= iy[j];
162       }  while (++j<N);
163    }
164    /*printf ("%d / %d (%d)\n", pulsesLeft, K, N);*/
165    celt_assert2(pulsesLeft>=1, "Allocated too many pulses in the quick pass");
166 #endif
167    while (pulsesLeft > 1)
168    {
169       int pulsesAtOnce=1;
170       int best_id;
171       celt_word16_t magnitude;
172       celt_word32_t best_num = -VERY_LARGE16;
173       celt_word16_t best_den = 0;
174 #ifdef FIXED_POINT
175       int rshift;
176 #endif
177       /* Decide on how many pulses to find at once */
178       pulsesAtOnce = (pulsesLeft*N_1)>>9; /* pulsesLeft/N */
179       if (pulsesAtOnce<1)
180          pulsesAtOnce = 1;
181 #ifdef FIXED_POINT
182       rshift = yshift+1+celt_ilog2(K-pulsesLeft+pulsesAtOnce);
183 #endif
184       magnitude = SHL16(pulsesAtOnce, yshift);
185
186       best_id = 0;
187       /* The squared magnitude term gets added anyway, so we might as well 
188          add it outside the loop */
189       yy = MAC16_16(yy, magnitude,magnitude);
190       /* Choose between fast and accurate strategy depending on where we are in the search */
191          /* This should ensure that anything we can process will have a better score */
192       j=0;
193       do {
194          celt_word16_t Rxy, Ryy;
195          /* Select sign based on X[j] alone */
196          s = magnitude;
197          /* Temporary sums of the new pulse(s) */
198          Rxy = EXTRACT16(SHR32(MAC16_16(xy, s,X[j]),rshift));
199          /* We're multiplying y[j] by two so we don't have to do it here */
200          Ryy = EXTRACT16(SHR32(MAC16_16(yy, s,y[j]),rshift));
201             
202             /* Approximate score: we maximise Rxy/sqrt(Ryy) (we're guaranteed that 
203          Rxy is positive because the sign is pre-computed) */
204          Rxy = MULT16_16_Q15(Rxy,Rxy);
205             /* The idea is to check for num/den >= best_num/best_den, but that way
206          we can do it without any division */
207          /* OPT: Make sure to use conditional moves here */
208          if (MULT16_16(best_den, Rxy) > MULT16_16(Ryy, best_num))
209          {
210             best_den = Ryy;
211             best_num = Rxy;
212             best_id = j;
213          }
214       } while (++j<N);
215       
216       j = best_id;
217       is = pulsesAtOnce;
218       s = SHL16(is, yshift);
219
220       /* Updating the sums of the new pulse(s) */
221       xy = xy + MULT16_16(s,X[j]);
222       /* We're multiplying y[j] by two so we don't have to do it here */
223       yy = yy + MULT16_16(s,y[j]);
224       yp = yp + MULT16_16(s, P[j]);
225
226       /* Only now that we've made the final choice, update y/iy */
227       /* Multiplying y[j] by 2 so we don't have to do it everywhere else */
228       y[j] += 2*s;
229       iy[j] += is;
230       pulsesLeft -= pulsesAtOnce;
231    }
232    
233    {
234       celt_word16_t g;
235       celt_word16_t best_num = -VERY_LARGE16;
236       celt_word16_t best_den = 0;
237       int best_id = 0;
238       celt_word16_t magnitude = SHL16(1, yshift);
239
240       /* The squared magnitude term gets added anyway, so we might as well 
241       add it outside the loop */
242       yy = MAC16_16(yy, magnitude,magnitude);
243       j=0;
244       do {
245          celt_word16_t Rxy, Ryy, Ryp;
246          celt_word16_t num;
247          /* Select sign based on X[j] alone */
248          s = magnitude;
249          /* Temporary sums of the new pulse(s) */
250          Rxy = ROUND16(MAC16_16(xy, s,X[j]), 14);
251          /* We're multiplying y[j] by two so we don't have to do it here */
252          Ryy = ROUND16(MAC16_16(yy, s,y[j]), 14);
253          Ryp = ROUND16(MAC16_16(yp, s,P[j]), 14);
254
255             /* Compute the gain such that ||p + g*y|| = 1 
256          ...but instead, we compute g*Ryy to avoid dividing */
257          g = celt_psqrt(MULT16_16(Ryp,Ryp) + MULT16_16(Ryy,QCONST16(1.f,14)-Rpp)) - Ryp;
258             /* Knowing that gain, what's the error: (x-g*y)^2 
259          (result is negated and we discard x^2 because it's constant) */
260          /* score = 2*g*Rxy - g*g*Ryy;*/
261 #ifdef FIXED_POINT
262          /* No need to multiply Rxy by 2 because we did it earlier */
263          num = MULT16_16_Q15(ADD16(SUB16(Rxy,g),Rxy),g);
264 #else
265          num = g*(2*Rxy-g);
266 #endif
267          if (MULT16_16(best_den, num) > MULT16_16(Ryy, best_num))
268          {
269             best_den = Ryy;
270             best_num = num;
271             best_id = j;
272          }
273       } while (++j<N);
274       iy[best_id] += 1;
275    }
276    j=0;
277    do {
278       P[j] = MULT16_16(signx[j],P[j]);
279       X[j] = MULT16_16(signx[j],X[j]);
280       if (signx[j] < 0)
281          iy[j] = -iy[j];
282    } while (++j<N);
283    encode_pulses(iy, N, K, enc);
284    
285    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
286    due to the recursive computation used in quantisation. */
287    mix_pitch_and_residual(iy, X, N, K, P);
288    RESTORE_STACK;
289 }
290
291
292 /** Decode pulse vector and combine the result with the pitch vector to produce
293     the final normalised signal in the current band. */
294 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, ec_dec *dec)
295 {
296    VARDECL(int, iy);
297    SAVE_STACK;
298    ALLOC(iy, N, int);
299    decode_pulses(iy, N, K, dec);
300    mix_pitch_and_residual(iy, X, N, K, P);
301    RESTORE_STACK;
302 }
303
304 void renormalise_vector(celt_norm_t *X, celt_word16_t value, int N, int stride)
305 {
306    int i;
307    celt_word32_t E = EPSILON;
308    celt_word16_t g;
309    celt_norm_t *xptr = X;
310    for (i=0;i<N;i++)
311    {
312       E = MAC16_16(E, *xptr, *xptr);
313       xptr += stride;
314    }
315
316    g = MULT16_16_Q15(value,celt_rcp(SHL32(celt_sqrt(E),9)));
317    xptr = X;
318    for (i=0;i<N;i++)
319    {
320       *xptr = PSHR32(MULT16_16(g, *xptr),8);
321       xptr += stride;
322    }
323 }
324
325 static void fold(const CELTMode *m, int N, celt_norm_t *Y, celt_norm_t * restrict P, int N0, int B)
326 {
327    int j;
328    const int C = CHANNELS(m);
329    int id = N0 % (C*B);
330    /* Here, we assume that id will never be greater than N0, i.e. that 
331       no band is wider than N0. In the unlikely case it happens, we set
332       everything to zero */
333    if (id+C*N>N0)
334       for (j=0;j<C*N;j++)
335          P[j] = 0;
336    else
337       for (j=0;j<C*N;j++)
338          P[j] = Y[id++];
339 }
340
341 #define KGAIN 6
342
343 void intra_fold(const CELTMode *m, celt_norm_t * restrict x, int N, int K, celt_norm_t *Y, celt_norm_t * restrict P, int N0, int B)
344 {
345    celt_word16_t pred_gain;
346    const int C = CHANNELS(m);
347
348    if (K==0)
349       pred_gain = Q15ONE;
350    else
351       pred_gain = celt_div((celt_word32_t)MULT16_16(Q15_ONE,N),(celt_word32_t)(N+KGAIN*K));
352
353    fold(m, N, Y, P, N0, B);
354
355    renormalise_vector(P, pred_gain, C*N, 1);
356 }
357