Updating only the L-best entries in alg_quant() that are useful.
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include <math.h>
33 #include <stdlib.h>
34 #include "cwrs.h"
35 #include "vq.h"
36
37
38 /* Improved algebraic pulse-base quantiser. The signal x is replaced by the sum of the pitch 
39    a combination of pulses such that its norm is still equal to 1. The only difference with 
40    the quantiser above is that the search is more complete. */
41 void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *enc)
42 {
43    int L = 3;
44    //float tata[200];
45    float _y[L][N];
46    int _iy[L][N];
47    //float tata2[200];
48    float _ny[L][N];
49    int _iny[L][N];
50    float *(ny[L]), *(y[L]);
51    int *(iny[L]), *(iy[L]);
52    int i, j, m;
53    int pulsesLeft;
54    float xy[L], nxy[L];
55    float yy[L], nyy[L];
56    float yp[L], nyp[L];
57    float best_scores[L];
58    float Rpp=0, Rxp=0;
59    float gain[L];
60    int maxL = 1;
61    
62    for (m=0;m<L;m++)
63    {
64       ny[m] = _ny[m];
65       iny[m] = _iny[m];
66       y[m] = _y[m];
67       iy[m] = _iy[m];
68    }
69    
70    for (j=0;j<N;j++)
71       Rpp += p[j]*p[j];
72    //if (Rpp>.01)
73    //   alpha = (1-sqrt(1-Rpp))/Rpp;
74    for (j=0;j<N;j++)
75       Rxp += x[j]*p[j];
76    for (m=0;m<L;m++)
77       for (i=0;i<N;i++)
78          y[m][i] = 0;
79       
80    for (m=0;m<L;m++)
81       for (i=0;i<N;i++)
82          ny[m][i] = 0;
83
84    for (m=0;m<L;m++)
85       for (i=0;i<N;i++)
86          iy[m][i] = iny[m][i] = 0;
87
88    for (m=0;m<L;m++)
89       xy[m] = yy[m] = yp[m] = gain[m] = 0;
90    
91    pulsesLeft = K;
92    while (pulsesLeft > 0)
93    {
94       int pulsesAtOnce=1;
95       int Lupdate = L;
96       int L2 = L;
97       pulsesAtOnce = pulsesLeft/N;
98       if (pulsesAtOnce<1)
99          pulsesAtOnce = 1;
100       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
101          Lupdate = 1;
102       //printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);
103       L2 = Lupdate;
104       if (L2>maxL)
105       {
106          L2 = maxL;
107          maxL *= N;
108       }
109
110       for (m=0;m<L;m++)
111          best_scores[m] = -1e10;
112
113       for (m=0;m<L2;m++)
114       {
115          for (j=0;j<N;j++)
116          {
117             int sign;
118             for (sign=-1;sign<=1;sign+=2)
119             {
120                /* All pulses at one location must have the same size */
121                if (iy[m][j]*sign < 0)
122                   continue;
123                //fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);
124                float tmp_xy, tmp_yy, tmp_yp;
125                float score;
126                float g;
127                float s = sign*pulsesAtOnce;
128                
129                /* Updating the sums the the new pulse(s) */
130                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
131                tmp_yy = yy[m] + 2*s*y[m][j] + s*s      +s*s*alpha*alpha*p[j]*p[j]*Rpp - 2*alpha*s*p[j]*yp[m] - 2*s*s*alpha*p[j]*p[j];
132                tmp_yp = yp[m] + s*p[j]               *(1-alpha*Rpp);
133                g = (sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)/tmp_yy;
134                score = 2*g*tmp_xy - g*g*tmp_yy;
135
136                if (score>best_scores[Lupdate-1])
137                {
138                   int k, n;
139                   int id = Lupdate-1;
140                   float *tmp_ny;
141                   int *tmp_iny;
142                   
143                   tmp_ny = ny[Lupdate-1];
144                   tmp_iny = iny[Lupdate-1];
145                   while (id > 0 && score > best_scores[id-1])
146                      id--;
147                
148                   for (k=Lupdate-1;k>id;k--)
149                   {
150                      nxy[k] = nxy[k-1];
151                      nyy[k] = nyy[k-1];
152                      nyp[k] = nyp[k-1];
153                      //fprintf(stderr, "%d %d \n", N, k);
154                      ny[k] = ny[k-1];
155                      iny[k] = iny[k-1];
156                      gain[k] = gain[k-1];
157                      best_scores[k] = best_scores[k-1];
158                   }
159
160                   ny[id] = tmp_ny;
161                   iny[id] = tmp_iny;
162                   
163                   nxy[id] = tmp_xy;
164                   nyy[id] = tmp_yy;
165                   nyp[id] = tmp_yp;
166                   gain[id] = g;
167                   for (n=0;n<N;n++)
168                      ny[id][n] = y[m][n];
169                   ny[id][j] += s;
170                   for (n=0;n<N;n++)
171                      ny[id][n] -= alpha*s*p[j]*p[n];
172                
173                   for (n=0;n<N;n++)
174                      iny[id][n] = iy[m][n];
175                   if (s>0)
176                      iny[id][j] += pulsesAtOnce;
177                   else
178                      iny[id][j] -= pulsesAtOnce;
179                   best_scores[id] = score;
180                }
181             }   
182          }
183          
184       }
185       int k;
186       for (k=0;k<Lupdate;k++)
187       {
188          float *tmp_ny;
189          int *tmp_iny;
190
191          xy[k] = nxy[k];
192          yy[k] = nyy[k];
193          yp[k] = nyp[k];
194          
195          tmp_ny = ny[k];
196          ny[k] = y[k];
197          y[k] = tmp_ny;
198          tmp_iny = iny[k];
199          iny[k] = iy[k];
200          iy[k] = tmp_iny;
201       }
202       pulsesLeft -= pulsesAtOnce;
203    }
204    
205    for (i=0;i<N;i++)
206       x[i] = p[i]+gain[0]*y[0][i];
207    if (0) {
208       float E=1e-15;
209       int ABS = 0;
210       for (i=0;i<N;i++)
211          ABS += abs(iy[0][i]);
212       //if (K != ABS)
213       //   printf ("%d %d\n", K, ABS);
214       for (i=0;i<N;i++)
215          E += x[i]*x[i];
216       //printf ("%f\n", E);
217       E = 1/sqrt(E);
218       for (i=0;i<N;i++)
219          x[i] *= E;
220    }
221    int comb[K];
222    int signs[K];
223    //for (i=0;i<N;i++)
224    //   printf ("%d ", iy[0][i]);
225    pulse2comb(N, K, comb, signs, iy[0]); 
226    ec_enc_uint64(enc,icwrs64(N, K, comb, signs),ncwrs64(N, K));
227    //printf ("%llu ", icwrs64(N, K, comb, signs));
228    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
229       due to the recursive computation used in quantisation */
230    if (1) {
231       float Ryp=0;
232       float Rpp=0;
233       float Ryy=0;
234       float g=0;
235       for (i=0;i<N;i++)
236          Rpp += p[i]*p[i];
237       
238       for (i=0;i<N;i++)
239          Ryp += iy[0][i]*p[i];
240       
241       for (i=0;i<N;i++)
242          y[0][i] = iy[0][i] - alpha*Ryp*p[i];
243       
244       Ryp = 0;
245       for (i=0;i<N;i++)
246          Ryp += y[0][i]*p[i];
247       
248       for (i=0;i<N;i++)
249          Ryy += y[0][i]*y[0][i];
250       
251       g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
252         
253       for (i=0;i<N;i++)
254          x[i] = p[i] + g*y[0][i];
255       
256    }
257
258 }
259
260 void alg_unquant(float *x, int N, int K, float *p, float alpha, ec_dec *dec)
261 {
262    int i;
263    celt_uint64_t id;
264    int comb[K];
265    int signs[K];
266    int iy[N];
267    float y[N];
268    float Rpp=0, Ryp=0, Ryy=0;
269    float g;
270
271    id = ec_dec_uint64(dec, ncwrs64(N, K));
272    //printf ("%llu ", id);
273    cwrsi64(N, K, id, comb, signs);
274    comb2pulse(N, K, iy, comb, signs);
275    //for (i=0;i<N;i++)
276    //   printf ("%d ", iy[i]);
277    for (i=0;i<N;i++)
278       Rpp += p[i]*p[i];
279
280    for (i=0;i<N;i++)
281       Ryp += iy[i]*p[i];
282
283    for (i=0;i<N;i++)
284       y[i] = iy[i] - alpha*Ryp*p[i];
285
286    /* Recompute after the projection (I think it's right) */
287    Ryp = 0;
288    for (i=0;i<N;i++)
289       Ryp += y[i]*p[i];
290
291    for (i=0;i<N;i++)
292       Ryy += y[i]*y[i];
293
294    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
295
296    for (i=0;i<N;i++)
297       x[i] = p[i] + g*y[i];
298 }
299
300
301 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
302
303 void intra_prediction(float *x, float *W, int N, int K, float *Y, float *P, int B, int N0, ec_enc *enc)
304 {
305    int i,j;
306    int best=0;
307    float best_score=0;
308    float s = 1;
309    int sign;
310    float E;
311    int max_pos = N0-N/B;
312    if (max_pos > 32)
313       max_pos = 32;
314
315    for (i=0;i<max_pos*B;i+=B)
316    {
317       int j;
318       float xy=0, yy=0;
319       float score;
320       for (j=0;j<N;j++)
321       {
322          xy += x[j]*Y[i+j];
323          yy += Y[i+j]*Y[i+j];
324       }
325       score = xy*xy/(.001+yy);
326       if (score > best_score)
327       {
328          best_score = score;
329          best = i;
330          if (xy>0)
331             s = 1;
332          else
333             s = -1;
334       }
335    }
336    if (s<0)
337       sign = 1;
338    else
339       sign = 0;
340    //printf ("%d %d ", sign, best);
341    ec_enc_uint(enc,sign,2);
342    ec_enc_uint(enc,best/B,max_pos);
343    //printf ("%d %f\n", best, best_score);
344    
345    float pred_gain;
346    if (K>10)
347       pred_gain = pg[10];
348    else
349       pred_gain = pg[K];
350    E = 1e-10;
351    for (j=0;j<N;j++)
352    {
353       P[j] = s*Y[best+j];
354       E += P[j]*P[j];
355    }
356    E = pred_gain/sqrt(E);
357    for (j=0;j<N;j++)
358       P[j] *= E;
359    if (K>0)
360    {
361       for (j=0;j<N;j++)
362          x[j] -= P[j];
363    } else {
364       for (j=0;j<N;j++)
365          x[j] = P[j];
366    }
367    //printf ("quant ");
368    //for (j=0;j<N;j++) printf ("%f ", P[j]);
369
370 }
371
372 void intra_unquant(float *x, int N, int K, float *Y, float *P, int B, int N0, ec_dec *dec)
373 {
374    int j;
375    int sign;
376    float s;
377    int best;
378    float E;
379    int max_pos = N0-N/B;
380    if (max_pos > 32)
381       max_pos = 32;
382    
383    sign = ec_dec_uint(dec, 2);
384    if (sign == 0)
385       s = 1;
386    else
387       s = -1;
388    
389    best = B*ec_dec_uint(dec, max_pos);
390    //printf ("%d %d ", sign, best);
391
392    float pred_gain;
393    if (K>10)
394       pred_gain = pg[10];
395    else
396       pred_gain = pg[K];
397    E = 1e-10;
398    for (j=0;j<N;j++)
399    {
400       P[j] = s*Y[best+j];
401       E += P[j]*P[j];
402    }
403    E = pred_gain/sqrt(E);
404    for (j=0;j<N;j++)
405       P[j] *= E;
406    if (K==0)
407    {
408       for (j=0;j<N;j++)
409          x[j] = P[j];
410    }
411 }
412
413 void intra_fold(float *x, int N, int K, float *Y, float *P, int B, int N0)
414 {
415    int j;
416    float E;
417    
418    E = 1e-10;
419    for (j=0;j<N;j++)
420    {
421       P[j] = Y[j];
422       E += P[j]*P[j];
423    }
424    E = 1.f/sqrt(E);
425    for (j=0;j<N;j++)
426       P[j] *= E;
427    for (j=0;j<N;j++)
428       x[j] = P[j];
429 }
430