vq search is now moving much less data around
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include <math.h>
33 #include <stdlib.h>
34 #include "cwrs.h"
35 #include "vq.h"
36
37 struct NBest {
38    float score;
39    float gain;
40    int sign;
41    int pos;
42    int orig;
43    float xy;
44    float yy;
45    float yp;
46 };
47
48 /* Improved algebraic pulse-base quantiser. The signal x is replaced by the sum of the pitch 
49    a combination of pulses such that its norm is still equal to 1. The only difference with 
50    the quantiser above is that the search is more complete. */
51 void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *enc)
52 {
53    int L = 3;
54    //float tata[200];
55    float _y[L][N];
56    int _iy[L][N];
57    //float tata2[200];
58    float _ny[L][N];
59    int _iny[L][N];
60    float *(ny[L]), *(y[L]);
61    int *(iny[L]), *(iy[L]);
62    int i, j, m;
63    int pulsesLeft;
64    float xy[L];
65    float yy[L];
66    float yp[L];
67    struct NBest _nbest[L];
68    struct NBest *(nbest[L]);
69    float Rpp=0, Rxp=0;
70    int maxL = 1;
71    
72    for (m=0;m<L;m++)
73       nbest[m] = &_nbest[m];
74    
75    for (m=0;m<L;m++)
76    {
77       ny[m] = _ny[m];
78       iny[m] = _iny[m];
79       y[m] = _y[m];
80       iy[m] = _iy[m];
81    }
82    
83    for (j=0;j<N;j++)
84       Rpp += p[j]*p[j];
85    //if (Rpp>.01)
86    //   alpha = (1-sqrt(1-Rpp))/Rpp;
87    for (j=0;j<N;j++)
88       Rxp += x[j]*p[j];
89    for (m=0;m<L;m++)
90       for (i=0;i<N;i++)
91          y[m][i] = 0;
92       
93    for (m=0;m<L;m++)
94       for (i=0;i<N;i++)
95          ny[m][i] = 0;
96
97    for (m=0;m<L;m++)
98       for (i=0;i<N;i++)
99          iy[m][i] = iny[m][i] = 0;
100
101    for (m=0;m<L;m++)
102       xy[m] = yy[m] = yp[m] = 0;
103    
104    pulsesLeft = K;
105    while (pulsesLeft > 0)
106    {
107       int pulsesAtOnce=1;
108       int Lupdate = L;
109       int L2 = L;
110       pulsesAtOnce = pulsesLeft/N;
111       if (pulsesAtOnce<1)
112          pulsesAtOnce = 1;
113       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
114          Lupdate = 1;
115       //printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);
116       L2 = Lupdate;
117       if (L2>maxL)
118       {
119          L2 = maxL;
120          maxL *= N;
121       }
122
123       for (m=0;m<L;m++)
124          nbest[m]->score = -1e10;
125
126       for (m=0;m<L2;m++)
127       {
128          for (j=0;j<N;j++)
129          {
130             int sign;
131             //if (x[j]>0) sign=1; else sign=-1;
132             for (sign=-1;sign<=1;sign+=2)
133             {
134                /* All pulses at one location must have the same sign. Also,
135                   only consider sign in the same direction as x[j], except for the
136                   last pulses */
137                if (iy[m][j]*sign < 0 || (x[j]*sign<0 && pulsesLeft>((K+1)>>1)))
138                   continue;
139                //fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);
140                float tmp_xy, tmp_yy, tmp_yp;
141                float score;
142                float g;
143                float s = sign*pulsesAtOnce;
144                
145                /* Updating the sums the the new pulse(s) */
146                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
147                tmp_yy = yy[m] + 2*s*y[m][j] + s*s      +s*s*alpha*alpha*p[j]*p[j]*Rpp - 2*alpha*s*p[j]*yp[m] - 2*s*s*alpha*p[j]*p[j];
148                tmp_yp = yp[m] + s*p[j]               *(1-alpha*Rpp);
149                g = (sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)/tmp_yy;
150                score = 2*g*tmp_xy - g*g*tmp_yy;
151
152                if (score>nbest[Lupdate-1]->score)
153                {
154                   int k, n;
155                   int id = Lupdate-1;
156                   struct NBest *tmp_best;
157
158                   /* Save some pointers that would be deleted and use them for the current entry*/
159                   tmp_best = nbest[Lupdate-1];
160                   while (id > 0 && score > nbest[id-1]->score)
161                      id--;
162                
163                   for (k=Lupdate-1;k>id;k--)
164                      nbest[k] = nbest[k-1];
165
166                   nbest[id] = tmp_best;
167                   nbest[id]->score = score;
168                   nbest[id]->pos = j;
169                   nbest[id]->orig = m;
170                   nbest[id]->sign = sign;
171                   nbest[id]->gain = g;
172                   nbest[id]->xy = tmp_xy;
173                   nbest[id]->yy = tmp_yy;
174                   nbest[id]->yp = tmp_yp;
175                }
176             }
177          }
178
179       }
180       int k;
181       for (k=0;k<Lupdate;k++)
182       {
183          int n;
184          int is;
185          float s;
186          is = nbest[k]->sign*pulsesAtOnce;
187          s = is;
188          for (n=0;n<N;n++)
189             ny[k][n] = y[nbest[k]->orig][n] - alpha*s*p[nbest[k]->pos]*p[n];
190          ny[k][nbest[k]->pos] += s;
191       
192          for (n=0;n<N;n++)
193             iny[k][n] = iy[nbest[k]->orig][n];
194          iny[k][nbest[k]->pos] += is;
195          
196          xy[k] = nbest[k]->xy;
197          yy[k] = nbest[k]->yy;
198          yp[k] = nbest[k]->yp;
199       }
200          
201       for (k=0;k<Lupdate;k++)
202       {
203          float *tmp_ny;
204          int *tmp_iny;
205
206          tmp_ny = ny[k];
207          ny[k] = y[k];
208          y[k] = tmp_ny;
209          tmp_iny = iny[k];
210          iny[k] = iy[k];
211          iy[k] = tmp_iny;
212       }
213       pulsesLeft -= pulsesAtOnce;
214    }
215    
216    if (0) {
217       float err=0;
218       for (i=0;i<N;i++)
219          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
220       //if (N<=10)
221       //printf ("%f %d %d\n", err, K, N);
222    }
223    for (i=0;i<N;i++)
224       x[i] = p[i]+nbest[0]->gain*y[0][i];
225    if (0) {
226       float E=1e-15;
227       int ABS = 0;
228       for (i=0;i<N;i++)
229          ABS += abs(iy[0][i]);
230       //if (K != ABS)
231       //   printf ("%d %d\n", K, ABS);
232       for (i=0;i<N;i++)
233          E += x[i]*x[i];
234       //printf ("%f\n", E);
235       E = 1/sqrt(E);
236       for (i=0;i<N;i++)
237          x[i] *= E;
238    }
239    
240    encode_pulses(iy[0], N, K, enc);
241    
242    //printf ("%llu ", icwrs64(N, K, comb, signs));
243    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
244       due to the recursive computation used in quantisation.
245       Not quite sure whether we need that or not */
246    if (1) {
247       float Ryp=0;
248       float Rpp=0;
249       float Ryy=0;
250       float g=0;
251       for (i=0;i<N;i++)
252          Rpp += p[i]*p[i];
253       
254       for (i=0;i<N;i++)
255          Ryp += iy[0][i]*p[i];
256       
257       for (i=0;i<N;i++)
258          y[0][i] = iy[0][i] - alpha*Ryp*p[i];
259       
260       Ryp = 0;
261       for (i=0;i<N;i++)
262          Ryp += y[0][i]*p[i];
263       
264       for (i=0;i<N;i++)
265          Ryy += y[0][i]*y[0][i];
266       
267       g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
268         
269       for (i=0;i<N;i++)
270          x[i] = p[i] + g*y[0][i];
271       
272    }
273
274 }
275
276 void alg_unquant(float *x, int N, int K, float *p, float alpha, ec_dec *dec)
277 {
278    int i;
279    int iy[N];
280    float y[N];
281    float Rpp=0, Ryp=0, Ryy=0;
282    float g;
283
284    decode_pulses(iy, N, K, dec);
285
286    //for (i=0;i<N;i++)
287    //   printf ("%d ", iy[i]);
288    for (i=0;i<N;i++)
289       Rpp += p[i]*p[i];
290
291    for (i=0;i<N;i++)
292       Ryp += iy[i]*p[i];
293
294    for (i=0;i<N;i++)
295       y[i] = iy[i] - alpha*Ryp*p[i];
296
297    /* Recompute after the projection (I think it's right) */
298    Ryp = 0;
299    for (i=0;i<N;i++)
300       Ryp += y[i]*p[i];
301
302    for (i=0;i<N;i++)
303       Ryy += y[i]*y[i];
304
305    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
306
307    for (i=0;i<N;i++)
308       x[i] = p[i] + g*y[i];
309 }
310
311
312 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
313
314 void intra_prediction(float *x, float *W, int N, int K, float *Y, float *P, int B, int N0, ec_enc *enc)
315 {
316    int i,j;
317    int best=0;
318    float best_score=0;
319    float s = 1;
320    int sign;
321    float E;
322    int max_pos = N0-N/B;
323    if (max_pos > 32)
324       max_pos = 32;
325
326    for (i=0;i<max_pos*B;i+=B)
327    {
328       int j;
329       float xy=0, yy=0;
330       float score;
331       for (j=0;j<N;j++)
332       {
333          xy += x[j]*Y[i+j];
334          yy += Y[i+j]*Y[i+j];
335       }
336       score = xy*xy/(.001+yy);
337       if (score > best_score)
338       {
339          best_score = score;
340          best = i;
341          if (xy>0)
342             s = 1;
343          else
344             s = -1;
345       }
346    }
347    if (s<0)
348       sign = 1;
349    else
350       sign = 0;
351    //printf ("%d %d ", sign, best);
352    ec_enc_uint(enc,sign,2);
353    ec_enc_uint(enc,best/B,max_pos);
354    //printf ("%d %f\n", best, best_score);
355    
356    float pred_gain;
357    if (K>10)
358       pred_gain = pg[10];
359    else
360       pred_gain = pg[K];
361    E = 1e-10;
362    for (j=0;j<N;j++)
363    {
364       P[j] = s*Y[best+j];
365       E += P[j]*P[j];
366    }
367    E = pred_gain/sqrt(E);
368    for (j=0;j<N;j++)
369       P[j] *= E;
370    if (K>0)
371    {
372       for (j=0;j<N;j++)
373          x[j] -= P[j];
374    } else {
375       for (j=0;j<N;j++)
376          x[j] = P[j];
377    }
378    //printf ("quant ");
379    //for (j=0;j<N;j++) printf ("%f ", P[j]);
380
381 }
382
383 void intra_unquant(float *x, int N, int K, float *Y, float *P, int B, int N0, ec_dec *dec)
384 {
385    int j;
386    int sign;
387    float s;
388    int best;
389    float E;
390    int max_pos = N0-N/B;
391    if (max_pos > 32)
392       max_pos = 32;
393    
394    sign = ec_dec_uint(dec, 2);
395    if (sign == 0)
396       s = 1;
397    else
398       s = -1;
399    
400    best = B*ec_dec_uint(dec, max_pos);
401    //printf ("%d %d ", sign, best);
402
403    float pred_gain;
404    if (K>10)
405       pred_gain = pg[10];
406    else
407       pred_gain = pg[K];
408    E = 1e-10;
409    for (j=0;j<N;j++)
410    {
411       P[j] = s*Y[best+j];
412       E += P[j]*P[j];
413    }
414    E = pred_gain/sqrt(E);
415    for (j=0;j<N;j++)
416       P[j] *= E;
417    if (K==0)
418    {
419       for (j=0;j<N;j++)
420          x[j] = P[j];
421    }
422 }
423
424 void intra_fold(float *x, int N, float *Y, float *P, int B, int N0, int Nmax)
425 {
426    int i, j;
427    float E;
428    
429    E = 1e-10;
430    if (N0 >= Nmax/2)
431    {
432       for (i=0;i<B;i++)
433       {
434          for (j=0;j<N/B;j++)
435          {
436             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
437             E += P[j*B+i]*P[j*B+i];
438          }
439       }
440    } else {
441       for (j=0;j<N;j++)
442       {
443          P[j] = Y[j];
444          E += P[j]*P[j];
445       }
446    }
447    E = 1.f/sqrt(E);
448    for (j=0;j<N;j++)
449       P[j] *= E;
450    for (j=0;j<N;j++)
451       x[j] = P[j];
452 }
453