Fixed the default int32 type which was wrong on amd64 (and added testcase).
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include <math.h>
33 #include <stdlib.h>
34 #include "cwrs.h"
35 #include "vq.h"
36
37
38 /* Improved algebraic pulse-base quantiser. The signal x is replaced by the sum of the pitch 
39    a combination of pulses such that its norm is still equal to 1. The only difference with 
40    the quantiser above is that the search is more complete. */
41 void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *enc)
42 {
43    int L = 3;
44    //float tata[200];
45    float y[L][N];
46    int iy[L][N];
47    //float tata2[200];
48    float ny[L][N];
49    int iny[L][N];
50    int i, j, m;
51    float xy[L], nxy[L];
52    float yy[L], nyy[L];
53    float yp[L], nyp[L];
54    float best_scores[L];
55    float Rpp=0, Rxp=0;
56    float gain[L];
57    int maxL = 1;
58    
59    for (j=0;j<N;j++)
60       Rpp += p[j]*p[j];
61    //if (Rpp>.01)
62    //   alpha = (1-sqrt(1-Rpp))/Rpp;
63    for (j=0;j<N;j++)
64       Rxp += x[j]*p[j];
65    for (m=0;m<L;m++)
66       for (i=0;i<N;i++)
67          y[m][i] = 0;
68       
69    for (m=0;m<L;m++)
70       for (i=0;i<N;i++)
71          ny[m][i] = 0;
72
73    for (m=0;m<L;m++)
74       for (i=0;i<N;i++)
75          iy[m][i] = iny[m][i] = 0;
76
77    for (m=0;m<L;m++)
78       xy[m] = yy[m] = yp[m] = gain[m] = 0;
79    
80    for (i=0;i<K;i++)
81    {
82       int L2 = L;
83       if (L>maxL)
84       {
85          L2 = maxL;
86          maxL *= N;
87       }
88       for (m=0;m<L;m++)
89          best_scores[m] = -1e10;
90
91       for (m=0;m<L2;m++)
92       {
93          for (j=0;j<N;j++)
94          {
95             int sign;
96             for (sign=-1;sign<=1;sign+=2)
97             {
98                if (iy[m][j]*sign < 0)
99                   continue;
100                //fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);
101                float tmp_xy, tmp_yy, tmp_yp;
102                float score;
103                float g;
104                float s = sign;
105                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
106                tmp_yy = yy[m] + 2*s*y[m][j] + 1      +alpha*alpha*p[j]*p[j]*Rpp - 2*alpha*s*p[j]*yp[m] - 2*alpha*p[j]*p[j];
107                tmp_yp = yp[m] + s*p[j]               *(1-alpha*Rpp);
108                g = (sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)/tmp_yy;
109                score = 2*g*tmp_xy - g*g*tmp_yy;
110
111                if (score>best_scores[L-1])
112                {
113                   int k, n;
114                   int id = L-1;
115                   while (id > 0 && score > best_scores[id-1])
116                      id--;
117                
118                   for (k=L-1;k>id;k--)
119                   {
120                      nxy[k] = nxy[k-1];
121                      nyy[k] = nyy[k-1];
122                      nyp[k] = nyp[k-1];
123                      //fprintf(stderr, "%d %d \n", N, k);
124                      for (n=0;n<N;n++)
125                         ny[k][n] = ny[k-1][n];
126                      for (n=0;n<N;n++)
127                         iny[k][n] = iny[k-1][n];
128                      gain[k] = gain[k-1];
129                      best_scores[k] = best_scores[k-1];
130                   }
131
132                   nxy[id] = tmp_xy;
133                   nyy[id] = tmp_yy;
134                   nyp[id] = tmp_yp;
135                   gain[id] = g;
136                   for (n=0;n<N;n++)
137                      ny[id][n] = y[m][n];
138                   ny[id][j] += s;
139                   for (n=0;n<N;n++)
140                      ny[id][n] -= alpha*s*p[j]*p[n];
141                
142                   for (n=0;n<N;n++)
143                      iny[id][n] = iy[m][n];
144                   if (s>0)
145                      iny[id][j] += 1;
146                   else
147                      iny[id][j] -= 1;
148                   best_scores[id] = score;
149                }
150             }   
151          }
152          
153       }
154       int k,n;
155       for (k=0;k<L;k++)
156       {
157          xy[k] = nxy[k];
158          yy[k] = nyy[k];
159          yp[k] = nyp[k];
160          for (n=0;n<N;n++)
161             y[k][n] = ny[k][n];
162          for (n=0;n<N;n++)
163             iy[k][n] = iny[k][n];
164       }
165
166    }
167    
168    for (i=0;i<N;i++)
169       x[i] = p[i]+gain[0]*y[0][i];
170    if (0) {
171       float E=1e-15;
172       int ABS = 0;
173       for (i=0;i<N;i++)
174          ABS += abs(iy[0][i]);
175       //if (K != ABS)
176       //   printf ("%d %d\n", K, ABS);
177       for (i=0;i<N;i++)
178          E += x[i]*x[i];
179       //printf ("%f\n", E);
180       E = 1/sqrt(E);
181       for (i=0;i<N;i++)
182          x[i] *= E;
183    }
184    int comb[K];
185    int signs[K];
186    //for (i=0;i<N;i++)
187    //   printf ("%d ", iy[0][i]);
188    pulse2comb(N, K, comb, signs, iy[0]); 
189    ec_enc_uint64(enc,icwrs64(N, K, comb, signs),ncwrs64(N, K));
190    //printf ("%llu ", icwrs64(N, K, comb, signs));
191    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
192       due to the recursive computation used in quantisation */
193    if (1) {
194       float Ryp=0;
195       float Rpp=0;
196       float Ryy=0;
197       float g=0;
198       for (i=0;i<N;i++)
199          Rpp += p[i]*p[i];
200       
201       for (i=0;i<N;i++)
202          Ryp += iy[0][i]*p[i];
203       
204       for (i=0;i<N;i++)
205          y[0][i] = iy[0][i] - alpha*Ryp*p[i];
206       
207       Ryp = 0;
208       for (i=0;i<N;i++)
209          Ryp += y[0][i]*p[i];
210       
211       for (i=0;i<N;i++)
212          Ryy += y[0][i]*y[0][i];
213       
214       g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
215         
216       for (i=0;i<N;i++)
217          x[i] = p[i] + g*y[0][i];
218       
219    }
220
221 }
222
223 void alg_unquant(float *x, int N, int K, float *p, float alpha, ec_dec *dec)
224 {
225    int i;
226    celt_uint64_t id;
227    int comb[K];
228    int signs[K];
229    int iy[N];
230    float y[N];
231    float Rpp=0, Ryp=0, Ryy=0;
232    float g;
233
234    id = ec_dec_uint64(dec, ncwrs64(N, K));
235    //printf ("%llu ", id);
236    cwrsi64(N, K, id, comb, signs);
237    comb2pulse(N, K, iy, comb, signs);
238    //for (i=0;i<N;i++)
239    //   printf ("%d ", iy[i]);
240    for (i=0;i<N;i++)
241       Rpp += p[i]*p[i];
242
243    for (i=0;i<N;i++)
244       Ryp += iy[i]*p[i];
245
246    for (i=0;i<N;i++)
247       y[i] = iy[i] - alpha*Ryp*p[i];
248
249    /* Recompute after the projection (I think it's right) */
250    Ryp = 0;
251    for (i=0;i<N;i++)
252       Ryp += y[i]*p[i];
253
254    for (i=0;i<N;i++)
255       Ryy += y[i]*y[i];
256
257    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
258
259    for (i=0;i<N;i++)
260       x[i] = p[i] + g*y[i];
261 }
262
263
264 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
265
266 void intra_prediction(float *x, float *W, int N, int K, float *Y, float *P, int B, int N0, ec_enc *enc)
267 {
268    int i,j;
269    int best=0;
270    float best_score=0;
271    float s = 1;
272    int sign;
273    float E;
274    for (i=0;i<N0*B-N;i+=B)
275    {
276       int j;
277       float xy=0, yy=0;
278       float score;
279       for (j=0;j<N;j++)
280       {
281          xy += x[j]*Y[i+j];
282          yy += Y[i+j]*Y[i+j];
283       }
284       score = xy*xy/(.001+yy);
285       if (score > best_score)
286       {
287          best_score = score;
288          best = i;
289          if (xy>0)
290             s = 1;
291          else
292             s = -1;
293       }
294    }
295    if (s<0)
296       sign = 1;
297    else
298       sign = 0;
299    //printf ("%d %d ", sign, best);
300    ec_enc_uint(enc,sign,2);
301    ec_enc_uint(enc,best/B,N0-N/B);
302    //printf ("%d %f\n", best, best_score);
303    
304    float pred_gain;
305    if (K>10)
306       pred_gain = pg[10];
307    else
308       pred_gain = pg[K];
309    E = 1e-10;
310    for (j=0;j<N;j++)
311    {
312       P[j] = s*Y[best+j];
313       E += P[j]*P[j];
314    }
315    E = pred_gain/sqrt(E);
316    for (j=0;j<N;j++)
317       P[j] *= E;
318    if (K>0)
319    {
320       for (j=0;j<N;j++)
321          x[j] -= P[j];
322    } else {
323       for (j=0;j<N;j++)
324          x[j] = P[j];
325    }
326    //printf ("quant ");
327    //for (j=0;j<N;j++) printf ("%f ", P[j]);
328
329 }
330
331 void intra_unquant(float *x, int N, int K, float *Y, float *P, int B, int N0, ec_dec *dec)
332 {
333    int j;
334    int sign;
335    float s;
336    int best;
337    float E;
338    sign = ec_dec_uint(dec, 2);
339    if (sign == 0)
340       s = 1;
341    else
342       s = -1;
343    
344    best = B*ec_dec_uint(dec, N0-N/B);
345    //printf ("%d %d ", sign, best);
346
347    float pred_gain;
348    if (K>10)
349       pred_gain = pg[10];
350    else
351       pred_gain = pg[K];
352    E = 1e-10;
353    for (j=0;j<N;j++)
354    {
355       P[j] = s*Y[best+j];
356       E += P[j]*P[j];
357    }
358    E = pred_gain/sqrt(E);
359    for (j=0;j<N;j++)
360       P[j] *= E;
361    if (K==0)
362    {
363       for (j=0;j<N;j++)
364          x[j] = P[j];
365    }
366 }