fa5cebdbac883d948813b223e6a1c4a1b15dcacc
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include <math.h>
33 #include <stdlib.h>
34 #include "cwrs.h"
35 #include "vq.h"
36
37 /* Enable this or define your own implementation if you want to speed up the
38    VQ search (used in inner loop only) */
39 #if 0
40 #include <xmmintrin.h>
41 static inline float approx_sqrt(float x)
42 {
43    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
44    return x;
45 }
46 static inline float approx_inv(float x)
47 {
48    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
49    return x;
50 }
51 #else
52 #define approx_sqrt(x) (sqrt(x))
53 #define approx_inv(x) (1.f/(x))
54 #endif
55
56 struct NBest {
57    float score;
58    float gain;
59    int sign;
60    int pos;
61    int orig;
62    float xy;
63    float yy;
64    float yp;
65 };
66
67 /* Improved algebraic pulse-base quantiser. The signal x is replaced by the sum of the pitch 
68    a combination of pulses such that its norm is still equal to 1. The only difference with 
69    the quantiser above is that the search is more complete. */
70 void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *enc)
71 {
72    int L = 3;
73    //float tata[200];
74    float _y[L][N];
75    int _iy[L][N];
76    //float tata2[200];
77    float _ny[L][N];
78    int _iny[L][N];
79    float *(ny[L]), *(y[L]);
80    int *(iny[L]), *(iy[L]);
81    int i, j, k, m;
82    int pulsesLeft;
83    float xy[L];
84    float yy[L];
85    float yp[L];
86    struct NBest _nbest[L];
87    struct NBest *(nbest[L]);
88    float Rpp=0, Rxp=0;
89    int maxL = 1;
90    
91    for (m=0;m<L;m++)
92       nbest[m] = &_nbest[m];
93    
94    for (m=0;m<L;m++)
95    {
96       ny[m] = _ny[m];
97       iny[m] = _iny[m];
98       y[m] = _y[m];
99       iy[m] = _iy[m];
100    }
101    
102    for (j=0;j<N;j++)
103    {
104       Rpp += p[j]*p[j];
105       Rxp += x[j]*p[j];
106    }
107    
108    /* We only need to initialise the zero because the first iteration only uses that */
109    for (i=0;i<N;i++)
110       y[0][i] = 0;
111    for (i=0;i<N;i++)
112       iy[0][i] = 0;
113    xy[0] = yy[0] = yp[0] = 0;
114
115    pulsesLeft = K;
116    while (pulsesLeft > 0)
117    {
118       int pulsesAtOnce=1;
119       int Lupdate = L;
120       int L2 = L;
121       
122       /* Decide on complexity strategy */
123       pulsesAtOnce = pulsesLeft/N;
124       if (pulsesAtOnce<1)
125          pulsesAtOnce = 1;
126       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
127          Lupdate = 1;
128       //printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);
129       L2 = Lupdate;
130       if (L2>maxL)
131       {
132          L2 = maxL;
133          maxL *= N;
134       }
135
136       for (m=0;m<Lupdate;m++)
137          nbest[m]->score = -1e10f;
138
139       for (m=0;m<L2;m++)
140       {
141          for (j=0;j<N;j++)
142          {
143             int sign;
144             /*if (x[j]>0) sign=1; else sign=-1;*/
145             for (sign=-1;sign<=1;sign+=2)
146             {
147                /* All pulses at one location must have the same sign. */
148                if (iy[m][j]*sign < 0)
149                   continue;
150                //fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);
151                float tmp_xy, tmp_yy, tmp_yp;
152                float score;
153                float g;
154                float s = sign*pulsesAtOnce;
155                
156                /* Updating the sums of the new pulse(s) */
157                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
158                tmp_yy = yy[m] + 2.f*s*y[m][j] + s*s      +s*s*alpha*alpha*p[j]*p[j]*Rpp - 2.f*alpha*s*p[j]*yp[m] - 2.f*s*s*alpha*p[j]*p[j];
159                tmp_yp = yp[m] + s*p[j]               *(1.f-alpha*Rpp);
160                
161                /* Compute the gain such that ||p + g*y|| = 1 */
162                g = (approx_sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)*approx_inv(tmp_yy);
163                /* Knowing that gain, what the error: (x-g*y)^2 
164                   (result is negated and we discard x^2 because it's constant) */
165                score = 2.f*g*tmp_xy - g*g*tmp_yy;
166
167                if (score>nbest[Lupdate-1]->score)
168                {
169                   int k;
170                   int id = Lupdate-1;
171                   struct NBest *tmp_best;
172
173                   /* Save some pointers that would be deleted and use them for the current entry*/
174                   tmp_best = nbest[Lupdate-1];
175                   while (id > 0 && score > nbest[id-1]->score)
176                      id--;
177                
178                   for (k=Lupdate-1;k>id;k--)
179                      nbest[k] = nbest[k-1];
180
181                   nbest[id] = tmp_best;
182                   nbest[id]->score = score;
183                   nbest[id]->pos = j;
184                   nbest[id]->orig = m;
185                   nbest[id]->sign = sign;
186                   nbest[id]->gain = g;
187                   nbest[id]->xy = tmp_xy;
188                   nbest[id]->yy = tmp_yy;
189                   nbest[id]->yp = tmp_yp;
190                }
191             }
192          }
193
194       }
195       /* Only now that we've made the final choice, update ny/iny and others */
196       for (k=0;k<Lupdate;k++)
197       {
198          int n;
199          int is;
200          float s;
201          is = nbest[k]->sign*pulsesAtOnce;
202          s = is;
203          for (n=0;n<N;n++)
204             ny[k][n] = y[nbest[k]->orig][n] - alpha*s*p[nbest[k]->pos]*p[n];
205          ny[k][nbest[k]->pos] += s;
206
207          for (n=0;n<N;n++)
208             iny[k][n] = iy[nbest[k]->orig][n];
209          iny[k][nbest[k]->pos] += is;
210
211          xy[k] = nbest[k]->xy;
212          yy[k] = nbest[k]->yy;
213          yp[k] = nbest[k]->yp;
214       }
215       /* Swap ny/iny with y/iy */
216       for (k=0;k<Lupdate;k++)
217       {
218          float *tmp_ny;
219          int *tmp_iny;
220
221          tmp_ny = ny[k];
222          ny[k] = y[k];
223          y[k] = tmp_ny;
224          tmp_iny = iny[k];
225          iny[k] = iy[k];
226          iy[k] = tmp_iny;
227       }
228       pulsesLeft -= pulsesAtOnce;
229    }
230    
231    if (0) {
232       float err=0;
233       for (i=0;i<N;i++)
234          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
235       //if (N<=10)
236       //printf ("%f %d %d\n", err, K, N);
237    }
238    for (i=0;i<N;i++)
239       x[i] = p[i]+nbest[0]->gain*y[0][i];
240    /* Sanity checks, don't bother */
241    if (0) {
242       float E=1e-15;
243       int ABS = 0;
244       for (i=0;i<N;i++)
245          ABS += abs(iy[0][i]);
246       //if (K != ABS)
247       //   printf ("%d %d\n", K, ABS);
248       for (i=0;i<N;i++)
249          E += x[i]*x[i];
250       //printf ("%f\n", E);
251       E = 1/sqrt(E);
252       for (i=0;i<N;i++)
253          x[i] *= E;
254    }
255    
256    encode_pulses(iy[0], N, K, enc);
257    
258    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
259       due to the recursive computation used in quantisation.
260       Not quite sure whether we need that or not */
261    if (1) {
262       float Ryp=0;
263       float Ryy=0;
264       float g=0;
265       
266       for (i=0;i<N;i++)
267          Ryp += iy[0][i]*p[i];
268       
269       for (i=0;i<N;i++)
270          y[0][i] = iy[0][i] - alpha*Ryp*p[i];
271       
272       Ryp = 0;
273       for (i=0;i<N;i++)
274          Ryp += y[0][i]*p[i];
275       
276       for (i=0;i<N;i++)
277          Ryy += y[0][i]*y[0][i];
278       
279       g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
280         
281       for (i=0;i<N;i++)
282          x[i] = p[i] + g*y[0][i];
283       
284    }
285
286 }
287
288 void alg_unquant(float *x, int N, int K, float *p, float alpha, ec_dec *dec)
289 {
290    int i;
291    int iy[N];
292    float y[N];
293    float Rpp=0, Ryp=0, Ryy=0;
294    float g;
295
296    decode_pulses(iy, N, K, dec);
297
298    //for (i=0;i<N;i++)
299    //   printf ("%d ", iy[i]);
300    for (i=0;i<N;i++)
301       Rpp += p[i]*p[i];
302
303    for (i=0;i<N;i++)
304       Ryp += iy[i]*p[i];
305
306    for (i=0;i<N;i++)
307       y[i] = iy[i] - alpha*Ryp*p[i];
308
309    /* Recompute after the projection (I think it's right) */
310    Ryp = 0;
311    for (i=0;i<N;i++)
312       Ryp += y[i]*p[i];
313
314    for (i=0;i<N;i++)
315       Ryy += y[i]*y[i];
316
317    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
318
319    for (i=0;i<N;i++)
320       x[i] = p[i] + g*y[i];
321 }
322
323
324 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
325
326 void intra_prediction(float *x, float *W, int N, int K, float *Y, float *P, int B, int N0, ec_enc *enc)
327 {
328    int i,j;
329    int best=0;
330    float best_score=0;
331    float s = 1;
332    int sign;
333    float E;
334    int max_pos = N0-N/B;
335    if (max_pos > 32)
336       max_pos = 32;
337
338    for (i=0;i<max_pos*B;i+=B)
339    {
340       int j;
341       float xy=0, yy=0;
342       float score;
343       for (j=0;j<N;j++)
344       {
345          xy += x[j]*Y[i+j];
346          yy += Y[i+j]*Y[i+j];
347       }
348       score = xy*xy/(.001+yy);
349       if (score > best_score)
350       {
351          best_score = score;
352          best = i;
353          if (xy>0)
354             s = 1;
355          else
356             s = -1;
357       }
358    }
359    if (s<0)
360       sign = 1;
361    else
362       sign = 0;
363    //printf ("%d %d ", sign, best);
364    ec_enc_uint(enc,sign,2);
365    ec_enc_uint(enc,best/B,max_pos);
366    //printf ("%d %f\n", best, best_score);
367    
368    float pred_gain;
369    if (K>10)
370       pred_gain = pg[10];
371    else
372       pred_gain = pg[K];
373    E = 1e-10;
374    for (j=0;j<N;j++)
375    {
376       P[j] = s*Y[best+j];
377       E += P[j]*P[j];
378    }
379    E = pred_gain/sqrt(E);
380    for (j=0;j<N;j++)
381       P[j] *= E;
382    if (K>0)
383    {
384       for (j=0;j<N;j++)
385          x[j] -= P[j];
386    } else {
387       for (j=0;j<N;j++)
388          x[j] = P[j];
389    }
390    //printf ("quant ");
391    //for (j=0;j<N;j++) printf ("%f ", P[j]);
392
393 }
394
395 void intra_unquant(float *x, int N, int K, float *Y, float *P, int B, int N0, ec_dec *dec)
396 {
397    int j;
398    int sign;
399    float s;
400    int best;
401    float E;
402    int max_pos = N0-N/B;
403    if (max_pos > 32)
404       max_pos = 32;
405    
406    sign = ec_dec_uint(dec, 2);
407    if (sign == 0)
408       s = 1;
409    else
410       s = -1;
411    
412    best = B*ec_dec_uint(dec, max_pos);
413    //printf ("%d %d ", sign, best);
414
415    float pred_gain;
416    if (K>10)
417       pred_gain = pg[10];
418    else
419       pred_gain = pg[K];
420    E = 1e-10;
421    for (j=0;j<N;j++)
422    {
423       P[j] = s*Y[best+j];
424       E += P[j]*P[j];
425    }
426    E = pred_gain/sqrt(E);
427    for (j=0;j<N;j++)
428       P[j] *= E;
429    if (K==0)
430    {
431       for (j=0;j<N;j++)
432          x[j] = P[j];
433    }
434 }
435
436 void intra_fold(float *x, int N, float *Y, float *P, int B, int N0, int Nmax)
437 {
438    int i, j;
439    float E;
440    
441    E = 1e-10;
442    if (N0 >= Nmax/2)
443    {
444       for (i=0;i<B;i++)
445       {
446          for (j=0;j<N/B;j++)
447          {
448             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
449             E += P[j*B+i]*P[j*B+i];
450          }
451       }
452    } else {
453       for (j=0;j<N;j++)
454       {
455          P[j] = Y[j];
456          E += P[j]*P[j];
457       }
458    }
459    E = 1.f/sqrt(E);
460    for (j=0;j<N;j++)
461       P[j] *= E;
462    for (j=0;j<N;j++)
463       x[j] = P[j];
464 }
465