Doing intra-frame prediction backwards (and a few comments)
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "cwrs.h"
39 #include "vq.h"
40 #include "arch.h"
41
42 /* Enable this or define your own implementation if you want to speed up the
43    VQ search (used in inner loop only) */
44 #if 0
45 #include <xmmintrin.h>
46 static inline float approx_sqrt(float x)
47 {
48    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
49    return x;
50 }
51 static inline float approx_inv(float x)
52 {
53    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
54    return x;
55 }
56 #else
57 #define approx_sqrt(x) (sqrt(x))
58 #define approx_inv(x) (1.f/(x))
59 #endif
60
61 /** All the info necessary to keep track of a hypothesis during the search */
62 struct NBest {
63    float score;
64    float gain;
65    int sign;
66    int pos;
67    int orig;
68    float xy;
69    float yy;
70    float yp;
71 };
72
73 void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *enc)
74 {
75    int L = 3;
76    VARDECL(float *_y);
77    VARDECL(float *_ny);
78    VARDECL(int *_iy);
79    VARDECL(int *_iny);
80    VARDECL(float **y);
81    VARDECL(float **ny);
82    VARDECL(int **iy);
83    VARDECL(int **iny);
84    int i, j, k, m;
85    int pulsesLeft;
86    VARDECL(float *xy);
87    VARDECL(float *yy);
88    VARDECL(float *yp);
89    VARDECL(struct NBest *_nbest);
90    VARDECL(struct NBest **nbest);
91    float Rpp=0, Rxp=0;
92    int maxL = 1;
93    
94    ALLOC(_y, L*N, float);
95    ALLOC(_ny, L*N, float);
96    ALLOC(_iy, L*N, int);
97    ALLOC(_iny, L*N, int);
98    ALLOC(y, L*N, float*);
99    ALLOC(ny, L*N, float*);
100    ALLOC(iy, L*N, int*);
101    ALLOC(iny, L*N, int*);
102    
103    ALLOC(xy, L, float);
104    ALLOC(yy, L, float);
105    ALLOC(yp, L, float);
106    ALLOC(_nbest, L, struct NBest);
107    ALLOC(nbest, L, struct NBest *);
108
109    for (m=0;m<L;m++)
110       nbest[m] = &_nbest[m];
111    
112    for (m=0;m<L;m++)
113    {
114       ny[m] = &_ny[m*N];
115       iny[m] = &_iny[m*N];
116       y[m] = &_y[m*N];
117       iy[m] = &_iy[m*N];
118    }
119    
120    for (j=0;j<N;j++)
121    {
122       Rpp += p[j]*p[j];
123       Rxp += x[j]*p[j];
124    }
125    
126    /* We only need to initialise the zero because the first iteration only uses that */
127    for (i=0;i<N;i++)
128       y[0][i] = 0;
129    for (i=0;i<N;i++)
130       iy[0][i] = 0;
131    xy[0] = yy[0] = yp[0] = 0;
132
133    pulsesLeft = K;
134    while (pulsesLeft > 0)
135    {
136       int pulsesAtOnce=1;
137       int Lupdate = L;
138       int L2 = L;
139       
140       /* Decide on complexity strategy */
141       pulsesAtOnce = pulsesLeft/N;
142       if (pulsesAtOnce<1)
143          pulsesAtOnce = 1;
144       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
145          Lupdate = 1;
146       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
147       L2 = Lupdate;
148       if (L2>maxL)
149       {
150          L2 = maxL;
151          maxL *= N;
152       }
153
154       for (m=0;m<Lupdate;m++)
155          nbest[m]->score = -1e10f;
156
157       for (m=0;m<L2;m++)
158       {
159          for (j=0;j<N;j++)
160          {
161             int sign;
162             /*if (x[j]>0) sign=1; else sign=-1;*/
163             for (sign=-1;sign<=1;sign+=2)
164             {
165                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
166                float tmp_xy, tmp_yy, tmp_yp;
167                float score;
168                float g;
169                float s = sign*pulsesAtOnce;
170                
171                /* All pulses at one location must have the same sign. */
172                if (iy[m][j]*sign < 0)
173                   continue;
174
175                /* Updating the sums of the new pulse(s) */
176                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
177                tmp_yy = yy[m] + 2.f*s*y[m][j] + s*s      +s*s*alpha*alpha*p[j]*p[j]*Rpp - 2.f*alpha*s*p[j]*yp[m] - 2.f*s*s*alpha*p[j]*p[j];
178                tmp_yp = yp[m] + s*p[j]               *(1.f-alpha*Rpp);
179                
180                /* Compute the gain such that ||p + g*y|| = 1 */
181                g = (approx_sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)*approx_inv(tmp_yy);
182                /* Knowing that gain, what the error: (x-g*y)^2 
183                   (result is negated and we discard x^2 because it's constant) */
184                score = 2.f*g*tmp_xy - g*g*tmp_yy;
185
186                if (score>nbest[Lupdate-1]->score)
187                {
188                   int k;
189                   int id = Lupdate-1;
190                   struct NBest *tmp_best;
191
192                   /* Save some pointers that would be deleted and use them for the current entry*/
193                   tmp_best = nbest[Lupdate-1];
194                   while (id > 0 && score > nbest[id-1]->score)
195                      id--;
196                
197                   for (k=Lupdate-1;k>id;k--)
198                      nbest[k] = nbest[k-1];
199
200                   nbest[id] = tmp_best;
201                   nbest[id]->score = score;
202                   nbest[id]->pos = j;
203                   nbest[id]->orig = m;
204                   nbest[id]->sign = sign;
205                   nbest[id]->gain = g;
206                   nbest[id]->xy = tmp_xy;
207                   nbest[id]->yy = tmp_yy;
208                   nbest[id]->yp = tmp_yp;
209                }
210             }
211          }
212
213       }
214       /* Only now that we've made the final choice, update ny/iny and others */
215       for (k=0;k<Lupdate;k++)
216       {
217          int n;
218          int is;
219          float s;
220          is = nbest[k]->sign*pulsesAtOnce;
221          s = is;
222          for (n=0;n<N;n++)
223             ny[k][n] = y[nbest[k]->orig][n] - alpha*s*p[nbest[k]->pos]*p[n];
224          ny[k][nbest[k]->pos] += s;
225
226          for (n=0;n<N;n++)
227             iny[k][n] = iy[nbest[k]->orig][n];
228          iny[k][nbest[k]->pos] += is;
229
230          xy[k] = nbest[k]->xy;
231          yy[k] = nbest[k]->yy;
232          yp[k] = nbest[k]->yp;
233       }
234       /* Swap ny/iny with y/iy */
235       for (k=0;k<Lupdate;k++)
236       {
237          float *tmp_ny;
238          int *tmp_iny;
239
240          tmp_ny = ny[k];
241          ny[k] = y[k];
242          y[k] = tmp_ny;
243          tmp_iny = iny[k];
244          iny[k] = iy[k];
245          iy[k] = tmp_iny;
246       }
247       pulsesLeft -= pulsesAtOnce;
248    }
249    
250    if (0) {
251       float err=0;
252       for (i=0;i<N;i++)
253          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
254       /*if (N<=10)
255         printf ("%f %d %d\n", err, K, N);*/
256    }
257    for (i=0;i<N;i++)
258       x[i] = p[i]+nbest[0]->gain*y[0][i];
259    /* Sanity checks, don't bother */
260    if (0) {
261       float E=1e-15;
262       int ABS = 0;
263       for (i=0;i<N;i++)
264          ABS += abs(iy[0][i]);
265       /*if (K != ABS)
266          printf ("%d %d\n", K, ABS);*/
267       for (i=0;i<N;i++)
268          E += x[i]*x[i];
269       /*printf ("%f\n", E);*/
270       E = 1/sqrt(E);
271       for (i=0;i<N;i++)
272          x[i] *= E;
273    }
274    
275    encode_pulses(iy[0], N, K, enc);
276    
277    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
278       due to the recursive computation used in quantisation.
279       Not quite sure whether we need that or not */
280    if (1) {
281       float Ryp=0;
282       float Ryy=0;
283       float g=0;
284       
285       for (i=0;i<N;i++)
286          Ryp += iy[0][i]*p[i];
287       
288       for (i=0;i<N;i++)
289          y[0][i] = iy[0][i] - alpha*Ryp*p[i];
290       
291       Ryp = 0;
292       for (i=0;i<N;i++)
293          Ryp += y[0][i]*p[i];
294       
295       for (i=0;i<N;i++)
296          Ryy += y[0][i]*y[0][i];
297       
298       g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
299         
300       for (i=0;i<N;i++)
301          x[i] = p[i] + g*y[0][i];
302       
303    }
304
305 }
306
307 /** Decode pulse vector and combine the result with the pitch vector to produce
308     the final normalised signal in the current band. */
309 void alg_unquant(float *x, int N, int K, float *p, float alpha, ec_dec *dec)
310 {
311    int i;
312    float Rpp=0, Ryp=0, Ryy=0;
313    float g;
314    VARDECL(int *iy);
315    VARDECL(float *y);
316    
317    ALLOC(iy, N, int);
318    ALLOC(y, N, float);
319
320    decode_pulses(iy, N, K, dec);
321
322    /*for (i=0;i<N;i++)
323       printf ("%d ", iy[i]);*/
324    for (i=0;i<N;i++)
325       Rpp += p[i]*p[i];
326
327    for (i=0;i<N;i++)
328       Ryp += iy[i]*p[i];
329
330    for (i=0;i<N;i++)
331       y[i] = iy[i] - alpha*Ryp*p[i];
332
333    /* Recompute after the projection (I think it's right) */
334    Ryp = 0;
335    for (i=0;i<N;i++)
336       Ryp += y[i]*p[i];
337
338    for (i=0;i<N;i++)
339       Ryy += y[i]*y[i];
340
341    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
342
343    for (i=0;i<N;i++)
344       x[i] = p[i] + g*y[i];
345 }
346
347
348 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
349
350 void intra_prediction(float *x, float *W, int N, int K, float *Y, float *P, int B, int N0, ec_enc *enc)
351 {
352    int i,j;
353    int best=0;
354    float best_score=0;
355    float s = 1;
356    int sign;
357    float E;
358    float pred_gain;
359    int max_pos = N0-N/B;
360    if (max_pos > 32)
361       max_pos = 32;
362
363    for (i=0;i<max_pos*B;i+=B)
364    {
365       int j;
366       float xy=0, yy=0;
367       float score;
368       for (j=0;j<N;j++)
369       {
370          xy += x[j]*Y[i+N-j-1];
371          yy += Y[i+N-j-1]*Y[i+N-j-1];
372       }
373       score = xy*xy/(.001+yy);
374       if (score > best_score)
375       {
376          best_score = score;
377          best = i;
378          if (xy>0)
379             s = 1;
380          else
381             s = -1;
382       }
383    }
384    if (s<0)
385       sign = 1;
386    else
387       sign = 0;
388    /*printf ("%d %d ", sign, best);*/
389    ec_enc_uint(enc,sign,2);
390    ec_enc_uint(enc,best/B,max_pos);
391    /*printf ("%d %f\n", best, best_score);*/
392    
393    if (K>10)
394       pred_gain = pg[10];
395    else
396       pred_gain = pg[K];
397    E = 1e-10;
398    for (j=0;j<N;j++)
399    {
400       P[j] = s*Y[best+N-j-1];
401       E += P[j]*P[j];
402    }
403    E = pred_gain/sqrt(E);
404    for (j=0;j<N;j++)
405       P[j] *= E;
406    if (K>0)
407    {
408       for (j=0;j<N;j++)
409          x[j] -= P[j];
410    } else {
411       for (j=0;j<N;j++)
412          x[j] = P[j];
413    }
414    /*printf ("quant ");*/
415    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
416
417 }
418
419 void intra_unquant(float *x, int N, int K, float *Y, float *P, int B, int N0, ec_dec *dec)
420 {
421    int j;
422    int sign;
423    float s;
424    int best;
425    float E;
426    float pred_gain;
427    int max_pos = N0-N/B;
428    if (max_pos > 32)
429       max_pos = 32;
430    
431    sign = ec_dec_uint(dec, 2);
432    if (sign == 0)
433       s = 1;
434    else
435       s = -1;
436    
437    best = B*ec_dec_uint(dec, max_pos);
438    /*printf ("%d %d ", sign, best);*/
439
440    if (K>10)
441       pred_gain = pg[10];
442    else
443       pred_gain = pg[K];
444    E = 1e-10;
445    for (j=0;j<N;j++)
446    {
447       P[j] = s*Y[best+N-j-1];
448       E += P[j]*P[j];
449    }
450    E = pred_gain/sqrt(E);
451    for (j=0;j<N;j++)
452       P[j] *= E;
453    if (K==0)
454    {
455       for (j=0;j<N;j++)
456          x[j] = P[j];
457    }
458 }
459
460 void intra_fold(float *x, int N, float *Y, float *P, int B, int N0, int Nmax)
461 {
462    int i, j;
463    float E;
464    
465    E = 1e-10;
466    if (N0 >= Nmax/2)
467    {
468       for (i=0;i<B;i++)
469       {
470          for (j=0;j<N/B;j++)
471          {
472             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
473             E += P[j*B+i]*P[j*B+i];
474          }
475       }
476    } else {
477       for (j=0;j<N;j++)
478       {
479          P[j] = Y[j];
480          E += P[j]*P[j];
481       }
482    }
483    E = 1.f/sqrt(E);
484    for (j=0;j<N;j++)
485       P[j] *= E;
486    for (j=0;j<N;j++)
487       x[j] = P[j];
488 }
489