doing the folding properly.
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include <math.h>
33 #include <stdlib.h>
34 #include "cwrs.h"
35 #include "vq.h"
36
37
38 /* Improved algebraic pulse-base quantiser. The signal x is replaced by the sum of the pitch 
39    a combination of pulses such that its norm is still equal to 1. The only difference with 
40    the quantiser above is that the search is more complete. */
41 void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *enc)
42 {
43    int L = 3;
44    //float tata[200];
45    float _y[L][N];
46    int _iy[L][N];
47    //float tata2[200];
48    float _ny[L][N];
49    int _iny[L][N];
50    float *(ny[L]), *(y[L]);
51    int *(iny[L]), *(iy[L]);
52    int i, j, m;
53    int pulsesLeft;
54    float xy[L], nxy[L];
55    float yy[L], nyy[L];
56    float yp[L], nyp[L];
57    float best_scores[L];
58    float Rpp=0, Rxp=0;
59    float gain[L];
60    int maxL = 1;
61    
62    for (m=0;m<L;m++)
63    {
64       ny[m] = _ny[m];
65       iny[m] = _iny[m];
66       y[m] = _y[m];
67       iy[m] = _iy[m];
68    }
69    
70    for (j=0;j<N;j++)
71       Rpp += p[j]*p[j];
72    //if (Rpp>.01)
73    //   alpha = (1-sqrt(1-Rpp))/Rpp;
74    for (j=0;j<N;j++)
75       Rxp += x[j]*p[j];
76    for (m=0;m<L;m++)
77       for (i=0;i<N;i++)
78          y[m][i] = 0;
79       
80    for (m=0;m<L;m++)
81       for (i=0;i<N;i++)
82          ny[m][i] = 0;
83
84    for (m=0;m<L;m++)
85       for (i=0;i<N;i++)
86          iy[m][i] = iny[m][i] = 0;
87
88    for (m=0;m<L;m++)
89       xy[m] = yy[m] = yp[m] = gain[m] = 0;
90    
91    pulsesLeft = K;
92    while (pulsesLeft > 0)
93    {
94       int pulsesAtOnce=1;
95       int Lupdate = L;
96       int L2 = L;
97       pulsesAtOnce = pulsesLeft/N;
98       if (pulsesAtOnce<1)
99          pulsesAtOnce = 1;
100       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
101          Lupdate = 1;
102       //printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);
103       L2 = Lupdate;
104       if (L2>maxL)
105       {
106          L2 = maxL;
107          maxL *= N;
108       }
109
110       for (m=0;m<L;m++)
111          best_scores[m] = -1e10;
112
113       for (m=0;m<L2;m++)
114       {
115          for (j=0;j<N;j++)
116          {
117             int sign;
118             //if (x[j]>0) sign=1; else sign=-1;
119             for (sign=-1;sign<=1;sign+=2)
120             {
121                /* All pulses at one location must have the same sign. Also,
122                   only consider sign in the same direction as x[j], except for the
123                   last pulses */
124                if (iy[m][j]*sign < 0 || (x[j]*sign<0 && pulsesLeft>((K+1)>>1)))
125                   continue;
126                //fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);
127                float tmp_xy, tmp_yy, tmp_yp;
128                float score;
129                float g;
130                float s = sign*pulsesAtOnce;
131                
132                /* Updating the sums the the new pulse(s) */
133                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
134                tmp_yy = yy[m] + 2*s*y[m][j] + s*s      +s*s*alpha*alpha*p[j]*p[j]*Rpp - 2*alpha*s*p[j]*yp[m] - 2*s*s*alpha*p[j]*p[j];
135                tmp_yp = yp[m] + s*p[j]               *(1-alpha*Rpp);
136                g = (sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)/tmp_yy;
137                score = 2*g*tmp_xy - g*g*tmp_yy;
138
139                if (score>best_scores[Lupdate-1])
140                {
141                   int k, n;
142                   int id = Lupdate-1;
143                   float *tmp_ny;
144                   int *tmp_iny;
145                   
146                   tmp_ny = ny[Lupdate-1];
147                   tmp_iny = iny[Lupdate-1];
148                   while (id > 0 && score > best_scores[id-1])
149                      id--;
150                
151                   for (k=Lupdate-1;k>id;k--)
152                   {
153                      nxy[k] = nxy[k-1];
154                      nyy[k] = nyy[k-1];
155                      nyp[k] = nyp[k-1];
156                      //fprintf(stderr, "%d %d \n", N, k);
157                      ny[k] = ny[k-1];
158                      iny[k] = iny[k-1];
159                      gain[k] = gain[k-1];
160                      best_scores[k] = best_scores[k-1];
161                   }
162
163                   ny[id] = tmp_ny;
164                   iny[id] = tmp_iny;
165
166                   nxy[id] = tmp_xy;
167                   nyy[id] = tmp_yy;
168                   nyp[id] = tmp_yp;
169                   gain[id] = g;
170                   for (n=0;n<N;n++)
171                      ny[id][n] = y[m][n] - alpha*s*p[j]*p[n];
172                   ny[id][j] += s;
173
174                   for (n=0;n<N;n++)
175                      iny[id][n] = iy[m][n];
176                   if (s>0)
177                      iny[id][j] += pulsesAtOnce;
178                   else
179                      iny[id][j] -= pulsesAtOnce;
180                   best_scores[id] = score;
181                }
182             }
183          }
184
185       }
186       int k;
187       for (k=0;k<Lupdate;k++)
188       {
189          float *tmp_ny;
190          int *tmp_iny;
191
192          xy[k] = nxy[k];
193          yy[k] = nyy[k];
194          yp[k] = nyp[k];
195          
196          tmp_ny = ny[k];
197          ny[k] = y[k];
198          y[k] = tmp_ny;
199          tmp_iny = iny[k];
200          iny[k] = iy[k];
201          iy[k] = tmp_iny;
202       }
203       pulsesLeft -= pulsesAtOnce;
204    }
205    
206    if (0) {
207       float err=0;
208       for (i=0;i<N;i++)
209          err += (x[i]-gain[0]*y[0][i])*(x[i]-gain[0]*y[0][i]);
210       //if (N<=10)
211       //printf ("%f %d %d\n", err, K, N);
212    }
213    for (i=0;i<N;i++)
214       x[i] = p[i]+gain[0]*y[0][i];
215    if (0) {
216       float E=1e-15;
217       int ABS = 0;
218       for (i=0;i<N;i++)
219          ABS += abs(iy[0][i]);
220       //if (K != ABS)
221       //   printf ("%d %d\n", K, ABS);
222       for (i=0;i<N;i++)
223          E += x[i]*x[i];
224       //printf ("%f\n", E);
225       E = 1/sqrt(E);
226       for (i=0;i<N;i++)
227          x[i] *= E;
228    }
229    int comb[K];
230    int signs[K];
231    //for (i=0;i<N;i++)
232    //   printf ("%d ", iy[0][i]);
233    pulse2comb(N, K, comb, signs, iy[0]); 
234    ec_enc_uint64(enc,icwrs64(N, K, comb, signs),ncwrs64(N, K));
235    //printf ("%llu ", icwrs64(N, K, comb, signs));
236    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
237       due to the recursive computation used in quantisation.
238       Not quite sure whether we need that or not */
239    if (1) {
240       float Ryp=0;
241       float Rpp=0;
242       float Ryy=0;
243       float g=0;
244       for (i=0;i<N;i++)
245          Rpp += p[i]*p[i];
246       
247       for (i=0;i<N;i++)
248          Ryp += iy[0][i]*p[i];
249       
250       for (i=0;i<N;i++)
251          y[0][i] = iy[0][i] - alpha*Ryp*p[i];
252       
253       Ryp = 0;
254       for (i=0;i<N;i++)
255          Ryp += y[0][i]*p[i];
256       
257       for (i=0;i<N;i++)
258          Ryy += y[0][i]*y[0][i];
259       
260       g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
261         
262       for (i=0;i<N;i++)
263          x[i] = p[i] + g*y[0][i];
264       
265    }
266
267 }
268
269 void alg_unquant(float *x, int N, int K, float *p, float alpha, ec_dec *dec)
270 {
271    int i;
272    celt_uint64_t id;
273    int comb[K];
274    int signs[K];
275    int iy[N];
276    float y[N];
277    float Rpp=0, Ryp=0, Ryy=0;
278    float g;
279
280    id = ec_dec_uint64(dec, ncwrs64(N, K));
281    //printf ("%llu ", id);
282    cwrsi64(N, K, id, comb, signs);
283    comb2pulse(N, K, iy, comb, signs);
284    //for (i=0;i<N;i++)
285    //   printf ("%d ", iy[i]);
286    for (i=0;i<N;i++)
287       Rpp += p[i]*p[i];
288
289    for (i=0;i<N;i++)
290       Ryp += iy[i]*p[i];
291
292    for (i=0;i<N;i++)
293       y[i] = iy[i] - alpha*Ryp*p[i];
294
295    /* Recompute after the projection (I think it's right) */
296    Ryp = 0;
297    for (i=0;i<N;i++)
298       Ryp += y[i]*p[i];
299
300    for (i=0;i<N;i++)
301       Ryy += y[i]*y[i];
302
303    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
304
305    for (i=0;i<N;i++)
306       x[i] = p[i] + g*y[i];
307 }
308
309
310 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
311
312 void intra_prediction(float *x, float *W, int N, int K, float *Y, float *P, int B, int N0, ec_enc *enc)
313 {
314    int i,j;
315    int best=0;
316    float best_score=0;
317    float s = 1;
318    int sign;
319    float E;
320    int max_pos = N0-N/B;
321    if (max_pos > 32)
322       max_pos = 32;
323
324    for (i=0;i<max_pos*B;i+=B)
325    {
326       int j;
327       float xy=0, yy=0;
328       float score;
329       for (j=0;j<N;j++)
330       {
331          xy += x[j]*Y[i+j];
332          yy += Y[i+j]*Y[i+j];
333       }
334       score = xy*xy/(.001+yy);
335       if (score > best_score)
336       {
337          best_score = score;
338          best = i;
339          if (xy>0)
340             s = 1;
341          else
342             s = -1;
343       }
344    }
345    if (s<0)
346       sign = 1;
347    else
348       sign = 0;
349    //printf ("%d %d ", sign, best);
350    ec_enc_uint(enc,sign,2);
351    ec_enc_uint(enc,best/B,max_pos);
352    //printf ("%d %f\n", best, best_score);
353    
354    float pred_gain;
355    if (K>10)
356       pred_gain = pg[10];
357    else
358       pred_gain = pg[K];
359    E = 1e-10;
360    for (j=0;j<N;j++)
361    {
362       P[j] = s*Y[best+j];
363       E += P[j]*P[j];
364    }
365    E = pred_gain/sqrt(E);
366    for (j=0;j<N;j++)
367       P[j] *= E;
368    if (K>0)
369    {
370       for (j=0;j<N;j++)
371          x[j] -= P[j];
372    } else {
373       for (j=0;j<N;j++)
374          x[j] = P[j];
375    }
376    //printf ("quant ");
377    //for (j=0;j<N;j++) printf ("%f ", P[j]);
378
379 }
380
381 void intra_unquant(float *x, int N, int K, float *Y, float *P, int B, int N0, ec_dec *dec)
382 {
383    int j;
384    int sign;
385    float s;
386    int best;
387    float E;
388    int max_pos = N0-N/B;
389    if (max_pos > 32)
390       max_pos = 32;
391    
392    sign = ec_dec_uint(dec, 2);
393    if (sign == 0)
394       s = 1;
395    else
396       s = -1;
397    
398    best = B*ec_dec_uint(dec, max_pos);
399    //printf ("%d %d ", sign, best);
400
401    float pred_gain;
402    if (K>10)
403       pred_gain = pg[10];
404    else
405       pred_gain = pg[K];
406    E = 1e-10;
407    for (j=0;j<N;j++)
408    {
409       P[j] = s*Y[best+j];
410       E += P[j]*P[j];
411    }
412    E = pred_gain/sqrt(E);
413    for (j=0;j<N;j++)
414       P[j] *= E;
415    if (K==0)
416    {
417       for (j=0;j<N;j++)
418          x[j] = P[j];
419    }
420 }
421
422 void intra_fold(float *x, int N, float *Y, float *P, int B, int N0, int Nmax)
423 {
424    int i, j;
425    float E;
426    
427    E = 1e-10;
428    if (N0 >= Nmax/2)
429    {
430       for (i=0;i<B;i++)
431       {
432          for (j=0;j<N/B;j++)
433          {
434             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
435             E += P[j*B+i]*P[j*B+i];
436          }
437       }
438    } else {
439       for (j=0;j<N;j++)
440       {
441          P[j] = Y[j];
442          E += P[j]*P[j];
443       }
444    }
445    E = 1.f/sqrt(E);
446    for (j=0;j<N;j++)
447       P[j] *= E;
448    for (j=0;j<N;j++)
449       x[j] = P[j];
450 }
451